mirror of
https://github.com/daeuniverse/dae.git
synced 2025-07-15 18:29:08 +07:00
feat: support tcp:// and tcp+udp:// for dns_upstream (#11)
This commit is contained in:
209
control/udp.go
209
control/udp.go
@ -14,6 +14,7 @@ import (
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/component/outbound/dialer"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
@ -94,11 +95,12 @@ func sendPktBind(data []byte, from netip.AddrPort, to netip.AddrPort) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ControlPlane) RelayToUDP(to netip.AddrPort, isDNS bool, dummyFrom *netip.AddrPort, validateRushAns bool) UdpHandler {
|
||||
func (c *ControlPlane) WriteToUDP(to netip.AddrPort, isDNS bool, dummyFrom *netip.AddrPort, validateRushAnsFunc func(from netip.AddrPort) bool) UdpHandler {
|
||||
return func(data []byte, from netip.AddrPort) (err error) {
|
||||
// Do not return conn-unrelated err in this func.
|
||||
|
||||
if isDNS {
|
||||
validateRushAns := validateRushAnsFunc(from)
|
||||
data, err = c.DnsRespHandler(data, validateRushAns)
|
||||
if err != nil {
|
||||
if validateRushAns && errors.Is(err, SuspectedRushAnswerError) {
|
||||
@ -158,15 +160,6 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI
|
||||
return nil
|
||||
}
|
||||
|
||||
// Need to make a DNS request.
|
||||
if c.dnsUpstream.IsValid() {
|
||||
c.log.Tracef("Modify dns target %v to upstream: %v", RefineAddrPortToShow(destToSend), c.dnsUpstream)
|
||||
// Modify dns target to upstream.
|
||||
// NOTICE: Routing was calculated in advance by the eBPF program.
|
||||
dummyFrom = &dst
|
||||
destToSend = c.dnsUpstream
|
||||
}
|
||||
|
||||
// Flip dns question to reduce dns pollution.
|
||||
FlipDnsQuestionCase(dnsMessage)
|
||||
// Make sure there is additional record OPT in the request to filter DNS rush-answer in the response process.
|
||||
@ -180,51 +173,167 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI
|
||||
}
|
||||
}
|
||||
|
||||
// We only validate rush-ans when outbound is direct and pkt does not send to a home device.
|
||||
// Because additional record OPT may not be supported by home router.
|
||||
// So se should trust home devices even if they make rush-answer (or looks like).
|
||||
validateRushAns := outboundIndex == consts.OutboundDirect && !destToSend.Addr().IsPrivate()
|
||||
|
||||
// Get udp endpoint.
|
||||
l4proto := consts.L4ProtoStr_UDP
|
||||
ipversion := consts.IpVersionFromAddr(dst.Addr())
|
||||
getNew:
|
||||
ue, isNew, err := DefaultUdpEndpointPool.GetOrCreate(src, &UdpEndpointOptions{
|
||||
Handler: c.RelayToUDP(src, isDns, dummyFrom, validateRushAns),
|
||||
NatTimeout: natTimeout,
|
||||
DialerFunc: func() (*dialer.Dialer, error) {
|
||||
newDialer, err := outbound.Select(l4proto, ipversion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to select dialer from group %v: %w", outbound.Name, err)
|
||||
}
|
||||
return newDialer, nil
|
||||
},
|
||||
Target: destToSend,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to GetOrCreate: %w", err)
|
||||
}
|
||||
// If the udp endpoint has been not alive, remove it from pool and get a new one.
|
||||
if !isNew && !ue.Dialer.MustGetAlive(l4proto, ipversion) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"src": src.String(),
|
||||
"network": string(l4proto) + string(ipversion),
|
||||
"dialer": ue.Dialer.Name(),
|
||||
}).Debugln("Old udp endpoint is not alive and removed")
|
||||
_ = DefaultUdpEndpointPool.Remove(src, ue)
|
||||
goto getNew
|
||||
}
|
||||
// This is real dialer.
|
||||
d := ue.Dialer
|
||||
var dialerForNew *dialer.Dialer
|
||||
|
||||
if isNew {
|
||||
// For DNS request, modify dst to dns upstream.
|
||||
// NOTICE: We might modify l4proto and ipversion.
|
||||
if isDns && c.dnsUpstream != nil {
|
||||
// Modify dns target to upstream.
|
||||
// NOTICE: Routing was calculated in advance by the eBPF program.
|
||||
|
||||
/// Choose the best l4proto and ipversion.
|
||||
// Get available ipversions and l4protos for DNS upstream.
|
||||
ipversions, l4protos := c.dnsUpstream.SupportedNetworks()
|
||||
var (
|
||||
bestDialer *dialer.Dialer
|
||||
bestLatency time.Duration
|
||||
bestTarget netip.AddrPort
|
||||
)
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"ipversions": ipversions,
|
||||
"l4protos": l4protos,
|
||||
}).Debugln("Choose DNS path")
|
||||
// Get the min latency path.
|
||||
for _, ver := range ipversions {
|
||||
for _, proto := range l4protos {
|
||||
d, latency, err := outbound.Select(proto, ver)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"latency": latency,
|
||||
"ver": ver,
|
||||
"proto": proto,
|
||||
"outbound": outbound.Name,
|
||||
}).Debugln("Choose")
|
||||
if bestDialer == nil || latency < bestLatency {
|
||||
bestDialer = d
|
||||
bestLatency = latency
|
||||
l4proto = proto
|
||||
ipversion = ver
|
||||
}
|
||||
}
|
||||
}
|
||||
switch ipversion {
|
||||
case consts.IpVersionStr_4:
|
||||
bestTarget = netip.AddrPortFrom(c.dnsUpstream.Ip4, c.dnsUpstream.Port)
|
||||
case consts.IpVersionStr_6:
|
||||
bestTarget = netip.AddrPortFrom(c.dnsUpstream.Ip6, c.dnsUpstream.Port)
|
||||
}
|
||||
dialerForNew = bestDialer
|
||||
dummyFrom = &dst
|
||||
destToSend = bestTarget
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"Original": RefineAddrPortToShow(dst),
|
||||
"New": destToSend,
|
||||
"Network": string(l4proto) + string(ipversion),
|
||||
}).Traceln("Modify DNS target")
|
||||
}
|
||||
if dialerForNew == nil {
|
||||
dialerForNew, _, err = outbound.Select(l4proto, ipversion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to select dialer from group %v: %w", outbound.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
var isNew bool
|
||||
var realDialer *dialer.Dialer
|
||||
|
||||
udpHandler := c.WriteToUDP(src, isDns, dummyFrom, func(from netip.AddrPort) bool {
|
||||
// We only validate rush-ans when outbound is direct and pkt does not send to a home device.
|
||||
// Because additional record OPT may not be supported by home router.
|
||||
// So se should trust home devices even if they make rush-answer (or looks like).
|
||||
return outboundIndex == consts.OutboundDirect && !from.Addr().IsPrivate()
|
||||
})
|
||||
|
||||
// Dial and send.
|
||||
switch l4proto {
|
||||
case consts.L4ProtoStr_UDP:
|
||||
// Get udp endpoint.
|
||||
var ue *UdpEndpoint
|
||||
getNew:
|
||||
ue, isNew, err = DefaultUdpEndpointPool.GetOrCreate(src, &UdpEndpointOptions{
|
||||
Handler: udpHandler,
|
||||
NatTimeout: natTimeout,
|
||||
DialerFunc: func() (*dialer.Dialer, error) {
|
||||
return dialerForNew, nil
|
||||
},
|
||||
Target: destToSend,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to GetOrCreate: %w", err)
|
||||
}
|
||||
// If the udp endpoint has been not alive, remove it from pool and get a new one.
|
||||
if !isNew && !ue.Dialer.MustGetAlive(l4proto, ipversion) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"src": src.String(),
|
||||
"network": string(l4proto) + string(ipversion),
|
||||
"dialer": ue.Dialer.Name(),
|
||||
}).Debugln("Old udp endpoint is not alive and removed")
|
||||
_ = DefaultUdpEndpointPool.Remove(src, ue)
|
||||
goto getNew
|
||||
}
|
||||
// This is real dialer.
|
||||
realDialer = ue.Dialer
|
||||
|
||||
//log.Printf("WriteToUDPAddrPort->%v", destToSend)
|
||||
_, err = ue.WriteToUDPAddrPort(data, destToSend)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write UDP packet req: %w", err)
|
||||
}
|
||||
case consts.L4ProtoStr_TCP:
|
||||
// MUST be DNS.
|
||||
if !isDns {
|
||||
return fmt.Errorf("UDP to TCP only support DNS request")
|
||||
}
|
||||
realDialer = dialerForNew
|
||||
|
||||
// We can block because we are in a coroutine.
|
||||
|
||||
conn, err := dialerForNew.Dial("tcp", destToSend.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to dial proxy to tcp: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_ = conn.SetDeadline(time.Now().Add(natTimeout))
|
||||
// We should write two byte length in the front of TCP DNS request.
|
||||
bLen := pool.Get(2)
|
||||
defer pool.Put(bLen)
|
||||
binary.BigEndian.PutUint16(bLen, uint16(len(data)))
|
||||
_, err = conn.Write(bLen)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write DNS req length: %w", err)
|
||||
}
|
||||
if _, err = conn.Write(data); err != nil {
|
||||
return fmt.Errorf("failed to write DNS req payload: %w", err)
|
||||
}
|
||||
|
||||
// Read two byte length.
|
||||
if _, err = io.ReadFull(conn, bLen); err != nil {
|
||||
return fmt.Errorf("failed to read DNS resp payload length: %w", err)
|
||||
}
|
||||
buf := pool.Get(int(binary.BigEndian.Uint16(bLen)))
|
||||
defer pool.Put(buf)
|
||||
if _, err = io.ReadFull(conn, buf); err != nil {
|
||||
return fmt.Errorf("failed to read DNS resp payload: %w", err)
|
||||
}
|
||||
if err = udpHandler(buf, destToSend); err != nil {
|
||||
return fmt.Errorf("failed to write DNS resp to client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Print log.
|
||||
if isNew || isDns {
|
||||
// Only print routing for new connection to avoid the log exploded (Quic and BT).
|
||||
if isDns && c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 {
|
||||
q := dnsMessage.Questions[0]
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"network": string(l4proto) + string(ipversion) + "(DNS)",
|
||||
"outbound": outbound.Name,
|
||||
"dialer": d.Name(),
|
||||
"dialer": realDialer.Name(),
|
||||
"qname": strings.ToLower(q.Name.String()),
|
||||
"qtype": q.Type,
|
||||
}).Infof("%v <-> %v",
|
||||
@ -235,16 +344,12 @@ getNew:
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"network": string(l4proto) + string(ipversion),
|
||||
"outbound": outbound.Name,
|
||||
"dialer": d.Name(),
|
||||
"dialer": realDialer.Name(),
|
||||
}).Infof("%v <-> %v",
|
||||
RefineSourceToShow(src, destToSend.Addr()), RefineAddrPortToShow(destToSend),
|
||||
)
|
||||
}
|
||||
}
|
||||
//log.Printf("WriteToUDPAddrPort->%v", destToSend)
|
||||
_, err = ue.WriteToUDPAddrPort(data, destToSend)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write UDP packet req: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user