diff --git a/README.md b/README.md index c1cda87..155a0e9 100644 --- a/README.md +++ b/README.md @@ -49,22 +49,18 @@ Use following command to show kernel configuration items on your machine. zcat /proc/config.gz || cat /boot/{config,config-$(uname -r)} ``` -**Bind to LAN** - +dae needs: ``` -CONFIG_DEBUG_INFO_BTF +CONFIG_DEBUG_INFO_BTF=y +CONFIG_NET_CLS_ACT=y +CONFIG_NET_SCH_INGRESS=m +CONFIG_NET_INGRESS=y +CONFIG_NET_EGRESS=y ``` - -**Bind to WAN**: - -``` -CONFIG_DEBUG_INFO_BTF -``` - Check them using command like: ```shell -(zcat /proc/config.gz || cat /boot/{config,config-$(uname -r)}) | grep 'CONFIG_DEBUG_INFO_BTF=' +(zcat /proc/config.gz || cat /boot/{config,config-$(uname -r)}) | grep -E 'CONFIG_(DEBUG_INFO_BTF|NET_CLS_ACT|NET_SCH_INGRESS|NET_INGRESS|NET_EGRESS)=' ``` ### Enable IP Forwarding diff --git a/component/outbound/dialer/alive_dialer_set.go b/component/outbound/dialer/alive_dialer_set.go index 2e9069a..6f48623 100644 --- a/component/outbound/dialer/alive_dialer_set.go +++ b/component/outbound/dialer/alive_dialer_set.go @@ -90,6 +90,9 @@ func (a *AliveDialerSet) GetMinLatency() (d *Dialer, latency time.Duration) { } func (a *AliveDialerSet) printLatencies() { + if !a.log.IsLevelEnabled(logrus.TraceLevel) { + return + } var builder strings.Builder builder.WriteString(fmt.Sprintf("%v (%v):\n", a.dialerGroupName, a.CheckTyp.String())) for _, d := range a.inorderedAliveDialerSet { diff --git a/control/control_plane.go b/control/control_plane.go index b5ecc01..248d08c 100644 --- a/control/control_plane.go +++ b/control/control_plane.go @@ -523,33 +523,32 @@ func (c *ControlPlane) ListenAndServe(port uint16) (err error) { } break } - pktDst := RetrieveOriginalDest(oob[:oobn]) - var newBuf []byte - var realDst netip.AddrPort - outboundIndex, err := c.core.RetrieveOutboundIndex(src, pktDst, unix.IPPROTO_UDP) - if err != nil { - // WAN. Old method. - addrHdr, dataOffset, err := ParseAddrHdr(buf[:n]) + newBuf := pool.Get(n) + copy(newBuf, buf[:n]) + go func(data []byte, src netip.AddrPort) { + defer pool.Put(data) + var realDst netip.AddrPort + var outboundIndex consts.OutboundIndex + pktDst := RetrieveOriginalDest(oob[:oobn]) + outboundIndex, err := c.core.RetrieveOutboundIndex(src, pktDst, unix.IPPROTO_UDP) if err != nil { - c.log.Warnf("No AddrPort presented") - continue + // WAN. Old method. + addrHdr, dataOffset, err := ParseAddrHdr(data) + if err != nil { + c.log.Warnf("No AddrPort presented") + return + } + copy(data, data[dataOffset:]) + outboundIndex = consts.OutboundIndex(addrHdr.Outbound) + src = netip.AddrPortFrom(addrHdr.Dest.Addr(), src.Port()) + realDst = addrHdr.Dest + } else { + realDst = pktDst } - newBuf = pool.Get(n - dataOffset) - copy(newBuf, buf[dataOffset:n]) - outboundIndex = consts.OutboundIndex(addrHdr.Outbound) - src = netip.AddrPortFrom(addrHdr.Dest.Addr(), src.Port()) - realDst = addrHdr.Dest - } else { - newBuf = pool.Get(n) - copy(newBuf, buf[:n]) - realDst = pktDst - } - go func(data []byte, src, pktDst, realDst netip.AddrPort, outboundIndex consts.OutboundIndex) { - if e := c.handlePkt(udpConn, newBuf, src, pktDst, realDst, outboundIndex); e != nil { + if e := c.handlePkt(udpConn, data, src, pktDst, realDst, outboundIndex); e != nil { c.log.Warnln("handlePkt:", e) } - pool.Put(newBuf) - }(newBuf, src, pktDst, realDst, outboundIndex) + }(newBuf, src) } }() <-ctx.Done() diff --git a/control/dns.go b/control/dns.go index f1523f3..df6481a 100644 --- a/control/dns.go +++ b/control/dns.go @@ -85,6 +85,9 @@ func (c *ControlPlane) BatchUpdateDomainRouting(cache *dnsCache) error { ips = append(ips, netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA)) } } + if len(ips) == 0 { + return nil + } // Update bpf map. // Construct keys and vals, and BpfMapBatchUpdate. @@ -316,13 +319,15 @@ loop: } // Update dnsCache. - c.log.WithFields(logrus.Fields{ - "qname": q.Name, - "rcode": msg.RCode, - "ans": FormatDnsRsc(msg.Answers), - "auth": FormatDnsRsc(msg.Authorities), - "addi": FormatDnsRsc(msg.Additionals), - }).Tracef("Update DNS record cache") + if c.log.IsLevelEnabled(logrus.TraceLevel) { + c.log.WithFields(logrus.Fields{ + "qname": q.Name, + "rcode": msg.RCode, + "ans": FormatDnsRsc(msg.Answers), + "auth": FormatDnsRsc(msg.Authorities), + "addi": FormatDnsRsc(msg.Additionals), + }).Tracef("Update DNS record cache") + } if err = c.UpdateDnsCache(q.Name.String(), q.Type, msg.Answers, time.Now().Add(time.Duration(ttl)*time.Second+DnsNatTimeout)); err != nil { return nil, err } diff --git a/control/kern/tproxy.c b/control/kern/tproxy.c index af42b88..3a28bef 100644 --- a/control/kern/tproxy.c +++ b/control/kern/tproxy.c @@ -49,10 +49,11 @@ #define MAX_LPM_SIZE 20480 #define MAX_LPM_NUM (MAX_MATCH_SET_LEN + 8) #define MAX_DST_MAPPING_NUM (65536 * 2) -#define MAX_SRC_PID_PNAME_MAPPING_NUM (65536) -#define IPV6_MAX_EXTENSIONS 4 +#define MAX_COOKIE_PID_PNAME_MAPPING_NUM (65536) +#define MAX_DOMAIN_ROUTING_NUM 65536 #define MAX_ARG_LEN_TO_PROBE 192 #define MAX_ARG_SCANNER_BUFFER_SIZE (TASK_COMM_LEN * 4) +#define IPV6_MAX_EXTENSIONS 4 #define OUTBOUND_DIRECT 0 #define OUTBOUND_BLOCK 1 @@ -327,7 +328,7 @@ struct { __uint(type, BPF_MAP_TYPE_LRU_HASH); __type(key, __be32[4]); __type(value, struct domain_routing); - __uint(max_entries, 65535); + __uint(max_entries, MAX_DOMAIN_ROUTING_NUM); /// NOTICE: No persistence. // __uint(pinning, LIBBPF_PIN_BY_NAME); } domain_routing_map SEC(".maps"); @@ -347,7 +348,7 @@ struct { __uint(type, BPF_MAP_TYPE_LRU_HASH); __type(key, __u64); __type(value, struct pid_pname); - __uint(max_entries, MAX_SRC_PID_PNAME_MAPPING_NUM); + __uint(max_entries, MAX_COOKIE_PID_PNAME_MAPPING_NUM); /// NOTICE: No persistence. __uint(pinning, LIBBPF_PIN_BY_NAME); } cookie_pid_map SEC(".maps"); diff --git a/control/tcp.go b/control/tcp.go index 6fd4ddd..50d173e 100644 --- a/control/tcp.go +++ b/control/tcp.go @@ -66,10 +66,12 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) { case consts.OutboundMustDirect: fallthrough case consts.OutboundControlPlaneDirect: - c.log.Tracef("outbound: %v => %v", - outboundIndex.String(), - consts.OutboundDirect.String(), - ) + if c.log.IsLevelEnabled(logrus.TraceLevel) { + c.log.Tracef("outbound: %v => %v", + outboundIndex.String(), + consts.OutboundDirect.String(), + ) + } outboundIndex = consts.OutboundDirect default: } @@ -87,13 +89,16 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) { if err != nil { return fmt.Errorf("failed to select dialer from group %v (%v): %w", outbound.Name, networkType.String(), err) } - c.log.WithFields(logrus.Fields{ - "network": networkType.String(), - "outbound": outbound.Name, - "policy": outbound.GetSelectionPolicy(), - "dialer": d.Name(), - "domain": domain, - }).Infof("%v <-> %v", RefineSourceToShow(src, dst.Addr(), consts.LanWanFlag_NotApplicable), RefineAddrPortToShow(dst)) + + if c.log.IsLevelEnabled(logrus.InfoLevel) { + c.log.WithFields(logrus.Fields{ + "network": networkType.String(), + "outbound": outbound.Name, + "policy": outbound.GetSelectionPolicy(), + "dialer": d.Name(), + "domain": domain, + }).Infof("%v <-> %v", RefineSourceToShow(src, dst.Addr(), consts.LanWanFlag_NotApplicable), RefineAddrPortToShow(dst)) + } // Dial and relay. dst = netip.AddrPortFrom(common.ConvergeIp(dst.Addr()), dst.Port()) diff --git a/control/udp.go b/control/udp.go index 3d53b03..1430558 100644 --- a/control/udp.go +++ b/control/udp.go @@ -35,7 +35,7 @@ var ( UnspecifiedAddr6 = netip.AddrFrom16([16]byte{}) ) -func ChooseNatTimeout(data []byte) (dmsg *dnsmessage.Message, timeout time.Duration) { +func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Message, timeout time.Duration) { var dnsmsg dnsmessage.Message if err := dnsmsg.Unpack(data); err == nil { //log.Printf("DEBUG: lookup %v", dnsmsg.Questions[0].Name) @@ -125,7 +125,9 @@ func (c *ControlPlane) WriteToUDP(lanWanFlag consts.LanWanFlag, lConn *net.UDPCo }).Tracef("DNS rush-answer rejected") return err } - c.log.Debugf("DnsRespHandler: %v", err) + if c.log.IsLevelEnabled(logrus.DebugLevel) { + c.log.Debugf("DnsRespHandler: %v", err) + } if data == nil { return nil } @@ -159,10 +161,12 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r mustDirect = true fallthrough case consts.OutboundControlPlaneDirect: - c.log.Tracef("outbound: %v => %v", - outboundIndex.String(), - consts.OutboundDirect.String(), - ) + if c.log.IsLevelEnabled(logrus.TraceLevel) { + c.log.Tracef("outbound: %v => %v", + outboundIndex.String(), + consts.OutboundDirect.String(), + ) + } outboundIndex = consts.OutboundDirect default: } @@ -170,7 +174,8 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r return fmt.Errorf("outbound %v out of range [0, %v]", outboundIndex, len(c.outbounds)-1) } outbound := c.outbounds[outboundIndex] - dnsMessage, natTimeout := ChooseNatTimeout(data) + // To keep consistency with kernel program, we only sniff DNS request sent to 53. + dnsMessage, natTimeout := ChooseNatTimeout(data, realDst.Port() == 53) // We should cache DNS records and set record TTL to 0, in order to monitor the dns req and resp in real time. isDns := dnsMessage != nil var dummyFrom *netip.AddrPort @@ -234,11 +239,13 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r bestLatency time.Duration bestTarget netip.AddrPort ) - c.log.WithFields(logrus.Fields{ - "ipversions": ipversions, - "l4protos": l4protos, - "src": realSrc.String(), - }).Traceln("Choose DNS path") + if c.log.IsLevelEnabled(logrus.TraceLevel) { + c.log.WithFields(logrus.Fields{ + "ipversions": ipversions, + "l4protos": l4protos, + "src": realSrc.String(), + }).Traceln("Choose DNS path") + } // Get the min latency path. networkType := dialer.NetworkType{ IsDns: isDns, @@ -251,12 +258,14 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r if err != nil { continue } - c.log.WithFields(logrus.Fields{ - "name": d.Name(), - "latency": latency, - "network": networkType.String(), - "outbound": outbound.Name, - }).Traceln("Choice") + if c.log.IsLevelEnabled(logrus.TraceLevel) { + c.log.WithFields(logrus.Fields{ + "name": d.Name(), + "latency": latency, + "network": networkType.String(), + "outbound": outbound.Name, + }).Traceln("Choice") + } if bestDialer == nil || latency < bestLatency { bestDialer = d bestLatency = latency @@ -274,11 +283,13 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r dialerForNew = bestDialer dummyFrom = &realDst destToSend = bestTarget - c.log.WithFields(logrus.Fields{ - "Original": RefineAddrPortToShow(realDst), - "New": destToSend, - "Network": string(l4proto) + string(ipversion), - }).Traceln("Modify DNS target") + if c.log.IsLevelEnabled(logrus.TraceLevel) { + c.log.WithFields(logrus.Fields{ + "Original": RefineAddrPortToShow(realDst), + "New": destToSend, + "Network": string(l4proto) + string(ipversion), + }).Traceln("Modify DNS target") + } } networkType := &dialer.NetworkType{ L4Proto: l4proto, @@ -329,12 +340,15 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r // If the udp endpoint has been not alive, remove it from pool and get a new one. if !isNew && outbound.GetSelectionPolicy() != consts.DialerSelectionPolicy_Fixed && !ue.Dialer.MustGetAlive(networkType) { - c.log.WithFields(logrus.Fields{ - "src": RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), - "network": networkType.String(), - "dialer": ue.Dialer.Name(), - "retry": retry, - }).Debugln("Old udp endpoint was not alive and removed.") + + if c.log.IsLevelEnabled(logrus.DebugLevel) { + c.log.WithFields(logrus.Fields{ + "src": RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), + "network": networkType.String(), + "dialer": ue.Dialer.Name(), + "retry": retry, + }).Debugln("Old udp endpoint was not alive and removed.") + } _ = DefaultUdpEndpointPool.Remove(realSrc, ue) retry++ goto getNew @@ -344,14 +358,16 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r _, err = ue.WriteTo(data, tgtToSend) if err != nil { - c.log.WithFields(logrus.Fields{ - "to": destToSend.String(), - "domain": domain, - "from": realSrc.String(), - "network": networkType.String(), - "err": err.Error(), - "retry": retry, - }).Debugln("Failed to write UDP packet request. Try to remove old UDP endpoint and retry.") + if c.log.IsLevelEnabled(logrus.DebugLevel) { + c.log.WithFields(logrus.Fields{ + "to": destToSend.String(), + "domain": domain, + "from": realSrc.String(), + "network": networkType.String(), + "err": err.Error(), + "retry": retry, + }).Debugln("Failed to write UDP packet request. Try to remove old UDP endpoint and retry.") + } _ = DefaultUdpEndpointPool.Remove(realSrc, ue) retry++ goto getNew @@ -420,8 +436,7 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r }).Infof("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), RefineAddrPortToShow(destToSend), ) - } else { - // TODO: Set-up ip to domain mapping and show domain if possible. + } else if c.log.IsLevelEnabled(logrus.InfoLevel) { c.log.WithFields(logrus.Fields{ "network": string(l4proto) + string(ipversion), "outbound": outbound.Name, diff --git a/control/udp_endpoint.go b/control/udp_endpoint.go index 127c1e2..2ff83cd 100644 --- a/control/udp_endpoint.go +++ b/control/udp_endpoint.go @@ -22,7 +22,7 @@ type UdpEndpoint struct { conn netproxy.PacketConn // mu protects deadlineTimer mu sync.Mutex - deadlineTimer *time.Timer + deadlineTimer *time.Timer // nil means UdpEndpoint was closed handler UdpHandler NatTimeout time.Duration @@ -48,7 +48,7 @@ func (ue *UdpEndpoint) start() { } } ue.mu.Lock() - ue.deadlineTimer.Stop() + ue.Close() ue.mu.Unlock() } @@ -56,13 +56,15 @@ func (ue *UdpEndpoint) WriteTo(b []byte, addr string) (int, error) { return ue.conn.WriteTo(b, addr) } -func (ue *UdpEndpoint) Close() error { +func (ue *UdpEndpoint) Close() (err error) { ue.mu.Lock() if ue.deadlineTimer != nil { + err = ue.conn.Close() ue.deadlineTimer.Stop() + ue.deadlineTimer = nil } ue.mu.Unlock() - return ue.conn.Close() + return err } // UdpEndpointPool is a full-cone udp conn pool @@ -149,7 +151,9 @@ func (p *UdpEndpointPool) GetOrCreate(lAddr netip.AddrPort, createOption *UdpEnd } else { // Postpone the deadline. ue.mu.Lock() - ue.deadlineTimer.Reset(ue.NatTimeout) + if ue.deadlineTimer != nil { + ue.deadlineTimer.Reset(ue.NatTimeout) + } ue.mu.Unlock() } return ue, isNew, nil