/* * SPDX-License-Identifier: AGPL-3.0-only * Copyright (c) 2022-2024, daeuniverse Organization */ package control import ( "context" "errors" "fmt" "net" "net/netip" "os" "path/filepath" "strconv" "strings" "sync" "syscall" "time" "github.com/bits-and-blooms/bloom/v3" "github.com/cilium/ebpf" "github.com/cilium/ebpf/asm" "github.com/cilium/ebpf/features" "github.com/cilium/ebpf/rlimit" "github.com/daeuniverse/dae/common" "github.com/daeuniverse/dae/common/assets" "github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/netutils" "github.com/daeuniverse/dae/component/dns" "github.com/daeuniverse/dae/component/outbound" "github.com/daeuniverse/dae/component/outbound/dialer" "github.com/daeuniverse/dae/component/routing" "github.com/daeuniverse/dae/config" "github.com/daeuniverse/dae/pkg/config_parser" internal "github.com/daeuniverse/dae/pkg/ebpf_internal" "github.com/daeuniverse/outbound/pool" "github.com/daeuniverse/outbound/protocol/direct" "github.com/daeuniverse/outbound/transport/grpc" "github.com/daeuniverse/outbound/transport/meek" dnsmessage "github.com/miekg/dns" "github.com/mohae/deepcopy" "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) type ControlPlane struct { log *logrus.Logger core *controlPlaneCore deferFuncs []func() error listenIp string // TODO: add mutex? outbounds []*outbound.DialerGroup inConnections sync.Map dnsController *DnsController onceNetworkReady sync.Once dialMode consts.DialMode routingMatcher *RoutingMatcher ctx context.Context cancel context.CancelFunc ready chan struct{} muRealDomainSet sync.Mutex realDomainSet *bloom.BloomFilter wanInterface []string lanInterface []string sniffingTimeout time.Duration tproxyPortProtect bool soMarkFromDae uint32 mptcp bool } func NewControlPlane( log *logrus.Logger, _bpf interface{}, dnsCache map[string]*DnsCache, tagToNodeList map[string][]string, groups []config.Group, routingA *config.Routing, global *config.Global, dnsConfig *config.Dns, externGeoDataDirs []string, ) (*ControlPlane, error) { // TODO: Some users reported that enabling GSO on the client would affect the performance of watching YouTube, so we disabled it by default. if _, ok := os.LookupEnv("QUIC_GO_DISABLE_GSO"); !ok { os.Setenv("QUIC_GO_DISABLE_GSO", "1") } var err error kernelVersion, e := internal.KernelVersion() if e != nil { return nil, fmt.Errorf("failed to get kernel version: %w", e) } /// Check linux kernel requirements. // Check version from high to low to reduce the number of user upgrading kernel. if err := features.HaveProgramHelper(ebpf.SchedCLS, asm.FnLoop); err != nil { return nil, fmt.Errorf("%w: your kernel version %v does not support bpf_loop (needed by routing); expect >=%v; upgrade your kernel and try again", err, kernelVersion.String(), consts.BpfLoopFeatureVersion.String()) } if requirement := consts.ChecksumFeatureVersion; kernelVersion.Less(requirement) { return nil, fmt.Errorf("your kernel version %v does not support checksum related features; expect >=%v; upgrade your kernel and try again", kernelVersion.String(), requirement.String()) } if requirement := consts.BpfTimerFeatureVersion; len(global.WanInterface) > 0 && kernelVersion.Less(requirement) { return nil, fmt.Errorf("your kernel version %v does not support bind to WAN; expect >=%v; remove wan_interface in config file and try again", kernelVersion.String(), requirement.String()) } if requirement := consts.SkAssignFeatureVersion; len(global.LanInterface) > 0 && kernelVersion.Less(requirement) { return nil, fmt.Errorf("your kernel version %v does not support bind to LAN; expect >=%v; remove lan_interface in config file and try again", kernelVersion.String(), requirement.String()) } if kernelVersion.Less(consts.BasicFeatureVersion) { return nil, fmt.Errorf("your kernel version %v does not satisfy basic requirement; expect >=%v", kernelVersion.String(), consts.BasicFeatureVersion.String()) } var deferFuncs []func() error /// Allow the current process to lock memory for eBPF resources. if err = rlimit.RemoveMemlock(); err != nil { return nil, fmt.Errorf("rlimit.RemoveMemlock:%v", err) } InitDaeNetns(log) if err = InitSysctlManager(log); err != nil { return nil, err } if err = GetDaeNetns().Setup(); err != nil { return nil, fmt.Errorf("failed to setup dae netns: %w", err) } pinPath := filepath.Join(consts.BpfPinRoot, consts.AppName) if err = os.MkdirAll(pinPath, 0755); err != nil && !os.IsExist(err) { if os.IsNotExist(err) { log.Warnln("Perhaps you are in a container environment (such as lxc). If so, please use higher virtualization (kvm/qemu).") } return nil, err } /// Load pre-compiled programs and maps into the kernel. if _bpf == nil { log.Infof("Loading eBPF programs and maps into the kernel...") log.Infof("The loading process takes about 120MB free memory, which will be released after loading. Insufficient memory will cause loading failure.") } //var bpf bpfObjects var ProgramOptions = ebpf.ProgramOptions{ KernelTypes: nil, LogSize: ebpf.DefaultVerifierLogSize * 10, } if log.Level == logrus.PanicLevel { ProgramOptions.LogLevel = ebpf.LogLevelBranch | ebpf.LogLevelStats // ProgramOptions.LogLevel = ebpf.LogLevelInstruction | ebpf.LogLevelStats } collectionOpts := &ebpf.CollectionOptions{ Maps: ebpf.MapOptions{ PinPath: pinPath, }, Programs: ProgramOptions, } var bpf *bpfObjects if _bpf != nil { if _bpf, ok := _bpf.(*bpfObjects); ok { bpf = _bpf } else { return nil, fmt.Errorf("unexpected bpf type: %T", _bpf) } } else { bpf = new(bpfObjects) if err = fullLoadBpfObjects(log, bpf, &loadBpfOptions{ PinPath: pinPath, BigEndianTproxyPort: uint32(common.Htons(global.TproxyPort)), CollectionOptions: collectionOpts, }); err != nil { if log.Level == logrus.PanicLevel { log.Panicln(err) } return nil, fmt.Errorf("load eBPF objects: %w", err) } } log.Infof("Loaded eBPF programs and maps") // outboundId2Name can be modified later. outboundId2Name := make(map[uint8]string) core := newControlPlaneCore( log, bpf, outboundId2Name, &kernelVersion, _bpf != nil, ) defer func() { if err != nil { // Flip back. core.Flip() _ = core.Close() } }() /// Bind to links. Binding should be advance of dialerGroups to avoid un-routable old connection. // Bind to LAN if len(global.LanInterface) > 0 { if global.AutoConfigKernelParameter { _ = SetIpv4forward("1") _ = setForwarding("all", consts.IpVersionStr_6, "1") } global.LanInterface = common.Deduplicate(global.LanInterface) for _, ifname := range global.LanInterface { if err = core.bindLan(ifname, global.AutoConfigKernelParameter); err != nil { return nil, fmt.Errorf("bindLan: %v: %w", ifname, err) } } } // Bind to WAN if len(global.WanInterface) > 0 { if err = core.setupSkPidMonitor(); err != nil { log.WithError(err).Warnln("cgroup2 is not enabled; pname routing cannot be used") } if global.EnableLocalTcpFastRedirect { if err = core.setupLocalTcpFastRedirect(); err != nil { log.WithError(err).Warnln("failed to setup local tcp fast redirect") } } for _, ifname := range global.WanInterface { if len(global.LanInterface) > 0 { // FIXME: Code is not elegant here. // bindLan setting conf.ipv6.all.forwarding=1 suppresses accept_ra=1, // thus we set it 2 as a workaround. // See https://sysctl-explorer.net/net/ipv6/accept_ra/ for more information. if global.AutoConfigKernelParameter { acceptRa := sysctl.Keyf("net.ipv6.conf.%v.accept_ra", ifname) val, _ := acceptRa.Get() if val == "1" { _ = acceptRa.Set("2", false) } } } if err = core.bindWan(ifname, global.AutoConfigKernelParameter); err != nil { return nil, fmt.Errorf("bindWan: %v: %w", ifname, err) } } } // Bind to dae0 and dae0peer if err = core.bindDaens(); err != nil { return nil, fmt.Errorf("bindDaens: %w", err) } /// DialerGroups (outbounds). if global.AllowInsecure { log.Warnln("AllowInsecure is enabled, but it is not recommended. Please make sure you have to turn it on.") } option := dialer.NewGlobalOption(global, log) // Dial mode. dialMode, err := consts.ParseDialMode(global.DialMode) if err != nil { return nil, err } sniffingTimeout := global.SniffingTimeout if dialMode == consts.DialMode_Ip { sniffingTimeout = 0 } disableKernelAliveCallback := dialMode != consts.DialMode_Ip _direct, directProperty := dialer.NewDirectDialer(option, true) direct := dialer.NewDialer(_direct, option, dialer.InstanceOption{DisableCheck: true}, directProperty) _block, blockProperty := dialer.NewBlockDialer(option, func() { /*Dialer Outbound*/ }) block := dialer.NewDialer(_block, option, dialer.InstanceOption{DisableCheck: true}, blockProperty) outbounds := []*outbound.DialerGroup{ outbound.NewDialerGroup(option, consts.OutboundDirect.String(), []*dialer.Dialer{direct}, []*dialer.Annotation{{}}, outbound.DialerSelectionPolicy{ Policy: consts.DialerSelectionPolicy_Fixed, FixedIndex: 0, }, core.outboundAliveChangeCallback(0, disableKernelAliveCallback)), outbound.NewDialerGroup(option, consts.OutboundBlock.String(), []*dialer.Dialer{block}, []*dialer.Annotation{{}}, outbound.DialerSelectionPolicy{ Policy: consts.DialerSelectionPolicy_Fixed, FixedIndex: 0, }, core.outboundAliveChangeCallback(1, disableKernelAliveCallback)), } // Filter out groups. // FIXME: Ugly code here: reset grpc and meek clients manually. grpc.CleanGlobalClientConnectionCache() meek.CleanGlobalRoundTripperCache() dialerSet := outbound.NewDialerSetFromLinks(option, tagToNodeList) deferFuncs = append(deferFuncs, dialerSet.Close) for _, group := range groups { // Parse policy. policy, err := outbound.NewDialerSelectionPolicyFromGroupParam(&group) if err != nil { return nil, fmt.Errorf("failed to create group %v: %w", group.Name, err) } // Filter nodes with user given filters. dialers, annos, err := dialerSet.FilterAndAnnotate(group.Filter, group.FilterAnnotation) if err != nil { return nil, fmt.Errorf(`failed to create group "%v": %w`, group.Name, err) } // Convert node links to dialers. log.Infof(`Group "%v" node list:`, group.Name) for _, d := range dialers { log.Infoln("\t" + d.Property().Name) } if len(dialers) == 0 { log.Infoln("\t") } groupOption, err := ParseGroupOverrideOption(group, *global, log) finalOption := option if err == nil && groupOption != nil { newDialers := make([]*dialer.Dialer, 0) for _, d := range dialers { newDialer := d.Clone() deferFuncs = append(deferFuncs, newDialer.Close) newDialer.GlobalOption = groupOption newDialers = append(newDialers, newDialer) } log.Infof(`Group "%v"'s check option has been override.`, group.Name) dialers = newDialers finalOption = groupOption } // Create dialer group and append it to outbounds. dialerGroup := outbound.NewDialerGroup(finalOption, group.Name, dialers, annos, *policy, core.outboundAliveChangeCallback(uint8(len(outbounds)), disableKernelAliveCallback)) outbounds = append(outbounds, dialerGroup) } /// Routing. // Generate outboundName2Id from outbounds. if len(outbounds) > int(consts.OutboundUserDefinedMax) { return nil, fmt.Errorf("too many outbounds") } outboundName2Id := make(map[string]uint8) for i, o := range outbounds { if _, exist := outboundName2Id[o.Name]; exist { return nil, fmt.Errorf("duplicated outbound name: %v", o.Name) } outboundName2Id[o.Name] = uint8(i) outboundId2Name[uint8(i)] = o.Name } // Apply rules optimizers. locationFinder := assets.NewLocationFinder(externGeoDataDirs) var rules []*config_parser.RoutingRule if rules, err = routing.ApplyRulesOptimizers(routingA.Rules, &routing.AliasOptimizer{}, &routing.DatReaderOptimizer{Logger: log, LocationFinder: locationFinder}, &routing.MergeAndSortRulesOptimizer{}, &routing.DeduplicateParamsOptimizer{}, ); err != nil { return nil, fmt.Errorf("ApplyRulesOptimizers error:\n%w", err) } routingA.Rules = nil // Release. if log.IsLevelEnabled(logrus.DebugLevel) { var debugBuilder strings.Builder for _, rule := range rules { debugBuilder.WriteString(rule.String(true, false, false) + "\n") } log.Debugf("RoutingA:\n%vfallback: %v\n", debugBuilder.String(), routingA.Fallback) } // Parse rules and build. builder, err := NewRoutingMatcherBuilder(log, rules, outboundName2Id, core.bpf, routingA.Fallback) if err != nil { return nil, fmt.Errorf("NewRoutingMatcherBuilder: %w", err) } if err = builder.BuildKernspace(log); err != nil { return nil, fmt.Errorf("RoutingMatcherBuilder.BuildKernspace: %w", err) } routingMatcher, err := builder.BuildUserspace() if err != nil { return nil, fmt.Errorf("RoutingMatcherBuilder.BuildUserspace: %w", err) } // New control plane. ctx, cancel := context.WithCancel(context.Background()) plane := &ControlPlane{ log: log, core: core, deferFuncs: deferFuncs, listenIp: "0.0.0.0", outbounds: outbounds, dnsController: nil, onceNetworkReady: sync.Once{}, dialMode: dialMode, routingMatcher: routingMatcher, ctx: ctx, cancel: cancel, ready: make(chan struct{}), muRealDomainSet: sync.Mutex{}, realDomainSet: bloom.NewWithEstimates(2048, 0.001), lanInterface: global.LanInterface, wanInterface: global.WanInterface, sniffingTimeout: sniffingTimeout, tproxyPortProtect: global.TproxyPortProtect, soMarkFromDae: global.SoMarkFromDae, mptcp: global.Mptcp, } defer func() { if err != nil { cancel() } }() /// DNS upstream. dnsUpstream, err := dns.New(dnsConfig, &dns.NewOption{ Logger: log, LocationFinder: locationFinder, UpstreamReadyCallback: plane.dnsUpstreamReadyCallback, UpstreamResolverNetwork: common.MagicNetwork("udp", global.SoMarkFromDae, global.Mptcp), }) if err != nil { return nil, err } /// Dns controller. fixedDomainTtl, err := ParseFixedDomainTtl(dnsConfig.FixedDomainTtl) if err != nil { return nil, err } if plane.dnsController, err = NewDnsController(dnsUpstream, &DnsControllerOption{ Log: log, CacheAccessCallback: func(cache *DnsCache) (err error) { // Write mappings into eBPF map: // IP record (from dns lookup) -> domain routing if err = core.BatchUpdateDomainRouting(cache); err != nil { return fmt.Errorf("BatchUpdateDomainRouting: %w", err) } return nil }, CacheRemoveCallback: func(cache *DnsCache) (err error) { // Write mappings into eBPF map: // IP record (from dns lookup) -> domain routing if err = core.BatchRemoveDomainRouting(cache); err != nil { return fmt.Errorf("BatchUpdateDomainRouting: %w", err) } return nil }, NewCache: func(fqdn string, answers []dnsmessage.RR, deadline time.Time, originalDeadline time.Time) (cache *DnsCache, err error) { return &DnsCache{ DomainBitmap: plane.routingMatcher.domainMatcher.MatchDomainBitmap(fqdn), Answer: answers, Deadline: deadline, OriginalDeadline: originalDeadline, }, nil }, BestDialerChooser: plane.chooseBestDnsDialer, TimeoutExceedCallback: func(dialArgument *dialArgument, err error) { dialArgument.bestDialer.ReportUnavailable(&dialer.NetworkType{ L4Proto: dialArgument.l4proto, IpVersion: dialArgument.ipversion, IsDns: true, }, err) }, IpVersionPrefer: dnsConfig.IpVersionPrefer, FixedDomainTtl: fixedDomainTtl, }); err != nil { return nil, err } // Refresh domain routing cache with new routing. // FIXME: We temperarily disable it because we want to make change of DNS section take effects immediately. // TODO: Add change detection. if false && len(dnsCache) > 0 { for cacheKey, cache := range dnsCache { // Also refresh out-dated routing because kernel map items have no expiration. lastDot := strings.LastIndex(cacheKey, ".") if lastDot == -1 || lastDot == len(cacheKey)-1 { // Not a valid key. log.Warnln("Invalid cache key:", cacheKey) continue } host := cacheKey[:lastDot] _typ := cacheKey[lastDot+1:] typ, err := strconv.ParseUint(_typ, 10, 16) if err != nil { // Unexpected. return nil, err } _ = plane.dnsController.UpdateDnsCacheDeadline(host, uint16(typ), cache.Answer, cache.Deadline) } } else if _bpf != nil { // Is reloading, and dnsCache == nil. // Remove all map items. // Normally, it is due to the change of ip version preference. var key [4]uint32 var val bpfDomainRouting iter := core.bpf.DomainRoutingMap.Iterate() for iter.Next(&key, &val) { _ = core.bpf.DomainRoutingMap.Delete(&key) } } // Init immediately to avoid DNS leaking in the very beginning because param control_plane_dns_routing will // be set in callback. if err = dnsUpstream.CheckUpstreamsFormat(); err != nil { return nil, err } go dnsUpstream.InitUpstreams() close(plane.ready) return plane, nil } func ParseFixedDomainTtl(ks []config.KeyableString) (map[string]int, error) { m := make(map[string]int) for _, k := range ks { key, value, _ := strings.Cut(string(k), ":") ttl, err := strconv.ParseInt(strings.TrimSpace(value), 0, strconv.IntSize) if err != nil { return nil, fmt.Errorf("failed to parse ttl: %v", err) } m[strings.TrimSpace(key)] = int(ttl) } return m, nil } func ParseGroupOverrideOption(group config.Group, global config.Global, log *logrus.Logger) (*dialer.GlobalOption, error) { result := global changed := false if group.TcpCheckUrl != nil { result.TcpCheckUrl = group.TcpCheckUrl changed = true } if group.TcpCheckHttpMethod != "" { result.TcpCheckHttpMethod = group.TcpCheckHttpMethod changed = true } if group.UdpCheckDns != nil { result.UdpCheckDns = group.UdpCheckDns changed = true } if group.CheckInterval != 0 { result.CheckInterval = group.CheckInterval changed = true } if group.CheckTolerance != 0 { result.CheckTolerance = group.CheckTolerance changed = true } if changed { option := dialer.NewGlobalOption(&result, log) return option, nil } return nil, nil } // EjectBpf will resect bpf from destroying life-cycle of control plane. func (c *ControlPlane) EjectBpf() *bpfObjects { return c.core.EjectBpf() } func (c *ControlPlane) InjectBpf(bpf *bpfObjects) { c.core.InjectBpf(bpf) } func (c *ControlPlane) CloneDnsCache() map[string]*DnsCache { c.dnsController.dnsCacheMu.Lock() defer c.dnsController.dnsCacheMu.Unlock() return deepcopy.Copy(c.dnsController.dnsCache).(map[string]*DnsCache) } func (c *ControlPlane) dnsUpstreamReadyCallback(dnsUpstream *dns.Upstream) (err error) { // Waiting for ready. select { case <-c.ctx.Done(): return nil case <-c.ready: } /// Notify dialers to check. c.onceNetworkReady.Do(func() { for _, out := range c.outbounds { for _, d := range out.Dialers { d.NotifyCheck() } } }) if dnsUpstream == nil { return nil } /// Updates dns cache to support domain routing for hostname of dns_upstream. // Ten years later. deadline := time.Now().Add(time.Hour * 24 * 365 * 10) fqdn := dnsmessage.CanonicalName(dnsUpstream.Hostname) if dnsUpstream.Ip4.IsValid() { typ := dnsmessage.TypeA answers := []dnsmessage.RR{&dnsmessage.A{ Hdr: dnsmessage.RR_Header{ Name: dnsmessage.CanonicalName(fqdn), Rrtype: typ, Class: dnsmessage.ClassINET, Ttl: 0, // Must be zero. }, A: dnsUpstream.Ip4.AsSlice(), }} if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ, answers, deadline); err != nil { return err } } if dnsUpstream.Ip6.IsValid() { typ := dnsmessage.TypeAAAA answers := []dnsmessage.RR{&dnsmessage.AAAA{ Hdr: dnsmessage.RR_Header{ Name: dnsmessage.CanonicalName(fqdn), Rrtype: typ, Class: dnsmessage.ClassINET, Ttl: 0, // Must be zero. }, AAAA: dnsUpstream.Ip6.AsSlice(), }} if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ, answers, deadline); err != nil { return err } } return nil } func (c *ControlPlane) ActivateCheck() { for _, g := range c.outbounds { for _, d := range g.Dialers { // We only activate check of nodes that have a group. d.ActivateCheck() } } } func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip.AddrPort, domain string) (dialTarget string, shouldReroute bool, dialIp bool) { dialMode := consts.DialMode_Ip if !outbound.IsReserved() && domain != "" { switch c.dialMode { case consts.DialMode_Domain: if cache := c.dnsController.LookupDnsRespCache(c.dnsController.cacheKey(domain, common.AddrToDnsType(dst.Addr())), true); cache != nil { // Has A/AAAA records. It is a real domain. dialMode = consts.DialMode_Domain } else { // Check if the domain is in real-domain set (bloom filter). c.muRealDomainSet.Lock() if c.realDomainSet.TestString(domain) { c.muRealDomainSet.Unlock() dialMode = consts.DialMode_Domain // Should use this domain to reroute shouldReroute = true } else { c.muRealDomainSet.Unlock() // Lookup A/AAAA to make sure it is a real domain. ctx, cancel := context.WithTimeout(context.TODO(), 5*time.Second) defer cancel() // TODO: use DNS controller and re-route by control plane. systemDns, err := netutils.SystemDns() if err == nil { if ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, domain, common.MagicNetwork("udp", c.soMarkFromDae, c.mptcp), true); err == nil && (ip46.Ip4.IsValid() || ip46.Ip6.IsValid()) { // Has A/AAAA records. It is a real domain. dialMode = consts.DialMode_Domain // Add it to real-domain set. c.muRealDomainSet.Lock() c.realDomainSet.AddString(domain) c.muRealDomainSet.Unlock() // Should use this domain to reroute shouldReroute = true } } } } case consts.DialMode_DomainCao: shouldReroute = true fallthrough case consts.DialMode_DomainPlus: dialMode = consts.DialMode_Domain } } switch dialMode { case consts.DialMode_Ip: dialTarget = dst.String() dialIp = true case consts.DialMode_Domain: if strings.HasPrefix(domain, "[") && strings.HasSuffix(domain, "]") { // Sniffed domain may be like `[2606:4700:20::681a:d1f]`. We should remove the brackets. domain = domain[1 : len(domain)-1] } if _, err := netip.ParseAddr(domain); err == nil { // domain is IPv4 or IPv6 (has colon) dialTarget = net.JoinHostPort(domain, strconv.Itoa(int(dst.Port()))) dialIp = true } else if _, _, err := net.SplitHostPort(domain); err == nil { // domain is already domain:port dialTarget = domain } else { dialTarget = net.JoinHostPort(domain, strconv.Itoa(int(dst.Port()))) } c.log.WithFields(logrus.Fields{ "from": dst.String(), "to": dialTarget, }).Debugln("Rewrite dial target to domain") } return dialTarget, shouldReroute, dialIp } type Listener struct { tcpListener net.Listener packetConn net.PacketConn port uint16 } func (l *Listener) Close() error { var ( err error err2 error ) if err, err2 = l.tcpListener.Close(), l.packetConn.Close(); err2 != nil { if err == nil { err = err2 } else { err = fmt.Errorf("%w: %v", err, err2) } } return err } func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err error) { sentReady := false defer func() { if !sentReady { readyChan <- false } }() udpConn := listener.packetConn.(*net.UDPConn) /// Serve. // TCP socket. tcpFile, err := listener.tcpListener.(*net.TCPListener).File() if err != nil { return fmt.Errorf("failed to retrieve copy of the underlying TCP connection file") } c.deferFuncs = append(c.deferFuncs, func() error { return tcpFile.Close() }) if err := c.core.bpf.ListenSocketMap.Update(consts.ZeroKey, uint64(tcpFile.Fd()), ebpf.UpdateAny); err != nil { return err } // UDP socket. udpFile, err := udpConn.File() if err != nil { return fmt.Errorf("failed to retrieve copy of the underlying UDP connection file") } c.deferFuncs = append(c.deferFuncs, func() error { return udpFile.Close() }) if err := c.core.bpf.ListenSocketMap.Update(consts.OneKey, uint64(udpFile.Fd()), ebpf.UpdateAny); err != nil { return err } sentReady = true readyChan <- true go func() { for { select { case <-c.ctx.Done(): return default: } lconn, err := listener.tcpListener.Accept() if err != nil { if !strings.Contains(err.Error(), "use of closed network connection") { c.log.Errorf("Error when accept: %v", err) } break } go func(lconn net.Conn) { c.inConnections.Store(lconn, struct{}{}) defer c.inConnections.Delete(lconn) if err := c.handleConn(lconn); err != nil { c.log.Warnln("handleConn:", err) } }(lconn) } }() go func() { buf := pool.GetFullCap(consts.EthernetMtu) var oob [120]byte // Size for original dest defer buf.Put() for { select { case <-c.ctx.Done(): return default: } n, oobn, _, src, err := udpConn.ReadMsgUDPAddrPort(buf, oob[:]) if err != nil { if !strings.Contains(err.Error(), "use of closed network connection") { c.log.Errorf("ReadFromUDPAddrPort: %v, %v", src.String(), err) } break } newBuf := pool.Get(n) copy(newBuf, buf[:n]) newOob := pool.Get(oobn) copy(newOob, oob[:oobn]) newSrc := src convergeSrc := common.ConvergeAddrPort(src) // Debug: // t := time.Now() DefaultUdpTaskPool.EmitTask(convergeSrc.String(), func() { data := newBuf oob := newOob src := newSrc defer data.Put() defer oob.Put() var realDst netip.AddrPort var routingResult *bpfRoutingResult pktDst := RetrieveOriginalDest(oob) routingResult, err := c.core.RetrieveRoutingResult(src, pktDst, unix.IPPROTO_UDP) if err != nil { c.log.Warnf("No AddrPort presented: %v", err) return } else { realDst = pktDst } if e := c.handlePkt(udpConn, data, convergeSrc, common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult, false); e != nil { c.log.Warnln("handlePkt:", e) } }) // if d := time.Since(t); d > 100*time.Millisecond { // logrus.Println(d) // } } }() c.ActivateCheck() <-c.ctx.Done() return nil } func (c *ControlPlane) ListenAndServe(readyChan chan<- bool, port uint16) (listener *Listener, err error) { // Listen. var listenConfig = net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { return dialer.TproxyControl(c) }, } listenAddr := net.JoinHostPort(c.listenIp, strconv.Itoa(int(port))) tcpListener, err := listenConfig.Listen(context.TODO(), "tcp", listenAddr) if err != nil { return nil, fmt.Errorf("listenTCP: %w", err) } packetConn, err := listenConfig.ListenPacket(context.TODO(), "udp", listenAddr) if err != nil { _ = tcpListener.Close() return nil, fmt.Errorf("listenUDP: %w", err) } listener = &Listener{ tcpListener: tcpListener, packetConn: packetConn, port: port, } defer func() { if err != nil { _ = listener.Close() } }() // Serve if err = c.Serve(readyChan, listener); err != nil { return nil, fmt.Errorf("failed to serve: %w", err) } return listener, nil } func (c *ControlPlane) chooseBestDnsDialer( req *udpRequest, dnsUpstream *dns.Upstream, ) (*dialArgument, error) { /// Choose the best l4proto+ipversion dialer, and change taregt DNS to the best ipversion DNS upstream for DNS request. // Get available ipversions and l4protos for DNS upstream. ipversions, l4protos := dnsUpstream.SupportedNetworks() var ( bestLatency time.Duration l4proto consts.L4ProtoStr ipversion consts.IpVersionStr bestDialer *dialer.Dialer bestOutbound *outbound.DialerGroup bestTarget netip.AddrPort dialMark uint32 ) // Get the min latency path. networkType := dialer.NetworkType{ IsDns: true, } for _, ver := range ipversions { for _, proto := range l4protos { networkType.L4Proto = proto networkType.IpVersion = ver var dAddr netip.Addr switch ver { case consts.IpVersionStr_4: dAddr = dnsUpstream.Ip4 case consts.IpVersionStr_6: dAddr = dnsUpstream.Ip6 default: return nil, fmt.Errorf("unexpected ipversion: %v", ver) } outboundIndex, mark, _, err := c.Route(req.realSrc, netip.AddrPortFrom(dAddr, dnsUpstream.Port), dnsUpstream.Hostname, proto.ToL4ProtoType(), req.routingResult) if err != nil { return nil, err } if mark == 0 { mark = c.soMarkFromDae } if int(outboundIndex) >= len(c.outbounds) { return nil, fmt.Errorf("bad outbound index: %v", outboundIndex) } dialerGroup := c.outbounds[outboundIndex] // DNS always dial IP. d, latency, err := dialerGroup.Select(&networkType, true) if err != nil { continue } //if c.log.IsLevelEnabled(logrus.TraceLevel) { // c.log.WithFields(logrus.Fields{ // "name": d.Name(), // "latency": latency, // "network": networkType.String(), // "outbound": dialerGroup.Name, // }).Traceln("Choice") //} if bestDialer == nil || latency < bestLatency { bestDialer = d bestOutbound = dialerGroup bestLatency = latency l4proto = proto ipversion = ver dialMark = mark if bestLatency == 0 { break } } } } if bestDialer == nil { return nil, fmt.Errorf("no proper dialer for DNS upstream: %v", dnsUpstream.String()) } switch ipversion { case consts.IpVersionStr_4: bestTarget = netip.AddrPortFrom(dnsUpstream.Ip4, dnsUpstream.Port) case consts.IpVersionStr_6: bestTarget = netip.AddrPortFrom(dnsUpstream.Ip6, dnsUpstream.Port) } if c.log.IsLevelEnabled(logrus.TraceLevel) { c.log.WithFields(logrus.Fields{ "ipversions": ipversions, "l4protos": l4protos, "upstream": dnsUpstream.String(), "choose": string(l4proto) + "+" + string(ipversion), "use": bestTarget.String(), }).Traceln("Choose DNS path") } return &dialArgument{ l4proto: l4proto, ipversion: ipversion, bestDialer: bestDialer, bestOutbound: bestOutbound, bestTarget: bestTarget, mark: dialMark, mptcp: c.mptcp, }, nil } func (c *ControlPlane) AbortConnections() (err error) { var errs []error c.inConnections.Range(func(key, value any) bool { if err = key.(net.Conn).Close(); err != nil { errs = append(errs, err) } return true }) return errors.Join(errs...) } func (c *ControlPlane) Close() (err error) { // Invoke defer funcs in reverse order. for i := len(c.deferFuncs) - 1; i >= 0; i-- { if e := c.deferFuncs[i](); e != nil { // Combine errors. if err != nil { err = fmt.Errorf("%w; %v", err, e) } else { err = e } } } c.cancel() return c.core.Close() }