From acfc1db679df8c25d34dd58d688c5e23f26c0a22 Mon Sep 17 00:00:00 2001 From: mzz <2017@duck.com> Date: Thu, 13 Jul 2023 19:04:48 +0800 Subject: [PATCH] fix/optimize/refactor(udp): fix potential stuck UDP and optimize reroute logic (#204) --- common/consts/dialer.go | 2 + control/dns_control.go | 12 +++- control/tcp.go | 9 ++- control/udp.go | 149 +++++++++++++++++++++------------------- control/udp_endpoint.go | 70 +++++++++++-------- go.mod | 2 +- go.sum | 4 +- 7 files changed, 141 insertions(+), 107 deletions(-) diff --git a/common/consts/dialer.go b/common/consts/dialer.go index ce27943..7282871 100644 --- a/common/consts/dialer.go +++ b/common/consts/dialer.go @@ -7,6 +7,7 @@ package consts import ( "net/netip" + "time" "golang.org/x/sys/unix" ) @@ -23,6 +24,7 @@ const ( const ( UdpCheckLookupHost = "connectivitycheck.gstatic.com." + DefaultDialTimeout = 8 * time.Second ) type L4ProtoStr string diff --git a/control/dns_control.go b/control/dns_control.go index 31f7362..6c12927 100644 --- a/control/dns_control.go +++ b/control/dns_control.go @@ -591,12 +591,20 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte // We should set a connClosed flag to avoid it. var connClosed bool var conn netproxy.Conn + + ctxDial, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout) + defer cancel() + bestContextDialer := netproxy.ContextDialer{ + Dialer: dialArgument.bestDialer, + } + switch dialArgument.l4proto { case consts.L4ProtoStr_UDP: // Get udp endpoint. // TODO: connection pool. - conn, err = dialArgument.bestDialer.Dial( + conn, err = bestContextDialer.DialContext( + ctxDial, common.MagicNetwork("udp", dialArgument.mark), dialArgument.bestTarget.String(), ) @@ -659,7 +667,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte case consts.L4ProtoStr_TCP: // We can block here because we are in a coroutine. - conn, err = dialArgument.bestDialer.Dial(common.MagicNetwork("tcp", dialArgument.mark), dialArgument.bestTarget.String()) + conn, err = bestContextDialer.DialContext(ctxDial, common.MagicNetwork("tcp", dialArgument.mark), dialArgument.bestTarget.String()) if err != nil { return fmt.Errorf("failed to dial proxy to tcp: %w", err) } diff --git a/control/tcp.go b/control/tcp.go index 1f0aeef..fb3e7b2 100644 --- a/control/tcp.go +++ b/control/tcp.go @@ -6,6 +6,7 @@ package control import ( + "context" "fmt" "net" "net/netip" @@ -185,8 +186,12 @@ func (c *ControlPlane) RouteDialTcp(p *RouteDialParam) (conn netproxy.Conn, err "mac": Mac2String(routingResult.Mac[:]), }).Infof("%v <-> %v", RefineSourceToShow(src, dst.Addr(), consts.LanWanFlag_NotApplicable), dialTarget) } - - return d.Dial(common.MagicNetwork("tcp", routingResult.Mark), dialTarget) + ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout) + defer cancel() + cd := netproxy.ContextDialer{ + Dialer: d, + } + return cd.DialContext(ctx, common.MagicNetwork("tcp", routingResult.Mark), dialTarget) } type WriteCloser interface { diff --git a/control/udp.go b/control/udp.go index eb07037..5419bb7 100644 --- a/control/udp.go +++ b/control/udp.go @@ -17,6 +17,7 @@ import ( "github.com/daeuniverse/dae/common" "github.com/daeuniverse/dae/common/consts" + ob "github.com/daeuniverse/dae/component/outbound" "github.com/daeuniverse/dae/component/outbound/dialer" "github.com/daeuniverse/dae/component/sniffing" internal "github.com/daeuniverse/dae/pkg/ebpf_internal" @@ -31,6 +32,13 @@ const ( MaxRetry = 2 ) +type DialOption struct { + Target string + Dialer *dialer.Dialer + Outbound *ob.DialerGroup + Network string +} + func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout time.Duration) { if sniffDns { var dnsmsg dnsmessage.Msg @@ -139,43 +147,9 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r return err } } - - // Get outbound. - outboundIndex := consts.OutboundIndex(routingResult.Outbound) - if c.dialMode == consts.DialMode_DomainCao && domain != "" { - outboundIndex = consts.OutboundControlPlaneRouting - } - - dialTarget, shouldReroute, dialIp := c.ChooseDialTarget(outboundIndex, realDst, domain) - if shouldReroute { - outboundIndex = consts.OutboundControlPlaneRouting - } - if routingResult.Must > 0 { isDns = false // Regard as plain traffic. } - switch outboundIndex { - case consts.OutboundDirect: - case consts.OutboundControlPlaneRouting: - if isDns { - // Routing of DNS packets are managed by DNS controller. - break - } - - if outboundIndex, routingResult.Mark, _, err = c.Route(realSrc, realDst, domain, consts.L4ProtoType_TCP, routingResult); err != nil { - return err - } - routingResult.Outbound = uint8(outboundIndex) - if c.log.IsLevelEnabled(logrus.TraceLevel) { - c.log.Tracef("outbound: %v => %v", - consts.OutboundControlPlaneRouting.String(), - outboundIndex.String(), - ) - } - // Reset dialTarget. - dialTarget, _, dialIp = c.ChooseDialTarget(outboundIndex, realDst, domain) - default: - } if routingResult.Mark == 0 { routingResult.Mark = c.soMarkFromDae } @@ -190,23 +164,6 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r }) } - if int(outboundIndex) >= len(c.outbounds) { - return fmt.Errorf("outbound %v out of range [0, %v]", outboundIndex, len(c.outbounds)-1) - } - outbound := c.outbounds[outboundIndex] - - // Select dialer from outbound (dialer group). - networkType := &dialer.NetworkType{ - L4Proto: consts.L4ProtoStr_UDP, - IpVersion: consts.IpVersionFromAddr(realDst.Addr()), - IsDns: true, // UDP relies on DNS check result. - } - strictIpVersion := dialIp - dialerForNew, _, err := outbound.Select(networkType, strictIpVersion) - if err != nil { - return fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err) - } - // Dial and send. // TODO: Rewritten domain should not use full-cone (such as VMess Packet Addr). // Maybe we should set up a mapping for UDP: Dialer + Target Domain => Remote Resolved IP. @@ -215,6 +172,17 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r // Get udp endpoint. var ue *UdpEndpoint retry := 0 + networkType := &dialer.NetworkType{ + L4Proto: consts.L4ProtoStr_UDP, + IpVersion: consts.IpVersionFromAddr(realDst.Addr()), + IsDns: true, // UDP relies on DNS check result. + } + // Get outbound. + outboundIndex := consts.OutboundIndex(routingResult.Outbound) + if c.dialMode == consts.DialMode_DomainCao && domain != "" { + outboundIndex = consts.OutboundControlPlaneRouting + } + dialTarget, shouldReroute, dialIp := c.ChooseDialTarget(outboundIndex, realDst, domain) getNew: if retry > MaxRetry { c.log.WithFields(logrus.Fields{ @@ -232,16 +200,59 @@ getNew: return sendPkt(data, from, realSrc, src, lConn, lanWanFlag) }, NatTimeout: natTimeout, - Dialer: dialerForNew, - Network: common.MagicNetwork("udp", routingResult.Mark), - Target: dialTarget, + GetDialOption: func() (option *DialOption, err error) { + if shouldReroute { + outboundIndex = consts.OutboundControlPlaneRouting + } + + switch outboundIndex { + case consts.OutboundDirect: + case consts.OutboundControlPlaneRouting: + if isDns { + // Routing of DNS packets are managed by DNS controller. + break + } + + if outboundIndex, routingResult.Mark, _, err = c.Route(realSrc, realDst, domain, consts.L4ProtoType_TCP, routingResult); err != nil { + return nil, err + } + routingResult.Outbound = uint8(outboundIndex) + if c.log.IsLevelEnabled(logrus.TraceLevel) { + c.log.Tracef("outbound: %v => %v", + consts.OutboundControlPlaneRouting.String(), + outboundIndex.String(), + ) + } + // Reset dialTarget. + dialTarget, _, dialIp = c.ChooseDialTarget(outboundIndex, realDst, domain) + default: + } + + if int(outboundIndex) >= len(c.outbounds) { + return nil, fmt.Errorf("outbound %v out of range [0, %v]", outboundIndex, len(c.outbounds)-1) + } + outbound := c.outbounds[outboundIndex] + + // Select dialer from outbound (dialer group). + strictIpVersion := dialIp + dialerForNew, _, err := outbound.Select(networkType, strictIpVersion) + if err != nil { + return nil, fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err) + } + return &DialOption{ + Target: dialTarget, + Dialer: dialerForNew, + Outbound: outbound, + Network: common.MagicNetwork("udp", routingResult.Mark), + }, nil + }, }) if err != nil { - return fmt.Errorf("failed to GetOrCreate (policy: %v): %w", outbound.GetSelectionPolicy(), err) + return fmt.Errorf("failed to GetOrCreate: %w", err) } // If the udp endpoint has been not alive, remove it from pool and get a new one. - if !isNew && outbound.GetSelectionPolicy() != consts.DialerSelectionPolicy_Fixed && !ue.Dialer.MustGetAlive(networkType) { + if !isNew && ue.Outbound.GetSelectionPolicy() != consts.DialerSelectionPolicy_Fixed && !ue.Dialer.MustGetAlive(networkType) { if c.log.IsLevelEnabled(logrus.DebugLevel) { c.log.WithFields(logrus.Fields{ @@ -278,21 +289,19 @@ getNew: // Print log. // Only print routing for new connection to avoid the log exploded (Quic and BT). - if isNew { - if c.log.IsLevelEnabled(logrus.InfoLevel) { - fields := logrus.Fields{ - "network": networkType.StringWithoutDns(), - "outbound": outbound.Name, - "policy": outbound.GetSelectionPolicy(), - "dialer": ue.Dialer.Property().Name, - "domain": domain, - "ip": RefineAddrPortToShow(realDst), - "pid": routingResult.Pid, - "pname": ProcessName2String(routingResult.Pname[:]), - "mac": Mac2String(routingResult.Mac[:]), - } - c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), dialTarget) + if isNew && c.log.IsLevelEnabled(logrus.InfoLevel) || c.log.IsLevelEnabled(logrus.DebugLevel) { + fields := logrus.Fields{ + "network": networkType.StringWithoutDns(), + "outbound": ue.Outbound.Name, + "policy": ue.Outbound.GetSelectionPolicy(), + "dialer": ue.Dialer.Property().Name, + "domain": domain, + "ip": RefineAddrPortToShow(realDst), + "pid": routingResult.Pid, + "pname": ProcessName2String(routingResult.Pname[:]), + "mac": Mac2String(routingResult.Mac[:]), } + c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), dialTarget) } return nil diff --git a/control/udp_endpoint.go b/control/udp_endpoint.go index 90452ca..82e07da 100644 --- a/control/udp_endpoint.go +++ b/control/udp_endpoint.go @@ -6,12 +6,14 @@ package control import ( + "context" "fmt" "net/netip" "sync" "time" "github.com/daeuniverse/dae/common/consts" + "github.com/daeuniverse/dae/component/outbound" "github.com/daeuniverse/dae/component/outbound/dialer" "github.com/mzz2017/softwind/netproxy" "github.com/mzz2017/softwind/pool" @@ -27,7 +29,8 @@ type UdpEndpoint struct { handler UdpHandler NatTimeout time.Duration - Dialer *dialer.Dialer + Dialer *dialer.Dialer + Outbound *outbound.DialerGroup } func (ue *UdpEndpoint) start() { @@ -65,46 +68,44 @@ func (ue *UdpEndpoint) Close() error { // UdpEndpointPool is a full-cone udp conn pool type UdpEndpointPool struct { - pool map[netip.AddrPort]*UdpEndpoint - mu sync.Mutex + pool sync.Map + createMuMap sync.Map } type UdpEndpointOptions struct { Handler UdpHandler NatTimeout time.Duration - Dialer *dialer.Dialer - // Network is useful for MagicNetwork - Network string - // Target is useful only if the underlay does not support Full-cone. - Target string + // GetTarget is useful only if the underlay does not support Full-cone. + GetDialOption func() (option *DialOption, err error) } var DefaultUdpEndpointPool = NewUdpEndpointPool() func NewUdpEndpointPool() *UdpEndpointPool { - return &UdpEndpointPool{ - pool: make(map[netip.AddrPort]*UdpEndpoint), - } + return &UdpEndpointPool{} } func (p *UdpEndpointPool) Remove(lAddr netip.AddrPort, udpEndpoint *UdpEndpoint) (err error) { - p.mu.Lock() - defer p.mu.Unlock() - if ue, ok := p.pool[lAddr]; ok { + if ue, ok := p.pool.LoadAndDelete(lAddr); ok { if ue != udpEndpoint { return fmt.Errorf("target udp endpoint is not in the pool") } - ue.Close() - delete(p.pool, lAddr) + ue.(*UdpEndpoint).Close() } return nil } func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEndpointOptions) (udpEndpoint *UdpEndpoint, isNew bool, err error) { - // TODO: fine-grained lock. - p.mu.Lock() - defer p.mu.Unlock() - ue, ok := p.pool[lAddr] + _ue, ok := p.pool.Load(lAddr) +begin: if !ok { + createMu, _ := p.createMuMap.LoadOrStore(lAddr, &sync.Mutex{}) + createMu.(*sync.Mutex).Lock() + defer createMu.(*sync.Mutex).Unlock() + defer p.createMuMap.Delete(lAddr) + _ue, ok = p.pool.Load(lAddr) + if ok { + goto begin + } // Create an UdpEndpoint. if createOption == nil { createOption = &UdpEndpointOptions{} @@ -116,36 +117,45 @@ func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEnd return nil, true, fmt.Errorf("createOption.Handler cannot be nil") } - udpConn, err := createOption.Dialer.Dial(createOption.Network, createOption.Target) + dialOption, err := createOption.GetDialOption() + if err != nil { + return nil, false, err + } + cd := netproxy.ContextDialer{ + Dialer: dialOption.Dialer, + } + ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout) + defer cancel() + udpConn, err := cd.DialContext(ctx, dialOption.Network, dialOption.Target) if err != nil { return nil, true, err } if _, ok = udpConn.(netproxy.PacketConn); !ok { return nil, true, fmt.Errorf("protocol does not support udp") } - ue = &UdpEndpoint{ + ue := &UdpEndpoint{ conn: udpConn.(netproxy.PacketConn), deadlineTimer: time.AfterFunc(createOption.NatTimeout, func() { - p.mu.Lock() - defer p.mu.Unlock() - if ue, ok := p.pool[lAddr]; ok { - ue.Close() - delete(p.pool, lAddr) + if ue, ok := p.pool.LoadAndDelete(lAddr); ok { + ue.(*UdpEndpoint).Close() } }), handler: createOption.Handler, NatTimeout: createOption.NatTimeout, - Dialer: createOption.Dialer, + Dialer: dialOption.Dialer, + Outbound: dialOption.Outbound, } - p.pool[lAddr] = ue + _ue = ue + p.pool.Store(lAddr, ue) // Receive UDP messages. go ue.start() isNew = true } else { + ue := _ue.(*UdpEndpoint) // Postpone the deadline. ue.mu.Lock() ue.deadlineTimer.Reset(ue.NatTimeout) ue.mu.Unlock() } - return ue, isNew, nil + return _ue.(*UdpEndpoint), isNew, nil } diff --git a/go.mod b/go.mod index cda013e..fbf4fcb 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/json-iterator/go v1.1.12 github.com/miekg/dns v1.1.55 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 - github.com/mzz2017/softwind v0.0.0-20230710142544-73a557cea4a4 + github.com/mzz2017/softwind v0.0.0-20230710175107-0107af8a1d26 github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd github.com/safchain/ethtool v0.3.0 github.com/sirupsen/logrus v1.9.3 diff --git a/go.sum b/go.sum index bc30b0d..7d8bf24 100644 --- a/go.sum +++ b/go.sum @@ -91,8 +91,8 @@ github.com/mzz2017/disk-bloom v1.0.1 h1:rEF9MiXd9qMW3ibRpqcerLXULoTgRlM21yqqJl1B github.com/mzz2017/disk-bloom v1.0.1/go.mod h1:JLHETtUu44Z6iBmsqzkOtFlRvXSlKnxjwiBRDapizDI= github.com/mzz2017/quic-go v0.0.0-20230706143320-cc858d4932b7 h1:9zmZilN02x3byMB2X3x+B4iyKHkucv70WA4hsyZkjo8= github.com/mzz2017/quic-go v0.0.0-20230706143320-cc858d4932b7/go.mod h1:3H6d55CEofIWWr3gQThiB27+hA3WG5tATtPovzEYPAA= -github.com/mzz2017/softwind v0.0.0-20230710142544-73a557cea4a4 h1:U6oSJf+dwVXpBZGi73l77igid+sOy4jgJucjSrfowFU= -github.com/mzz2017/softwind v0.0.0-20230710142544-73a557cea4a4/go.mod h1:Fz8fgR7/dbnfR6RLpeOMkUDyebq4xShdmjj+cE5jnJ4= +github.com/mzz2017/softwind v0.0.0-20230710175107-0107af8a1d26 h1:kVjALMAhr+rYw77TfrpD8VNIRbZ2/2pN1AYWBcL6eqM= +github.com/mzz2017/softwind v0.0.0-20230710175107-0107af8a1d26/go.mod h1:Fz8fgR7/dbnfR6RLpeOMkUDyebq4xShdmjj+cE5jnJ4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=