From 57c8fa4e4d3b2f4e74a07400f7b775c6c456de6b Mon Sep 17 00:00:00 2001 From: mzz2017 <2017@duck.com> Date: Sat, 15 Jun 2024 14:36:11 +0800 Subject: [PATCH] fix: make udp packets be sent in order --- control/control_plane.go | 17 ++++++-- control/udp.go | 41 +++++++++++++++--- control/udp_endpoint_pool.go | 15 +++++-- control/udp_task_pool.go | 84 ++++++++++++++++++++++++++++++++++++ go.mod | 3 +- go.sum | 7 ++- 6 files changed, 153 insertions(+), 14 deletions(-) create mode 100644 control/udp_task_pool.go diff --git a/control/control_plane.go b/control/control_plane.go index 3430d5d..a2eaa7f 100644 --- a/control/control_plane.go +++ b/control/control_plane.go @@ -751,7 +751,15 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err copy(newBuf, buf[:n]) newOob := pool.Get(oobn) copy(newOob, oob[:oobn]) - go func(data pool.PB, oob pool.PB, src netip.AddrPort) { + newSrc := src + convergeSrc := common.ConvergeAddrPort(src) + // Debug: + // t := time.Now() + DefaultUdpTaskPool.EmitTask(convergeSrc.String(), func() { + data := newBuf + oob := newOob + src := newSrc + defer data.Put() defer oob.Put() var realDst netip.AddrPort @@ -764,10 +772,13 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err } else { realDst = pktDst } - if e := c.handlePkt(udpConn, data, common.ConvergeAddrPort(src), common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult, false); e != nil { + if e := c.handlePkt(udpConn, data, convergeSrc, common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult, false); e != nil { c.log.Warnln("handlePkt:", e) } - }(newBuf, newOob, src) + }) + // if d := time.Since(t); d > 100*time.Millisecond { + // logrus.Println(d) + // } } }() c.ActivateCheck() diff --git a/control/udp.go b/control/udp.go index 708ac78..b1483ea 100644 --- a/control/udp.go +++ b/control/udp.go @@ -31,10 +31,11 @@ const ( ) type DialOption struct { - Target string - Dialer *dialer.Dialer - Outbound *ob.DialerGroup - Network string + Target string + Dialer *dialer.Dialer + Outbound *ob.DialerGroup + Network string + SniffedDomain string } func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout time.Duration) { @@ -78,12 +79,41 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r var realSrc netip.AddrPort var domain string realSrc = src + ue, ueExists := DefaultUdpEndpointPool.Get(realSrc) + if ueExists && ue.SniffedDomain != "" { + // It is quic ... + // Fast path. + domain := ue.SniffedDomain + dialTarget := realDst.String() + + if c.log.IsLevelEnabled(logrus.TraceLevel) { + fields := logrus.Fields{ + "network": "udp(fp)", + "outbound": ue.Outbound.Name, + "policy": ue.Outbound.GetSelectionPolicy(), + "dialer": ue.Dialer.Property().Name, + "sniffed": domain, + "ip": RefineAddrPortToShow(realDst), + "pid": routingResult.Pid, + "dscp": routingResult.Dscp, + "pname": ProcessName2String(routingResult.Pname[:]), + "mac": Mac2String(routingResult.Mac[:]), + } + c.log.WithFields(fields).Tracef("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr()), dialTarget) + } + + _, err = ue.WriteTo(data, dialTarget) + if err != nil { + return err + } + return nil + } // To keep consistency with kernel program, we only sniff DNS request sent to 53. dnsMessage, natTimeout := ChooseNatTimeout(data, realDst.Port() == 53) // We should cache DNS records and set record TTL to 0, in order to monitor the dns req and resp in real time. isDns := dnsMessage != nil - if !isDns && !skipSniffing && !DefaultUdpEndpointPool.Exists(realSrc) { + if !isDns && !skipSniffing && !ueExists { // Sniff Quic, ... key := PacketSnifferKey{ LAddr: realSrc, @@ -152,7 +182,6 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r // However, games may not use QUIC for communication, thus we cannot use domain to dial, which is fine. // Get udp endpoint. - var ue *UdpEndpoint retry := 0 networkType := &dialer.NetworkType{ L4Proto: consts.L4ProtoStr_UDP, diff --git a/control/udp_endpoint_pool.go b/control/udp_endpoint_pool.go index 45ce35c..1bb29e0 100644 --- a/control/udp_endpoint_pool.go +++ b/control/udp_endpoint_pool.go @@ -31,6 +31,10 @@ type UdpEndpoint struct { Dialer *dialer.Dialer Outbound *outbound.DialerGroup + + // Non-empty indicates this UDP Endpoint is related with a sniffed domain. + SniffedDomain string + DialTarget string } func (ue *UdpEndpoint) start() { @@ -95,9 +99,12 @@ func (p *UdpEndpointPool) Remove(lAddr netip.AddrPort, udpEndpoint *UdpEndpoint) return nil } -func (p *UdpEndpointPool) Exists(lAddr netip.AddrPort) (ok bool) { - _, ok = p.pool.Load(lAddr) - return ok +func (p *UdpEndpointPool) Get(lAddr netip.AddrPort) (udpEndpoint *UdpEndpoint, ok bool) { + _ue, ok := p.pool.Load(lAddr) + if !ok { + return nil, ok + } + return _ue.(*UdpEndpoint), ok } func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEndpointOptions) (udpEndpoint *UdpEndpoint, isNew bool, err error) { @@ -146,6 +153,8 @@ begin: NatTimeout: createOption.NatTimeout, Dialer: dialOption.Dialer, Outbound: dialOption.Outbound, + SniffedDomain: dialOption.SniffedDomain, + DialTarget: dialOption.Target, } ue.deadlineTimer = time.AfterFunc(createOption.NatTimeout, func() { if _ue, ok := p.pool.LoadAndDelete(lAddr); ok { diff --git a/control/udp_task_pool.go b/control/udp_task_pool.go new file mode 100644 index 0000000..0c8a78a --- /dev/null +++ b/control/udp_task_pool.go @@ -0,0 +1,84 @@ +package control + +import ( + "sync" + "time" +) + +const UdpTaskQueueLength = 128 + +type UdpTask = func() + +type UdpTaskQueue struct { + ch chan UdpTask + timer *time.Timer + agingTime time.Duration + closed chan struct{} + freed chan struct{} +} + +func (q *UdpTaskQueue) Push(task UdpTask) { + q.timer.Reset(q.agingTime) + q.ch <- task +} + +type UdpTaskPool struct { + queueChPool sync.Pool + // mu protects m + mu sync.Mutex + m map[string]*UdpTaskQueue +} + +func NewUdpTaskPool() *UdpTaskPool { + p := &UdpTaskPool{ + queueChPool: sync.Pool{New: func() any { + return make(chan UdpTask, UdpTaskQueueLength) + }}, + mu: sync.Mutex{}, + m: map[string]*UdpTaskQueue{}, + } + return p +} + +func (p *UdpTaskPool) convoy(q *UdpTaskQueue) { + for { + select { + case <-q.closed: + close(q.freed) + return + case t := <-q.ch: + t() + } + } +} + +func (p *UdpTaskPool) EmitTask(key string, task UdpTask) { + p.mu.Lock() + q, ok := p.m[key] + if !ok { + ch := p.queueChPool.Get().(chan UdpTask) + q = &UdpTaskQueue{ + ch: ch, + timer: nil, + agingTime: DefaultNatTimeout, + closed: make(chan struct{}), + freed: make(chan struct{}), + } + q.timer = time.AfterFunc(q.agingTime, func() { + p.mu.Lock() + defer p.mu.Unlock() + if p.m[key] == q { + delete(p.m, key) + } + close(q.closed) + <-q.freed + p.queueChPool.Put(ch) + }) + p.m[key] = q + go p.convoy(q) + } + p.mu.Unlock() + q.Push(task) +} + +var DefaultUdpTaskPool = NewUdpTaskPool() diff --git a/go.mod b/go.mod index 638e94d..3a9738f 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,7 @@ require ( github.com/gorilla/websocket v1.5.0 // indirect github.com/klauspost/compress v1.17.4 // indirect github.com/onsi/ginkgo/v2 v2.11.0 // indirect - github.com/quic-go/qpack v0.4.0 // indirect + github.com/stretchr/testify v1.8.1 // indirect go.uber.org/mock v0.4.0 // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/net v0.20.0 // indirect @@ -64,6 +64,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/mzz2017/disk-bloom v1.0.1 // indirect github.com/onsi/ginkgo v1.16.5 // indirect + github.com/quic-go/qpack v0.4.0 // indirect github.com/refraction-networking/utls v1.6.4 // indirect github.com/seiflotfy/cuckoofilter v0.0.0-20220411075957-e3b120b3f5fb // indirect github.com/spf13/pflag v1.0.5 // indirect diff --git a/go.sum b/go.sum index f4dc039..56bc125 100644 --- a/go.sum +++ b/go.sum @@ -133,12 +133,17 @@ github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRM github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg= github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= github.com/v2rayA/ahocorasick-domain v0.0.0-20231231085011-99ceb8ef3208 h1:s/K1ome/+rTDictkqGhqLuAleUymyWnvgNWARjblS9U=