From 4efed17ff9f0905a0e2013d9716766c2ab396af4 Mon Sep 17 00:00:00 2001 From: mzz2017 <2017@duck.com> Date: Wed, 12 Apr 2023 21:45:50 +0800 Subject: [PATCH] fix: use goroutine instead of SetReadDeadline to try to fix compatibility --- component/sniffing/sniffer.go | 109 +++++++++++++++++++++------------- control/udp.go | 4 +- 2 files changed, 71 insertions(+), 42 deletions(-) diff --git a/component/sniffing/sniffer.go b/component/sniffing/sniffer.go index 94d918a..0d4849d 100644 --- a/component/sniffing/sniffer.go +++ b/component/sniffing/sniffer.go @@ -6,26 +6,30 @@ package sniffing import ( - "errors" "io" - "net" "sync" "time" ) type Sniffer struct { - r io.Reader + // Stream + stream bool + r io.Reader + dataReady chan struct{} + dataError error + + // Common buf []byte bufAt int - stream bool readMu sync.Mutex } func NewStreamSniffer(r io.Reader, bufSize int) *Sniffer { s := &Sniffer{ - r: r, - buf: make([]byte, bufSize), - stream: true, + r: r, + buf: make([]byte, bufSize), + stream: true, + dataReady: make(chan struct{}), } return s } @@ -40,38 +44,7 @@ func NewPacketSniffer(data []byte) *Sniffer { type sniff func() (d string, err error) -func (s *Sniffer) SniffTcp() (d string, err error) { - s.readMu.Lock() - defer s.readMu.Unlock() - if s.stream { - r, isConn := s.r.(net.Conn) - if isConn { - // Set timeout. - r.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - } - n, err := r.Read(s.buf) - if isConn { - // Recover. - r.SetReadDeadline(time.Time{}) - } - s.buf = s.buf[:n] - if err != nil { - var netError net.Error - if isConn && errors.As(err, &netError) && netError.Timeout() { - goto sniff - } - return "", err - } - } -sniff: - if len(s.buf) == 0 { - return "", NotApplicableError - } - sniffs := []sniff{ - // Most sniffable traffic is TLS, thus we sniff it first. - s.SniffTls, - s.SniffHttp, - } +func sniffGroup(sniffs []sniff) (d string, err error) { for _, sniffer := range sniffs { d, err = sniffer() if err == nil { @@ -84,9 +57,65 @@ sniff: return "", NotApplicableError } -func (s *Sniffer) Read(p []byte) (n int, err error) { +func (s *Sniffer) SniffTcp() (d string, err error) { s.readMu.Lock() defer s.readMu.Unlock() + if s.stream { + go func() { + n, err := s.r.Read(s.buf) + s.buf = s.buf[:n] + if err != nil { + s.dataError = err + } + close(s.dataReady) + }() + + // Waiting 100ms for data. + select { + case <-time.After(100 * time.Millisecond): + return "", NotApplicableError + case <-s.dataReady: + if s.dataError != nil { + return "", s.dataError + } + } + } else { + close(s.dataReady) + } + + if len(s.buf) == 0 { + return "", NotApplicableError + } + + return sniffGroup([]sniff{ + // Most sniffable traffic is TLS, thus we sniff it first. + s.SniffTls, + s.SniffHttp, + }) +} + +func (s *Sniffer) SniffUdp() (d string, err error) { + s.readMu.Lock() + defer s.readMu.Unlock() + + if len(s.buf) == 0 { + return "", NotApplicableError + } + + return sniffGroup([]sniff{ + s.SniffQuic, + }) +} + +func (s *Sniffer) Read(p []byte) (n int, err error) { + <-s.dataReady + if s.dataError != nil { + return 0, s.dataError + } + + s.readMu.Lock() + defer s.readMu.Unlock() + if s.buf != nil && s.bufAt < len(s.buf) { // Read buf first. n = copy(p, s.buf[s.bufAt:]) diff --git a/control/udp.go b/control/udp.go index 6434804..b79387a 100644 --- a/control/udp.go +++ b/control/udp.go @@ -127,9 +127,9 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r // We should cache DNS records and set record TTL to 0, in order to monitor the dns req and resp in real time. isDns := dnsMessage != nil if !isDns { - // Sniff Quic + // Sniff Quic, ... sniffer := sniffing.NewPacketSniffer(data) - domain, err = sniffer.SniffQuic() + domain, err = sniffer.SniffUdp() if err != nil && !sniffing.IsSniffingError(err) { sniffer.Close() return err