diff --git a/component/sniffing/conn_sniffer.go b/component/sniffing/conn_sniffer.go index 8cf100f..0ea4c59 100644 --- a/component/sniffing/conn_sniffer.go +++ b/component/sniffing/conn_sniffer.go @@ -9,48 +9,28 @@ import ( "errors" "net" "strings" - "sync" - "time" ) type ConnSniffer struct { net.Conn - sniffer *Sniffer - - mu sync.Mutex + *Sniffer } func NewConnSniffer(conn net.Conn, snifferBufSize int) *ConnSniffer { s := &ConnSniffer{ Conn: conn, - sniffer: NewStreamSniffer(conn, snifferBufSize), + Sniffer: NewStreamSniffer(conn, snifferBufSize), } return s } -func (s *ConnSniffer) SniffTcp() (d string, err error) { - s.Conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - defer s.Conn.SetReadDeadline(time.Time{}) - d, err = s.sniffer.SniffTcp() - if err != nil { - var netError net.Error - if errors.As(err, &netError) && netError.Timeout() { - return "", NotApplicableError - } - return "", err - } - return d, nil -} func (s *ConnSniffer) Read(p []byte) (n int, err error) { - s.mu.Lock() - n, err = s.sniffer.Read(p) - s.mu.Unlock() - return n, err + return s.Sniffer.Read(p) } func (s *ConnSniffer) Close() (err error) { var errs []string - if err = s.sniffer.Close(); err != nil { + if err = s.Sniffer.Close(); err != nil { errs = append(errs, err.Error()) } if err = s.Conn.Close(); err != nil { diff --git a/component/sniffing/sniffer.go b/component/sniffing/sniffer.go index 4cfebda..94d918a 100644 --- a/component/sniffing/sniffer.go +++ b/component/sniffing/sniffer.go @@ -6,9 +6,11 @@ package sniffing import ( - "github.com/mzz2017/softwind/pool" + "errors" "io" + "net" "sync" + "time" ) type Sniffer struct { @@ -42,12 +44,26 @@ func (s *Sniffer) SniffTcp() (d string, err error) { s.readMu.Lock() defer s.readMu.Unlock() if s.stream { - n, err := s.r.Read(s.buf) + r, isConn := s.r.(net.Conn) + if isConn { + // Set timeout. + r.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + } + n, err := r.Read(s.buf) + if isConn { + // Recover. + r.SetReadDeadline(time.Time{}) + } s.buf = s.buf[:n] if err != nil { + var netError net.Error + if isConn && errors.As(err, &netError) && netError.Timeout() { + goto sniff + } return "", err } } +sniff: if len(s.buf) == 0 { return "", NotApplicableError } @@ -76,9 +92,6 @@ func (s *Sniffer) Read(p []byte) (n int, err error) { n = copy(p, s.buf[s.bufAt:]) s.bufAt += n if s.bufAt >= len(s.buf) { - if s.stream { - pool.Put(s.buf) - } s.buf = nil } return n, nil