mirror of
https://github.com/daeuniverse/dae.git
synced 2024-12-22 16:04:40 +07:00
fix: incidental packet drop and weird UDP state maintaining (#539)
This commit is contained in:
parent
ed50de2ef8
commit
93e47ffe88
@ -764,7 +764,15 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err
|
||||
copy(newBuf, buf[:n])
|
||||
newOob := pool.Get(oobn)
|
||||
copy(newOob, oob[:oobn])
|
||||
go func(data pool.PB, oob pool.PB, src netip.AddrPort) {
|
||||
newSrc := src
|
||||
convergeSrc := common.ConvergeAddrPort(src)
|
||||
// Debug:
|
||||
// t := time.Now()
|
||||
DefaultUdpTaskPool.EmitTask(convergeSrc.String(), func() {
|
||||
data := newBuf
|
||||
oob := newOob
|
||||
src := newSrc
|
||||
|
||||
defer data.Put()
|
||||
defer oob.Put()
|
||||
var realDst netip.AddrPort
|
||||
@ -777,10 +785,13 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err
|
||||
} else {
|
||||
realDst = pktDst
|
||||
}
|
||||
if e := c.handlePkt(udpConn, data, common.ConvergeAddrPort(src), common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult, false); e != nil {
|
||||
if e := c.handlePkt(udpConn, data, convergeSrc, common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult, false); e != nil {
|
||||
c.log.Warnln("handlePkt:", e)
|
||||
}
|
||||
}(newBuf, newOob, src)
|
||||
})
|
||||
// if d := time.Since(t); d > 100*time.Millisecond {
|
||||
// logrus.Println(d)
|
||||
// }
|
||||
}
|
||||
}()
|
||||
c.ActivateCheck()
|
||||
|
@ -1290,24 +1290,19 @@ refresh_udp_conn_state_timer(struct tuples_key *key, bool is_egress)
|
||||
if (unlikely(!value))
|
||||
return NULL;
|
||||
|
||||
ret = bpf_timer_init(&value->timer, &udp_conn_state_map,
|
||||
CLOCK_MONOTONIC);
|
||||
if (unlikely(ret))
|
||||
goto del;
|
||||
if ((ret = bpf_timer_init(&value->timer, &udp_conn_state_map,
|
||||
CLOCK_MONOTONIC)))
|
||||
goto retn;
|
||||
|
||||
ret = bpf_timer_set_callback(&value->timer,
|
||||
refresh_udp_conn_state_timer_cb);
|
||||
if (unlikely(ret))
|
||||
goto del;
|
||||
if ((ret = bpf_timer_set_callback(&value->timer,
|
||||
refresh_udp_conn_state_timer_cb)))
|
||||
goto retn;
|
||||
|
||||
ret = bpf_timer_start(&value->timer, TIMEOUT_UDP_CONN_STATE, 0);
|
||||
if (unlikely(ret))
|
||||
goto del;
|
||||
if ((ret = bpf_timer_start(&value->timer, TIMEOUT_UDP_CONN_STATE, 0)))
|
||||
goto retn;
|
||||
|
||||
retn:
|
||||
return value;
|
||||
del:
|
||||
bpf_map_delete_elem(&udp_conn_state_map, key);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
SEC("tc/wan_ingress")
|
||||
@ -1515,17 +1510,22 @@ int tproxy_wan_egress(struct __sk_buff *skb)
|
||||
flag[6] = tuples.dscp;
|
||||
struct pid_pname *pid_pname;
|
||||
|
||||
struct udp_conn_state *conn_state =
|
||||
refresh_udp_conn_state_timer(&tuples.five, true);
|
||||
if (!conn_state)
|
||||
return TC_ACT_SHOT;
|
||||
if (!conn_state->is_egress ||
|
||||
pid_is_control_plane(skb, &pid_pname)) {
|
||||
// Input udp connection or
|
||||
if (pid_is_control_plane(skb, &pid_pname)) {
|
||||
// from control plane
|
||||
// => direct.
|
||||
return TC_ACT_OK;
|
||||
}
|
||||
|
||||
struct udp_conn_state *conn_state =
|
||||
refresh_udp_conn_state_timer(&tuples.five, true);
|
||||
if (!conn_state)
|
||||
return TC_ACT_SHOT;
|
||||
if (!conn_state->is_egress) {
|
||||
// Input udp connection
|
||||
// => direct.
|
||||
return TC_ACT_OK;
|
||||
}
|
||||
|
||||
if (pid_pname) {
|
||||
// 2, 3, 4, 5
|
||||
__builtin_memcpy(&flag[2], pid_pname->pname,
|
||||
|
@ -37,7 +37,7 @@ type PacketSnifferKey struct {
|
||||
RAddr netip.AddrPort
|
||||
}
|
||||
|
||||
var DefaultPacketSnifferPool = NewPacketSnifferPool()
|
||||
var DefaultPacketSnifferSessionMgr = NewPacketSnifferPool()
|
||||
|
||||
func NewPacketSnifferPool() *PacketSnifferPool {
|
||||
return &PacketSnifferPool{}
|
||||
|
@ -21,7 +21,7 @@ var testPacketSnifferData = []string{
|
||||
func TestPacketSniffer_Normal(t *testing.T) {
|
||||
for _, _data := range testPacketSnifferData {
|
||||
data, _ := hex.DecodeString(_data)
|
||||
sniffer, _ := DefaultPacketSnifferPool.GetOrCreate(PacketSnifferKey{
|
||||
sniffer, _ := DefaultPacketSnifferSessionMgr.GetOrCreate(PacketSnifferKey{
|
||||
LAddr: netip.MustParseAddrPort("1.1.1.1:1111"),
|
||||
RAddr: netip.MustParseAddrPort("2.2.2.2:2222"),
|
||||
}, nil)
|
||||
@ -44,7 +44,7 @@ func TestPacketSniffer_Mismatched(t *testing.T) {
|
||||
dst := netip.MustParseAddrPort("2.2.2.2:2222")
|
||||
for _, _data := range testPacketSnifferData {
|
||||
data, _ := hex.DecodeString(_data)
|
||||
sniffer, _ := DefaultPacketSnifferPool.GetOrCreate(PacketSnifferKey{
|
||||
sniffer, _ := DefaultPacketSnifferSessionMgr.GetOrCreate(PacketSnifferKey{
|
||||
LAddr: netip.MustParseAddrPort("1.1.1.1:1111"),
|
||||
RAddr: dst,
|
||||
}, nil)
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"time"
|
||||
|
||||
"github.com/daeuniverse/dae/common"
|
||||
@ -29,10 +30,11 @@ const (
|
||||
)
|
||||
|
||||
type DialOption struct {
|
||||
Target string
|
||||
Dialer *dialer.Dialer
|
||||
Outbound *ob.DialerGroup
|
||||
Network string
|
||||
Target string
|
||||
Dialer *dialer.Dialer
|
||||
Outbound *ob.DialerGroup
|
||||
Network string
|
||||
SniffedDomain string
|
||||
}
|
||||
|
||||
func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout time.Duration) {
|
||||
@ -60,21 +62,50 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
|
||||
var realSrc netip.AddrPort
|
||||
var domain string
|
||||
realSrc = src
|
||||
ue, ueExists := DefaultUdpEndpointPool.Get(realSrc)
|
||||
if ueExists && ue.SniffedDomain != "" {
|
||||
// It is quic ...
|
||||
// Fast path.
|
||||
domain := ue.SniffedDomain
|
||||
dialTarget := realDst.String()
|
||||
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
fields := logrus.Fields{
|
||||
"network": "udp(fp)",
|
||||
"outbound": ue.Outbound.Name,
|
||||
"policy": ue.Outbound.GetSelectionPolicy(),
|
||||
"dialer": ue.Dialer.Property().Name,
|
||||
"sniffed": domain,
|
||||
"ip": RefineAddrPortToShow(realDst),
|
||||
"pid": routingResult.Pid,
|
||||
"dscp": routingResult.Dscp,
|
||||
"pname": ProcessName2String(routingResult.Pname[:]),
|
||||
"mac": Mac2String(routingResult.Mac[:]),
|
||||
}
|
||||
c.log.WithFields(fields).Tracef("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr()), dialTarget)
|
||||
}
|
||||
|
||||
_, err = ue.WriteTo(data, dialTarget)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// To keep consistency with kernel program, we only sniff DNS request sent to 53.
|
||||
dnsMessage, natTimeout := ChooseNatTimeout(data, realDst.Port() == 53)
|
||||
// We should cache DNS records and set record TTL to 0, in order to monitor the dns req and resp in real time.
|
||||
isDns := dnsMessage != nil
|
||||
if !isDns && !skipSniffing && !DefaultUdpEndpointPool.Exists(realSrc) {
|
||||
if !isDns && !skipSniffing && !ueExists {
|
||||
// Sniff Quic, ...
|
||||
key := PacketSnifferKey{
|
||||
LAddr: realSrc,
|
||||
RAddr: realDst,
|
||||
}
|
||||
_sniffer, _ := DefaultPacketSnifferPool.GetOrCreate(key, nil)
|
||||
_sniffer, _ := DefaultPacketSnifferSessionMgr.GetOrCreate(key, nil)
|
||||
_sniffer.Mu.Lock()
|
||||
// Re-get sniffer from pool to confirm the transaction is not done.
|
||||
sniffer := DefaultPacketSnifferPool.Get(key)
|
||||
sniffer := DefaultPacketSnifferSessionMgr.Get(key)
|
||||
if _sniffer == sniffer {
|
||||
sniffer.AppendData(data)
|
||||
domain, err = sniffer.SniffUdp()
|
||||
@ -92,7 +123,7 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
|
||||
WithField("to", realDst).
|
||||
Trace("sniffUdp")
|
||||
}
|
||||
defer DefaultPacketSnifferPool.Remove(key, sniffer)
|
||||
defer DefaultPacketSnifferSessionMgr.Remove(key, sniffer)
|
||||
// Re-handlePkt after self func.
|
||||
toRehandle := sniffer.Data()[1 : len(sniffer.Data())-1] // Skip the first empty and the last (self).
|
||||
sniffer.Mu.Unlock()
|
||||
@ -134,7 +165,6 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
|
||||
// However, games may not use QUIC for communication, thus we cannot use domain to dial, which is fine.
|
||||
|
||||
// Get udp endpoint.
|
||||
var ue *UdpEndpoint
|
||||
retry := 0
|
||||
networkType := &dialer.NetworkType{
|
||||
L4Proto: consts.L4ProtoStr_UDP,
|
||||
@ -217,10 +247,11 @@ getNew:
|
||||
return nil, fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err)
|
||||
}
|
||||
return &DialOption{
|
||||
Target: dialTarget,
|
||||
Dialer: dialerForNew,
|
||||
Outbound: outbound,
|
||||
Network: common.MagicNetwork("udp", routingResult.Mark),
|
||||
Target: dialTarget,
|
||||
Dialer: dialerForNew,
|
||||
Outbound: outbound,
|
||||
Network: common.MagicNetwork("udp", routingResult.Mark),
|
||||
SniffedDomain: domain,
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
@ -243,6 +274,10 @@ getNew:
|
||||
retry++
|
||||
goto getNew
|
||||
}
|
||||
if domain == "" {
|
||||
// It is used for showing.
|
||||
domain = ue.SniffedDomain
|
||||
}
|
||||
|
||||
_, err = ue.WriteTo(data, dialTarget)
|
||||
if err != nil {
|
||||
@ -280,7 +315,11 @@ getNew:
|
||||
"pname": ProcessName2String(routingResult.Pname[:]),
|
||||
"mac": Mac2String(routingResult.Mac[:]),
|
||||
}
|
||||
c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr()), dialTarget)
|
||||
logger := c.log.WithFields(fields).Infof
|
||||
if !isNew && c.log.IsLevelEnabled(logrus.DebugLevel) {
|
||||
logger = c.log.WithFields(fields).Debugf
|
||||
}
|
||||
logger("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr()), dialTarget)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -31,6 +31,10 @@ type UdpEndpoint struct {
|
||||
|
||||
Dialer *dialer.Dialer
|
||||
Outbound *outbound.DialerGroup
|
||||
|
||||
// Non-empty indicates this UDP Endpoint is related with a sniffed domain.
|
||||
SniffedDomain string
|
||||
DialTarget string
|
||||
}
|
||||
|
||||
func (ue *UdpEndpoint) start() {
|
||||
@ -95,9 +99,12 @@ func (p *UdpEndpointPool) Remove(lAddr netip.AddrPort, udpEndpoint *UdpEndpoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *UdpEndpointPool) Exists(lAddr netip.AddrPort) (ok bool) {
|
||||
_, ok = p.pool.Load(lAddr)
|
||||
return ok
|
||||
func (p *UdpEndpointPool) Get(lAddr netip.AddrPort) (udpEndpoint *UdpEndpoint, ok bool) {
|
||||
_ue, ok := p.pool.Load(lAddr)
|
||||
if !ok {
|
||||
return nil, ok
|
||||
}
|
||||
return _ue.(*UdpEndpoint), ok
|
||||
}
|
||||
|
||||
func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEndpointOptions) (udpEndpoint *UdpEndpoint, isNew bool, err error) {
|
||||
@ -146,6 +153,8 @@ begin:
|
||||
NatTimeout: createOption.NatTimeout,
|
||||
Dialer: dialOption.Dialer,
|
||||
Outbound: dialOption.Outbound,
|
||||
SniffedDomain: dialOption.SniffedDomain,
|
||||
DialTarget: dialOption.Target,
|
||||
}
|
||||
ue.deadlineTimer = time.AfterFunc(createOption.NatTimeout, func() {
|
||||
if _ue, ok := p.pool.LoadAndDelete(lAddr); ok {
|
||||
|
92
control/udp_task_pool.go
Normal file
92
control/udp_task_pool.go
Normal file
@ -0,0 +1,92 @@
|
||||
package control
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const UdpTaskQueueLength = 128
|
||||
|
||||
type UdpTask = func()
|
||||
|
||||
type UdpTaskQueue struct {
|
||||
ch chan UdpTask
|
||||
timer *time.Timer
|
||||
agingTime time.Duration
|
||||
closed chan struct{}
|
||||
freed chan struct{}
|
||||
}
|
||||
|
||||
func (q *UdpTaskQueue) Push(task UdpTask) {
|
||||
q.timer.Reset(q.agingTime)
|
||||
q.ch <- task
|
||||
}
|
||||
|
||||
type UdpTaskPool struct {
|
||||
queueChPool sync.Pool
|
||||
// mu protects m
|
||||
mu sync.Mutex
|
||||
m map[string]*UdpTaskQueue
|
||||
}
|
||||
|
||||
func NewUdpTaskPool() *UdpTaskPool {
|
||||
p := &UdpTaskPool{
|
||||
queueChPool: sync.Pool{New: func() any {
|
||||
return make(chan UdpTask, UdpTaskQueueLength)
|
||||
}},
|
||||
mu: sync.Mutex{},
|
||||
m: map[string]*UdpTaskQueue{},
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *UdpTaskPool) convoy(q *UdpTaskQueue) {
|
||||
for {
|
||||
select {
|
||||
case <-q.closed:
|
||||
clearloop:
|
||||
for {
|
||||
select {
|
||||
case <-q.ch:
|
||||
default:
|
||||
break clearloop
|
||||
}
|
||||
}
|
||||
close(q.freed)
|
||||
return
|
||||
case t := <-q.ch:
|
||||
t()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *UdpTaskPool) EmitTask(key string, task UdpTask) {
|
||||
p.mu.Lock()
|
||||
q, ok := p.m[key]
|
||||
if !ok {
|
||||
ch := p.queueChPool.Get().(chan UdpTask)
|
||||
q = &UdpTaskQueue{
|
||||
ch: ch,
|
||||
timer: nil,
|
||||
agingTime: DefaultNatTimeout,
|
||||
closed: make(chan struct{}),
|
||||
freed: make(chan struct{}),
|
||||
}
|
||||
q.timer = time.AfterFunc(q.agingTime, func() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.m[key] == q {
|
||||
delete(p.m, key)
|
||||
}
|
||||
close(q.closed)
|
||||
<-q.freed
|
||||
p.queueChPool.Put(ch)
|
||||
})
|
||||
p.m[key] = q
|
||||
go p.convoy(q)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
q.Push(task)
|
||||
}
|
||||
|
||||
var DefaultUdpTaskPool = NewUdpTaskPool()
|
1
go.mod
1
go.mod
@ -39,6 +39,7 @@ require (
|
||||
github.com/gorilla/websocket v1.5.0 // indirect
|
||||
github.com/klauspost/compress v1.17.4 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.11.0 // indirect
|
||||
github.com/stretchr/testify v1.8.1 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
golang.org/x/mod v0.12.0 // indirect
|
||||
golang.org/x/net v0.20.0 // indirect
|
||||
|
7
go.sum
7
go.sum
@ -133,12 +133,17 @@ github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRM
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg=
|
||||
github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ=
|
||||
github.com/v2rayA/ahocorasick-domain v0.0.0-20231231085011-99ceb8ef3208 h1:s/K1ome/+rTDictkqGhqLuAleUymyWnvgNWARjblS9U=
|
||||
|
Loading…
Reference in New Issue
Block a user