fix/optimize/refactor(udp): fix potential stuck UDP and optimize reroute logic (#204)

This commit is contained in:
mzz 2023-07-13 19:04:48 +08:00 committed by GitHub
parent ceab2edd00
commit acfc1db679
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 141 additions and 107 deletions

View File

@ -7,6 +7,7 @@ package consts
import ( import (
"net/netip" "net/netip"
"time"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -23,6 +24,7 @@ const (
const ( const (
UdpCheckLookupHost = "connectivitycheck.gstatic.com." UdpCheckLookupHost = "connectivitycheck.gstatic.com."
DefaultDialTimeout = 8 * time.Second
) )
type L4ProtoStr string type L4ProtoStr string

View File

@ -591,12 +591,20 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
// We should set a connClosed flag to avoid it. // We should set a connClosed flag to avoid it.
var connClosed bool var connClosed bool
var conn netproxy.Conn var conn netproxy.Conn
ctxDial, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout)
defer cancel()
bestContextDialer := netproxy.ContextDialer{
Dialer: dialArgument.bestDialer,
}
switch dialArgument.l4proto { switch dialArgument.l4proto {
case consts.L4ProtoStr_UDP: case consts.L4ProtoStr_UDP:
// Get udp endpoint. // Get udp endpoint.
// TODO: connection pool. // TODO: connection pool.
conn, err = dialArgument.bestDialer.Dial( conn, err = bestContextDialer.DialContext(
ctxDial,
common.MagicNetwork("udp", dialArgument.mark), common.MagicNetwork("udp", dialArgument.mark),
dialArgument.bestTarget.String(), dialArgument.bestTarget.String(),
) )
@ -659,7 +667,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
case consts.L4ProtoStr_TCP: case consts.L4ProtoStr_TCP:
// We can block here because we are in a coroutine. // We can block here because we are in a coroutine.
conn, err = dialArgument.bestDialer.Dial(common.MagicNetwork("tcp", dialArgument.mark), dialArgument.bestTarget.String()) conn, err = bestContextDialer.DialContext(ctxDial, common.MagicNetwork("tcp", dialArgument.mark), dialArgument.bestTarget.String())
if err != nil { if err != nil {
return fmt.Errorf("failed to dial proxy to tcp: %w", err) return fmt.Errorf("failed to dial proxy to tcp: %w", err)
} }

View File

@ -6,6 +6,7 @@
package control package control
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@ -185,8 +186,12 @@ func (c *ControlPlane) RouteDialTcp(p *RouteDialParam) (conn netproxy.Conn, err
"mac": Mac2String(routingResult.Mac[:]), "mac": Mac2String(routingResult.Mac[:]),
}).Infof("%v <-> %v", RefineSourceToShow(src, dst.Addr(), consts.LanWanFlag_NotApplicable), dialTarget) }).Infof("%v <-> %v", RefineSourceToShow(src, dst.Addr(), consts.LanWanFlag_NotApplicable), dialTarget)
} }
ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout)
return d.Dial(common.MagicNetwork("tcp", routingResult.Mark), dialTarget) defer cancel()
cd := netproxy.ContextDialer{
Dialer: d,
}
return cd.DialContext(ctx, common.MagicNetwork("tcp", routingResult.Mark), dialTarget)
} }
type WriteCloser interface { type WriteCloser interface {

View File

@ -17,6 +17,7 @@ import (
"github.com/daeuniverse/dae/common" "github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/consts"
ob "github.com/daeuniverse/dae/component/outbound"
"github.com/daeuniverse/dae/component/outbound/dialer" "github.com/daeuniverse/dae/component/outbound/dialer"
"github.com/daeuniverse/dae/component/sniffing" "github.com/daeuniverse/dae/component/sniffing"
internal "github.com/daeuniverse/dae/pkg/ebpf_internal" internal "github.com/daeuniverse/dae/pkg/ebpf_internal"
@ -31,6 +32,13 @@ const (
MaxRetry = 2 MaxRetry = 2
) )
type DialOption struct {
Target string
Dialer *dialer.Dialer
Outbound *ob.DialerGroup
Network string
}
func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout time.Duration) { func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout time.Duration) {
if sniffDns { if sniffDns {
var dnsmsg dnsmessage.Msg var dnsmsg dnsmessage.Msg
@ -139,43 +147,9 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
return err return err
} }
} }
// Get outbound.
outboundIndex := consts.OutboundIndex(routingResult.Outbound)
if c.dialMode == consts.DialMode_DomainCao && domain != "" {
outboundIndex = consts.OutboundControlPlaneRouting
}
dialTarget, shouldReroute, dialIp := c.ChooseDialTarget(outboundIndex, realDst, domain)
if shouldReroute {
outboundIndex = consts.OutboundControlPlaneRouting
}
if routingResult.Must > 0 { if routingResult.Must > 0 {
isDns = false // Regard as plain traffic. isDns = false // Regard as plain traffic.
} }
switch outboundIndex {
case consts.OutboundDirect:
case consts.OutboundControlPlaneRouting:
if isDns {
// Routing of DNS packets are managed by DNS controller.
break
}
if outboundIndex, routingResult.Mark, _, err = c.Route(realSrc, realDst, domain, consts.L4ProtoType_TCP, routingResult); err != nil {
return err
}
routingResult.Outbound = uint8(outboundIndex)
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.Tracef("outbound: %v => %v",
consts.OutboundControlPlaneRouting.String(),
outboundIndex.String(),
)
}
// Reset dialTarget.
dialTarget, _, dialIp = c.ChooseDialTarget(outboundIndex, realDst, domain)
default:
}
if routingResult.Mark == 0 { if routingResult.Mark == 0 {
routingResult.Mark = c.soMarkFromDae routingResult.Mark = c.soMarkFromDae
} }
@ -190,23 +164,6 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
}) })
} }
if int(outboundIndex) >= len(c.outbounds) {
return fmt.Errorf("outbound %v out of range [0, %v]", outboundIndex, len(c.outbounds)-1)
}
outbound := c.outbounds[outboundIndex]
// Select dialer from outbound (dialer group).
networkType := &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionFromAddr(realDst.Addr()),
IsDns: true, // UDP relies on DNS check result.
}
strictIpVersion := dialIp
dialerForNew, _, err := outbound.Select(networkType, strictIpVersion)
if err != nil {
return fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err)
}
// Dial and send. // Dial and send.
// TODO: Rewritten domain should not use full-cone (such as VMess Packet Addr). // TODO: Rewritten domain should not use full-cone (such as VMess Packet Addr).
// Maybe we should set up a mapping for UDP: Dialer + Target Domain => Remote Resolved IP. // Maybe we should set up a mapping for UDP: Dialer + Target Domain => Remote Resolved IP.
@ -215,6 +172,17 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
// Get udp endpoint. // Get udp endpoint.
var ue *UdpEndpoint var ue *UdpEndpoint
retry := 0 retry := 0
networkType := &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionFromAddr(realDst.Addr()),
IsDns: true, // UDP relies on DNS check result.
}
// Get outbound.
outboundIndex := consts.OutboundIndex(routingResult.Outbound)
if c.dialMode == consts.DialMode_DomainCao && domain != "" {
outboundIndex = consts.OutboundControlPlaneRouting
}
dialTarget, shouldReroute, dialIp := c.ChooseDialTarget(outboundIndex, realDst, domain)
getNew: getNew:
if retry > MaxRetry { if retry > MaxRetry {
c.log.WithFields(logrus.Fields{ c.log.WithFields(logrus.Fields{
@ -232,16 +200,59 @@ getNew:
return sendPkt(data, from, realSrc, src, lConn, lanWanFlag) return sendPkt(data, from, realSrc, src, lConn, lanWanFlag)
}, },
NatTimeout: natTimeout, NatTimeout: natTimeout,
Dialer: dialerForNew, GetDialOption: func() (option *DialOption, err error) {
Network: common.MagicNetwork("udp", routingResult.Mark), if shouldReroute {
Target: dialTarget, outboundIndex = consts.OutboundControlPlaneRouting
}
switch outboundIndex {
case consts.OutboundDirect:
case consts.OutboundControlPlaneRouting:
if isDns {
// Routing of DNS packets are managed by DNS controller.
break
}
if outboundIndex, routingResult.Mark, _, err = c.Route(realSrc, realDst, domain, consts.L4ProtoType_TCP, routingResult); err != nil {
return nil, err
}
routingResult.Outbound = uint8(outboundIndex)
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.Tracef("outbound: %v => %v",
consts.OutboundControlPlaneRouting.String(),
outboundIndex.String(),
)
}
// Reset dialTarget.
dialTarget, _, dialIp = c.ChooseDialTarget(outboundIndex, realDst, domain)
default:
}
if int(outboundIndex) >= len(c.outbounds) {
return nil, fmt.Errorf("outbound %v out of range [0, %v]", outboundIndex, len(c.outbounds)-1)
}
outbound := c.outbounds[outboundIndex]
// Select dialer from outbound (dialer group).
strictIpVersion := dialIp
dialerForNew, _, err := outbound.Select(networkType, strictIpVersion)
if err != nil {
return nil, fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err)
}
return &DialOption{
Target: dialTarget,
Dialer: dialerForNew,
Outbound: outbound,
Network: common.MagicNetwork("udp", routingResult.Mark),
}, nil
},
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to GetOrCreate (policy: %v): %w", outbound.GetSelectionPolicy(), err) return fmt.Errorf("failed to GetOrCreate: %w", err)
} }
// If the udp endpoint has been not alive, remove it from pool and get a new one. // If the udp endpoint has been not alive, remove it from pool and get a new one.
if !isNew && outbound.GetSelectionPolicy() != consts.DialerSelectionPolicy_Fixed && !ue.Dialer.MustGetAlive(networkType) { if !isNew && ue.Outbound.GetSelectionPolicy() != consts.DialerSelectionPolicy_Fixed && !ue.Dialer.MustGetAlive(networkType) {
if c.log.IsLevelEnabled(logrus.DebugLevel) { if c.log.IsLevelEnabled(logrus.DebugLevel) {
c.log.WithFields(logrus.Fields{ c.log.WithFields(logrus.Fields{
@ -278,21 +289,19 @@ getNew:
// Print log. // Print log.
// Only print routing for new connection to avoid the log exploded (Quic and BT). // Only print routing for new connection to avoid the log exploded (Quic and BT).
if isNew { if isNew && c.log.IsLevelEnabled(logrus.InfoLevel) || c.log.IsLevelEnabled(logrus.DebugLevel) {
if c.log.IsLevelEnabled(logrus.InfoLevel) { fields := logrus.Fields{
fields := logrus.Fields{ "network": networkType.StringWithoutDns(),
"network": networkType.StringWithoutDns(), "outbound": ue.Outbound.Name,
"outbound": outbound.Name, "policy": ue.Outbound.GetSelectionPolicy(),
"policy": outbound.GetSelectionPolicy(), "dialer": ue.Dialer.Property().Name,
"dialer": ue.Dialer.Property().Name, "domain": domain,
"domain": domain, "ip": RefineAddrPortToShow(realDst),
"ip": RefineAddrPortToShow(realDst), "pid": routingResult.Pid,
"pid": routingResult.Pid, "pname": ProcessName2String(routingResult.Pname[:]),
"pname": ProcessName2String(routingResult.Pname[:]), "mac": Mac2String(routingResult.Mac[:]),
"mac": Mac2String(routingResult.Mac[:]),
}
c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), dialTarget)
} }
c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), dialTarget)
} }
return nil return nil

View File

@ -6,12 +6,14 @@
package control package control
import ( import (
"context"
"fmt" "fmt"
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
"github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/component/outbound"
"github.com/daeuniverse/dae/component/outbound/dialer" "github.com/daeuniverse/dae/component/outbound/dialer"
"github.com/mzz2017/softwind/netproxy" "github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pool" "github.com/mzz2017/softwind/pool"
@ -27,7 +29,8 @@ type UdpEndpoint struct {
handler UdpHandler handler UdpHandler
NatTimeout time.Duration NatTimeout time.Duration
Dialer *dialer.Dialer Dialer *dialer.Dialer
Outbound *outbound.DialerGroup
} }
func (ue *UdpEndpoint) start() { func (ue *UdpEndpoint) start() {
@ -65,46 +68,44 @@ func (ue *UdpEndpoint) Close() error {
// UdpEndpointPool is a full-cone udp conn pool // UdpEndpointPool is a full-cone udp conn pool
type UdpEndpointPool struct { type UdpEndpointPool struct {
pool map[netip.AddrPort]*UdpEndpoint pool sync.Map
mu sync.Mutex createMuMap sync.Map
} }
type UdpEndpointOptions struct { type UdpEndpointOptions struct {
Handler UdpHandler Handler UdpHandler
NatTimeout time.Duration NatTimeout time.Duration
Dialer *dialer.Dialer // GetTarget is useful only if the underlay does not support Full-cone.
// Network is useful for MagicNetwork GetDialOption func() (option *DialOption, err error)
Network string
// Target is useful only if the underlay does not support Full-cone.
Target string
} }
var DefaultUdpEndpointPool = NewUdpEndpointPool() var DefaultUdpEndpointPool = NewUdpEndpointPool()
func NewUdpEndpointPool() *UdpEndpointPool { func NewUdpEndpointPool() *UdpEndpointPool {
return &UdpEndpointPool{ return &UdpEndpointPool{}
pool: make(map[netip.AddrPort]*UdpEndpoint),
}
} }
func (p *UdpEndpointPool) Remove(lAddr netip.AddrPort, udpEndpoint *UdpEndpoint) (err error) { func (p *UdpEndpointPool) Remove(lAddr netip.AddrPort, udpEndpoint *UdpEndpoint) (err error) {
p.mu.Lock() if ue, ok := p.pool.LoadAndDelete(lAddr); ok {
defer p.mu.Unlock()
if ue, ok := p.pool[lAddr]; ok {
if ue != udpEndpoint { if ue != udpEndpoint {
return fmt.Errorf("target udp endpoint is not in the pool") return fmt.Errorf("target udp endpoint is not in the pool")
} }
ue.Close() ue.(*UdpEndpoint).Close()
delete(p.pool, lAddr)
} }
return nil return nil
} }
func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEndpointOptions) (udpEndpoint *UdpEndpoint, isNew bool, err error) { func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEndpointOptions) (udpEndpoint *UdpEndpoint, isNew bool, err error) {
// TODO: fine-grained lock. _ue, ok := p.pool.Load(lAddr)
p.mu.Lock() begin:
defer p.mu.Unlock()
ue, ok := p.pool[lAddr]
if !ok { if !ok {
createMu, _ := p.createMuMap.LoadOrStore(lAddr, &sync.Mutex{})
createMu.(*sync.Mutex).Lock()
defer createMu.(*sync.Mutex).Unlock()
defer p.createMuMap.Delete(lAddr)
_ue, ok = p.pool.Load(lAddr)
if ok {
goto begin
}
// Create an UdpEndpoint. // Create an UdpEndpoint.
if createOption == nil { if createOption == nil {
createOption = &UdpEndpointOptions{} createOption = &UdpEndpointOptions{}
@ -116,36 +117,45 @@ func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEnd
return nil, true, fmt.Errorf("createOption.Handler cannot be nil") return nil, true, fmt.Errorf("createOption.Handler cannot be nil")
} }
udpConn, err := createOption.Dialer.Dial(createOption.Network, createOption.Target) dialOption, err := createOption.GetDialOption()
if err != nil {
return nil, false, err
}
cd := netproxy.ContextDialer{
Dialer: dialOption.Dialer,
}
ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout)
defer cancel()
udpConn, err := cd.DialContext(ctx, dialOption.Network, dialOption.Target)
if err != nil { if err != nil {
return nil, true, err return nil, true, err
} }
if _, ok = udpConn.(netproxy.PacketConn); !ok { if _, ok = udpConn.(netproxy.PacketConn); !ok {
return nil, true, fmt.Errorf("protocol does not support udp") return nil, true, fmt.Errorf("protocol does not support udp")
} }
ue = &UdpEndpoint{ ue := &UdpEndpoint{
conn: udpConn.(netproxy.PacketConn), conn: udpConn.(netproxy.PacketConn),
deadlineTimer: time.AfterFunc(createOption.NatTimeout, func() { deadlineTimer: time.AfterFunc(createOption.NatTimeout, func() {
p.mu.Lock() if ue, ok := p.pool.LoadAndDelete(lAddr); ok {
defer p.mu.Unlock() ue.(*UdpEndpoint).Close()
if ue, ok := p.pool[lAddr]; ok {
ue.Close()
delete(p.pool, lAddr)
} }
}), }),
handler: createOption.Handler, handler: createOption.Handler,
NatTimeout: createOption.NatTimeout, NatTimeout: createOption.NatTimeout,
Dialer: createOption.Dialer, Dialer: dialOption.Dialer,
Outbound: dialOption.Outbound,
} }
p.pool[lAddr] = ue _ue = ue
p.pool.Store(lAddr, ue)
// Receive UDP messages. // Receive UDP messages.
go ue.start() go ue.start()
isNew = true isNew = true
} else { } else {
ue := _ue.(*UdpEndpoint)
// Postpone the deadline. // Postpone the deadline.
ue.mu.Lock() ue.mu.Lock()
ue.deadlineTimer.Reset(ue.NatTimeout) ue.deadlineTimer.Reset(ue.NatTimeout)
ue.mu.Unlock() ue.mu.Unlock()
} }
return ue, isNew, nil return _ue.(*UdpEndpoint), isNew, nil
} }

2
go.mod
View File

@ -12,7 +12,7 @@ require (
github.com/json-iterator/go v1.1.12 github.com/json-iterator/go v1.1.12
github.com/miekg/dns v1.1.55 github.com/miekg/dns v1.1.55
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/mzz2017/softwind v0.0.0-20230710142544-73a557cea4a4 github.com/mzz2017/softwind v0.0.0-20230710175107-0107af8a1d26
github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd
github.com/safchain/ethtool v0.3.0 github.com/safchain/ethtool v0.3.0
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3

4
go.sum
View File

@ -91,8 +91,8 @@ github.com/mzz2017/disk-bloom v1.0.1 h1:rEF9MiXd9qMW3ibRpqcerLXULoTgRlM21yqqJl1B
github.com/mzz2017/disk-bloom v1.0.1/go.mod h1:JLHETtUu44Z6iBmsqzkOtFlRvXSlKnxjwiBRDapizDI= github.com/mzz2017/disk-bloom v1.0.1/go.mod h1:JLHETtUu44Z6iBmsqzkOtFlRvXSlKnxjwiBRDapizDI=
github.com/mzz2017/quic-go v0.0.0-20230706143320-cc858d4932b7 h1:9zmZilN02x3byMB2X3x+B4iyKHkucv70WA4hsyZkjo8= github.com/mzz2017/quic-go v0.0.0-20230706143320-cc858d4932b7 h1:9zmZilN02x3byMB2X3x+B4iyKHkucv70WA4hsyZkjo8=
github.com/mzz2017/quic-go v0.0.0-20230706143320-cc858d4932b7/go.mod h1:3H6d55CEofIWWr3gQThiB27+hA3WG5tATtPovzEYPAA= github.com/mzz2017/quic-go v0.0.0-20230706143320-cc858d4932b7/go.mod h1:3H6d55CEofIWWr3gQThiB27+hA3WG5tATtPovzEYPAA=
github.com/mzz2017/softwind v0.0.0-20230710142544-73a557cea4a4 h1:U6oSJf+dwVXpBZGi73l77igid+sOy4jgJucjSrfowFU= github.com/mzz2017/softwind v0.0.0-20230710175107-0107af8a1d26 h1:kVjALMAhr+rYw77TfrpD8VNIRbZ2/2pN1AYWBcL6eqM=
github.com/mzz2017/softwind v0.0.0-20230710142544-73a557cea4a4/go.mod h1:Fz8fgR7/dbnfR6RLpeOMkUDyebq4xShdmjj+cE5jnJ4= github.com/mzz2017/softwind v0.0.0-20230710175107-0107af8a1d26/go.mod h1:Fz8fgR7/dbnfR6RLpeOMkUDyebq4xShdmjj+cE5jnJ4=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=