diff --git a/control/control_plane.go b/control/control_plane.go index 27b81cc..7f0dfa1 100644 --- a/control/control_plane.go +++ b/control/control_plane.go @@ -764,7 +764,15 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err copy(newBuf, buf[:n]) newOob := pool.Get(oobn) copy(newOob, oob[:oobn]) - go func(data pool.PB, oob pool.PB, src netip.AddrPort) { + newSrc := src + convergeSrc := common.ConvergeAddrPort(src) + // Debug: + // t := time.Now() + DefaultUdpTaskPool.EmitTask(convergeSrc.String(), func() { + data := newBuf + oob := newOob + src := newSrc + defer data.Put() defer oob.Put() var realDst netip.AddrPort @@ -777,10 +785,13 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err } else { realDst = pktDst } - if e := c.handlePkt(udpConn, data, common.ConvergeAddrPort(src), common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult, false); e != nil { + if e := c.handlePkt(udpConn, data, convergeSrc, common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult, false); e != nil { c.log.Warnln("handlePkt:", e) } - }(newBuf, newOob, src) + }) + // if d := time.Since(t); d > 100*time.Millisecond { + // logrus.Println(d) + // } } }() c.ActivateCheck() diff --git a/control/kern/tproxy.c b/control/kern/tproxy.c index 8559f33..11d9b04 100644 --- a/control/kern/tproxy.c +++ b/control/kern/tproxy.c @@ -1290,24 +1290,19 @@ refresh_udp_conn_state_timer(struct tuples_key *key, bool is_egress) if (unlikely(!value)) return NULL; - ret = bpf_timer_init(&value->timer, &udp_conn_state_map, - CLOCK_MONOTONIC); - if (unlikely(ret)) - goto del; + if ((ret = bpf_timer_init(&value->timer, &udp_conn_state_map, + CLOCK_MONOTONIC))) + goto retn; - ret = bpf_timer_set_callback(&value->timer, - refresh_udp_conn_state_timer_cb); - if (unlikely(ret)) - goto del; + if ((ret = bpf_timer_set_callback(&value->timer, + refresh_udp_conn_state_timer_cb))) + goto retn; - ret = bpf_timer_start(&value->timer, TIMEOUT_UDP_CONN_STATE, 0); - if (unlikely(ret)) - goto del; + if ((ret = bpf_timer_start(&value->timer, TIMEOUT_UDP_CONN_STATE, 0))) + goto retn; +retn: return value; -del: - bpf_map_delete_elem(&udp_conn_state_map, key); - return NULL; } SEC("tc/wan_ingress") @@ -1515,17 +1510,22 @@ int tproxy_wan_egress(struct __sk_buff *skb) flag[6] = tuples.dscp; struct pid_pname *pid_pname; - struct udp_conn_state *conn_state = - refresh_udp_conn_state_timer(&tuples.five, true); - if (!conn_state) - return TC_ACT_SHOT; - if (!conn_state->is_egress || - pid_is_control_plane(skb, &pid_pname)) { - // Input udp connection or + if (pid_is_control_plane(skb, &pid_pname)) { // from control plane // => direct. return TC_ACT_OK; } + + struct udp_conn_state *conn_state = + refresh_udp_conn_state_timer(&tuples.five, true); + if (!conn_state) + return TC_ACT_SHOT; + if (!conn_state->is_egress) { + // Input udp connection + // => direct. + return TC_ACT_OK; + } + if (pid_pname) { // 2, 3, 4, 5 __builtin_memcpy(&flag[2], pid_pname->pname, diff --git a/control/packet_sniffer_pool.go b/control/packet_sniffer_pool.go index 56319f2..3d159c1 100644 --- a/control/packet_sniffer_pool.go +++ b/control/packet_sniffer_pool.go @@ -37,7 +37,7 @@ type PacketSnifferKey struct { RAddr netip.AddrPort } -var DefaultPacketSnifferPool = NewPacketSnifferPool() +var DefaultPacketSnifferSessionMgr = NewPacketSnifferPool() func NewPacketSnifferPool() *PacketSnifferPool { return &PacketSnifferPool{} diff --git a/control/packet_sniffer_pool_test.go b/control/packet_sniffer_pool_test.go index 55300b5..30b6841 100644 --- a/control/packet_sniffer_pool_test.go +++ b/control/packet_sniffer_pool_test.go @@ -21,7 +21,7 @@ var testPacketSnifferData = []string{ func TestPacketSniffer_Normal(t *testing.T) { for _, _data := range testPacketSnifferData { data, _ := hex.DecodeString(_data) - sniffer, _ := DefaultPacketSnifferPool.GetOrCreate(PacketSnifferKey{ + sniffer, _ := DefaultPacketSnifferSessionMgr.GetOrCreate(PacketSnifferKey{ LAddr: netip.MustParseAddrPort("1.1.1.1:1111"), RAddr: netip.MustParseAddrPort("2.2.2.2:2222"), }, nil) @@ -44,7 +44,7 @@ func TestPacketSniffer_Mismatched(t *testing.T) { dst := netip.MustParseAddrPort("2.2.2.2:2222") for _, _data := range testPacketSnifferData { data, _ := hex.DecodeString(_data) - sniffer, _ := DefaultPacketSnifferPool.GetOrCreate(PacketSnifferKey{ + sniffer, _ := DefaultPacketSnifferSessionMgr.GetOrCreate(PacketSnifferKey{ LAddr: netip.MustParseAddrPort("1.1.1.1:1111"), RAddr: dst, }, nil) diff --git a/control/udp.go b/control/udp.go index b96b5f4..638ce93 100644 --- a/control/udp.go +++ b/control/udp.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "net/netip" + "time" "github.com/daeuniverse/dae/common" @@ -29,10 +30,11 @@ const ( ) type DialOption struct { - Target string - Dialer *dialer.Dialer - Outbound *ob.DialerGroup - Network string + Target string + Dialer *dialer.Dialer + Outbound *ob.DialerGroup + Network string + SniffedDomain string } func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout time.Duration) { @@ -60,21 +62,50 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r var realSrc netip.AddrPort var domain string realSrc = src + ue, ueExists := DefaultUdpEndpointPool.Get(realSrc) + if ueExists && ue.SniffedDomain != "" { + // It is quic ... + // Fast path. + domain := ue.SniffedDomain + dialTarget := realDst.String() + + if c.log.IsLevelEnabled(logrus.TraceLevel) { + fields := logrus.Fields{ + "network": "udp(fp)", + "outbound": ue.Outbound.Name, + "policy": ue.Outbound.GetSelectionPolicy(), + "dialer": ue.Dialer.Property().Name, + "sniffed": domain, + "ip": RefineAddrPortToShow(realDst), + "pid": routingResult.Pid, + "dscp": routingResult.Dscp, + "pname": ProcessName2String(routingResult.Pname[:]), + "mac": Mac2String(routingResult.Mac[:]), + } + c.log.WithFields(fields).Tracef("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr()), dialTarget) + } + + _, err = ue.WriteTo(data, dialTarget) + if err != nil { + return err + } + return nil + } // To keep consistency with kernel program, we only sniff DNS request sent to 53. dnsMessage, natTimeout := ChooseNatTimeout(data, realDst.Port() == 53) // We should cache DNS records and set record TTL to 0, in order to monitor the dns req and resp in real time. isDns := dnsMessage != nil - if !isDns && !skipSniffing && !DefaultUdpEndpointPool.Exists(realSrc) { + if !isDns && !skipSniffing && !ueExists { // Sniff Quic, ... key := PacketSnifferKey{ LAddr: realSrc, RAddr: realDst, } - _sniffer, _ := DefaultPacketSnifferPool.GetOrCreate(key, nil) + _sniffer, _ := DefaultPacketSnifferSessionMgr.GetOrCreate(key, nil) _sniffer.Mu.Lock() // Re-get sniffer from pool to confirm the transaction is not done. - sniffer := DefaultPacketSnifferPool.Get(key) + sniffer := DefaultPacketSnifferSessionMgr.Get(key) if _sniffer == sniffer { sniffer.AppendData(data) domain, err = sniffer.SniffUdp() @@ -92,7 +123,7 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r WithField("to", realDst). Trace("sniffUdp") } - defer DefaultPacketSnifferPool.Remove(key, sniffer) + defer DefaultPacketSnifferSessionMgr.Remove(key, sniffer) // Re-handlePkt after self func. toRehandle := sniffer.Data()[1 : len(sniffer.Data())-1] // Skip the first empty and the last (self). sniffer.Mu.Unlock() @@ -134,7 +165,6 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r // However, games may not use QUIC for communication, thus we cannot use domain to dial, which is fine. // Get udp endpoint. - var ue *UdpEndpoint retry := 0 networkType := &dialer.NetworkType{ L4Proto: consts.L4ProtoStr_UDP, @@ -217,10 +247,11 @@ getNew: return nil, fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err) } return &DialOption{ - Target: dialTarget, - Dialer: dialerForNew, - Outbound: outbound, - Network: common.MagicNetwork("udp", routingResult.Mark), + Target: dialTarget, + Dialer: dialerForNew, + Outbound: outbound, + Network: common.MagicNetwork("udp", routingResult.Mark), + SniffedDomain: domain, }, nil }, }) @@ -243,6 +274,10 @@ getNew: retry++ goto getNew } + if domain == "" { + // It is used for showing. + domain = ue.SniffedDomain + } _, err = ue.WriteTo(data, dialTarget) if err != nil { @@ -280,7 +315,11 @@ getNew: "pname": ProcessName2String(routingResult.Pname[:]), "mac": Mac2String(routingResult.Mac[:]), } - c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr()), dialTarget) + logger := c.log.WithFields(fields).Infof + if !isNew && c.log.IsLevelEnabled(logrus.DebugLevel) { + logger = c.log.WithFields(fields).Debugf + } + logger("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr()), dialTarget) } return nil diff --git a/control/udp_endpoint_pool.go b/control/udp_endpoint_pool.go index 45ce35c..1bb29e0 100644 --- a/control/udp_endpoint_pool.go +++ b/control/udp_endpoint_pool.go @@ -31,6 +31,10 @@ type UdpEndpoint struct { Dialer *dialer.Dialer Outbound *outbound.DialerGroup + + // Non-empty indicates this UDP Endpoint is related with a sniffed domain. + SniffedDomain string + DialTarget string } func (ue *UdpEndpoint) start() { @@ -95,9 +99,12 @@ func (p *UdpEndpointPool) Remove(lAddr netip.AddrPort, udpEndpoint *UdpEndpoint) return nil } -func (p *UdpEndpointPool) Exists(lAddr netip.AddrPort) (ok bool) { - _, ok = p.pool.Load(lAddr) - return ok +func (p *UdpEndpointPool) Get(lAddr netip.AddrPort) (udpEndpoint *UdpEndpoint, ok bool) { + _ue, ok := p.pool.Load(lAddr) + if !ok { + return nil, ok + } + return _ue.(*UdpEndpoint), ok } func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEndpointOptions) (udpEndpoint *UdpEndpoint, isNew bool, err error) { @@ -146,6 +153,8 @@ begin: NatTimeout: createOption.NatTimeout, Dialer: dialOption.Dialer, Outbound: dialOption.Outbound, + SniffedDomain: dialOption.SniffedDomain, + DialTarget: dialOption.Target, } ue.deadlineTimer = time.AfterFunc(createOption.NatTimeout, func() { if _ue, ok := p.pool.LoadAndDelete(lAddr); ok { diff --git a/control/udp_task_pool.go b/control/udp_task_pool.go new file mode 100644 index 0000000..f59a5b2 --- /dev/null +++ b/control/udp_task_pool.go @@ -0,0 +1,92 @@ +package control + +import ( + "sync" + "time" +) + +const UdpTaskQueueLength = 128 + +type UdpTask = func() + +type UdpTaskQueue struct { + ch chan UdpTask + timer *time.Timer + agingTime time.Duration + closed chan struct{} + freed chan struct{} +} + +func (q *UdpTaskQueue) Push(task UdpTask) { + q.timer.Reset(q.agingTime) + q.ch <- task +} + +type UdpTaskPool struct { + queueChPool sync.Pool + // mu protects m + mu sync.Mutex + m map[string]*UdpTaskQueue +} + +func NewUdpTaskPool() *UdpTaskPool { + p := &UdpTaskPool{ + queueChPool: sync.Pool{New: func() any { + return make(chan UdpTask, UdpTaskQueueLength) + }}, + mu: sync.Mutex{}, + m: map[string]*UdpTaskQueue{}, + } + return p +} + +func (p *UdpTaskPool) convoy(q *UdpTaskQueue) { + for { + select { + case <-q.closed: + clearloop: + for { + select { + case <-q.ch: + default: + break clearloop + } + } + close(q.freed) + return + case t := <-q.ch: + t() + } + } +} + +func (p *UdpTaskPool) EmitTask(key string, task UdpTask) { + p.mu.Lock() + q, ok := p.m[key] + if !ok { + ch := p.queueChPool.Get().(chan UdpTask) + q = &UdpTaskQueue{ + ch: ch, + timer: nil, + agingTime: DefaultNatTimeout, + closed: make(chan struct{}), + freed: make(chan struct{}), + } + q.timer = time.AfterFunc(q.agingTime, func() { + p.mu.Lock() + defer p.mu.Unlock() + if p.m[key] == q { + delete(p.m, key) + } + close(q.closed) + <-q.freed + p.queueChPool.Put(ch) + }) + p.m[key] = q + go p.convoy(q) + } + p.mu.Unlock() + q.Push(task) +} + +var DefaultUdpTaskPool = NewUdpTaskPool() diff --git a/go.mod b/go.mod index bf3229b..8dafc82 100644 --- a/go.mod +++ b/go.mod @@ -39,6 +39,7 @@ require ( github.com/gorilla/websocket v1.5.0 // indirect github.com/klauspost/compress v1.17.4 // indirect github.com/onsi/ginkgo/v2 v2.11.0 // indirect + github.com/stretchr/testify v1.8.1 // indirect go.uber.org/mock v0.4.0 // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/net v0.20.0 // indirect diff --git a/go.sum b/go.sum index d99b7da..2e28d3c 100644 --- a/go.sum +++ b/go.sum @@ -133,12 +133,17 @@ github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRM github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg= github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= github.com/v2rayA/ahocorasick-domain v0.0.0-20231231085011-99ceb8ef3208 h1:s/K1ome/+rTDictkqGhqLuAleUymyWnvgNWARjblS9U=