mirror of
https://github.com/daeuniverse/dae.git
synced 2025-07-27 08:09:42 +07:00
fix/optimize/refactor(udp): fix potential stuck UDP and optimize reroute logic (#204)
This commit is contained in:
149
control/udp.go
149
control/udp.go
@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/daeuniverse/dae/common"
|
||||
"github.com/daeuniverse/dae/common/consts"
|
||||
ob "github.com/daeuniverse/dae/component/outbound"
|
||||
"github.com/daeuniverse/dae/component/outbound/dialer"
|
||||
"github.com/daeuniverse/dae/component/sniffing"
|
||||
internal "github.com/daeuniverse/dae/pkg/ebpf_internal"
|
||||
@ -31,6 +32,13 @@ const (
|
||||
MaxRetry = 2
|
||||
)
|
||||
|
||||
type DialOption struct {
|
||||
Target string
|
||||
Dialer *dialer.Dialer
|
||||
Outbound *ob.DialerGroup
|
||||
Network string
|
||||
}
|
||||
|
||||
func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout time.Duration) {
|
||||
if sniffDns {
|
||||
var dnsmsg dnsmessage.Msg
|
||||
@ -139,43 +147,9 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Get outbound.
|
||||
outboundIndex := consts.OutboundIndex(routingResult.Outbound)
|
||||
if c.dialMode == consts.DialMode_DomainCao && domain != "" {
|
||||
outboundIndex = consts.OutboundControlPlaneRouting
|
||||
}
|
||||
|
||||
dialTarget, shouldReroute, dialIp := c.ChooseDialTarget(outboundIndex, realDst, domain)
|
||||
if shouldReroute {
|
||||
outboundIndex = consts.OutboundControlPlaneRouting
|
||||
}
|
||||
|
||||
if routingResult.Must > 0 {
|
||||
isDns = false // Regard as plain traffic.
|
||||
}
|
||||
switch outboundIndex {
|
||||
case consts.OutboundDirect:
|
||||
case consts.OutboundControlPlaneRouting:
|
||||
if isDns {
|
||||
// Routing of DNS packets are managed by DNS controller.
|
||||
break
|
||||
}
|
||||
|
||||
if outboundIndex, routingResult.Mark, _, err = c.Route(realSrc, realDst, domain, consts.L4ProtoType_TCP, routingResult); err != nil {
|
||||
return err
|
||||
}
|
||||
routingResult.Outbound = uint8(outboundIndex)
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.Tracef("outbound: %v => %v",
|
||||
consts.OutboundControlPlaneRouting.String(),
|
||||
outboundIndex.String(),
|
||||
)
|
||||
}
|
||||
// Reset dialTarget.
|
||||
dialTarget, _, dialIp = c.ChooseDialTarget(outboundIndex, realDst, domain)
|
||||
default:
|
||||
}
|
||||
if routingResult.Mark == 0 {
|
||||
routingResult.Mark = c.soMarkFromDae
|
||||
}
|
||||
@ -190,23 +164,6 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
|
||||
})
|
||||
}
|
||||
|
||||
if int(outboundIndex) >= len(c.outbounds) {
|
||||
return fmt.Errorf("outbound %v out of range [0, %v]", outboundIndex, len(c.outbounds)-1)
|
||||
}
|
||||
outbound := c.outbounds[outboundIndex]
|
||||
|
||||
// Select dialer from outbound (dialer group).
|
||||
networkType := &dialer.NetworkType{
|
||||
L4Proto: consts.L4ProtoStr_UDP,
|
||||
IpVersion: consts.IpVersionFromAddr(realDst.Addr()),
|
||||
IsDns: true, // UDP relies on DNS check result.
|
||||
}
|
||||
strictIpVersion := dialIp
|
||||
dialerForNew, _, err := outbound.Select(networkType, strictIpVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err)
|
||||
}
|
||||
|
||||
// Dial and send.
|
||||
// TODO: Rewritten domain should not use full-cone (such as VMess Packet Addr).
|
||||
// Maybe we should set up a mapping for UDP: Dialer + Target Domain => Remote Resolved IP.
|
||||
@ -215,6 +172,17 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
|
||||
// Get udp endpoint.
|
||||
var ue *UdpEndpoint
|
||||
retry := 0
|
||||
networkType := &dialer.NetworkType{
|
||||
L4Proto: consts.L4ProtoStr_UDP,
|
||||
IpVersion: consts.IpVersionFromAddr(realDst.Addr()),
|
||||
IsDns: true, // UDP relies on DNS check result.
|
||||
}
|
||||
// Get outbound.
|
||||
outboundIndex := consts.OutboundIndex(routingResult.Outbound)
|
||||
if c.dialMode == consts.DialMode_DomainCao && domain != "" {
|
||||
outboundIndex = consts.OutboundControlPlaneRouting
|
||||
}
|
||||
dialTarget, shouldReroute, dialIp := c.ChooseDialTarget(outboundIndex, realDst, domain)
|
||||
getNew:
|
||||
if retry > MaxRetry {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
@ -232,16 +200,59 @@ getNew:
|
||||
return sendPkt(data, from, realSrc, src, lConn, lanWanFlag)
|
||||
},
|
||||
NatTimeout: natTimeout,
|
||||
Dialer: dialerForNew,
|
||||
Network: common.MagicNetwork("udp", routingResult.Mark),
|
||||
Target: dialTarget,
|
||||
GetDialOption: func() (option *DialOption, err error) {
|
||||
if shouldReroute {
|
||||
outboundIndex = consts.OutboundControlPlaneRouting
|
||||
}
|
||||
|
||||
switch outboundIndex {
|
||||
case consts.OutboundDirect:
|
||||
case consts.OutboundControlPlaneRouting:
|
||||
if isDns {
|
||||
// Routing of DNS packets are managed by DNS controller.
|
||||
break
|
||||
}
|
||||
|
||||
if outboundIndex, routingResult.Mark, _, err = c.Route(realSrc, realDst, domain, consts.L4ProtoType_TCP, routingResult); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
routingResult.Outbound = uint8(outboundIndex)
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.Tracef("outbound: %v => %v",
|
||||
consts.OutboundControlPlaneRouting.String(),
|
||||
outboundIndex.String(),
|
||||
)
|
||||
}
|
||||
// Reset dialTarget.
|
||||
dialTarget, _, dialIp = c.ChooseDialTarget(outboundIndex, realDst, domain)
|
||||
default:
|
||||
}
|
||||
|
||||
if int(outboundIndex) >= len(c.outbounds) {
|
||||
return nil, fmt.Errorf("outbound %v out of range [0, %v]", outboundIndex, len(c.outbounds)-1)
|
||||
}
|
||||
outbound := c.outbounds[outboundIndex]
|
||||
|
||||
// Select dialer from outbound (dialer group).
|
||||
strictIpVersion := dialIp
|
||||
dialerForNew, _, err := outbound.Select(networkType, strictIpVersion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err)
|
||||
}
|
||||
return &DialOption{
|
||||
Target: dialTarget,
|
||||
Dialer: dialerForNew,
|
||||
Outbound: outbound,
|
||||
Network: common.MagicNetwork("udp", routingResult.Mark),
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to GetOrCreate (policy: %v): %w", outbound.GetSelectionPolicy(), err)
|
||||
return fmt.Errorf("failed to GetOrCreate: %w", err)
|
||||
}
|
||||
|
||||
// If the udp endpoint has been not alive, remove it from pool and get a new one.
|
||||
if !isNew && outbound.GetSelectionPolicy() != consts.DialerSelectionPolicy_Fixed && !ue.Dialer.MustGetAlive(networkType) {
|
||||
if !isNew && ue.Outbound.GetSelectionPolicy() != consts.DialerSelectionPolicy_Fixed && !ue.Dialer.MustGetAlive(networkType) {
|
||||
|
||||
if c.log.IsLevelEnabled(logrus.DebugLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
@ -278,21 +289,19 @@ getNew:
|
||||
|
||||
// Print log.
|
||||
// Only print routing for new connection to avoid the log exploded (Quic and BT).
|
||||
if isNew {
|
||||
if c.log.IsLevelEnabled(logrus.InfoLevel) {
|
||||
fields := logrus.Fields{
|
||||
"network": networkType.StringWithoutDns(),
|
||||
"outbound": outbound.Name,
|
||||
"policy": outbound.GetSelectionPolicy(),
|
||||
"dialer": ue.Dialer.Property().Name,
|
||||
"domain": domain,
|
||||
"ip": RefineAddrPortToShow(realDst),
|
||||
"pid": routingResult.Pid,
|
||||
"pname": ProcessName2String(routingResult.Pname[:]),
|
||||
"mac": Mac2String(routingResult.Mac[:]),
|
||||
}
|
||||
c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), dialTarget)
|
||||
if isNew && c.log.IsLevelEnabled(logrus.InfoLevel) || c.log.IsLevelEnabled(logrus.DebugLevel) {
|
||||
fields := logrus.Fields{
|
||||
"network": networkType.StringWithoutDns(),
|
||||
"outbound": ue.Outbound.Name,
|
||||
"policy": ue.Outbound.GetSelectionPolicy(),
|
||||
"dialer": ue.Dialer.Property().Name,
|
||||
"domain": domain,
|
||||
"ip": RefineAddrPortToShow(realDst),
|
||||
"pid": routingResult.Pid,
|
||||
"pname": ProcessName2String(routingResult.Pname[:]),
|
||||
"mac": Mac2String(routingResult.Mac[:]),
|
||||
}
|
||||
c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), dialTarget)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
Reference in New Issue
Block a user