diff --git a/common/netutils/dns.go b/common/netutils/dns.go index 5813476..204a013 100644 --- a/common/netutils/dns.go +++ b/common/netutils/dns.go @@ -13,8 +13,29 @@ import ( "golang.org/x/net/proxy" "net/netip" "strings" + "sync" + "time" ) +var ( + systemDnsMu sync.Mutex + systemDns netip.AddrPort +) + +func SystemDns() (dns netip.AddrPort, err error) { + systemDnsMu.Lock() + defer systemDnsMu.Unlock() + if !systemDns.IsValid() { + dnsConf := dnsReadConfig("/etc/resolv.conf") + if len(dnsConf.servers) == 0 { + err = fmt.Errorf("no valid dns server in /etc/resolv.conf") + return netip.AddrPort{}, err + } + systemDns = netip.MustParseAddrPort(dnsConf.servers[0]) + } + return systemDns, nil +} + func ResolveNetip(ctx context.Context, d proxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type) (addrs []netip.Addr, err error) { if addr, err := netip.ParseAddr(host); err == nil { if (addr.Is4() || addr.Is4In6()) && typ == dnsmessage.TypeA { @@ -61,7 +82,23 @@ func ResolveNetip(ctx context.Context, d proxy.Dialer, dns netip.AddrPort, host if err != nil { return nil, err } - ch := make(chan error, 1) + ch := make(chan error, 2) + go func() { + // Resend every 3 seconds. + for { + select { + case <-ctx.Done(): + return + default: + time.Sleep(3 * time.Second) + } + _, err := c.Write(b) + if err != nil { + ch <- err + return + } + } + }() go func() { buf := pool.Get(512) n, err := c.Read(buf) diff --git a/common/netutils/dnsconfig_unix.go b/common/netutils/dnsconfig_unix.go new file mode 100644 index 0000000..66e50f0 --- /dev/null +++ b/common/netutils/dnsconfig_unix.go @@ -0,0 +1,196 @@ +// Modified from go1.18/src/net/dnsconfig_unix.go + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris + +// Read system DNS config from /etc/resolv.conf + +package netutils + +import ( + "bufio" + "net" + "net/netip" + "os" + "strconv" + "strings" + "sync/atomic" + "time" +) + +var ( + defaultNS = []string{"127.0.0.1:53", "[::1]:53"} + getHostname = os.Hostname // variable for testing +) + +type dnsConfig struct { + servers []string // server addresses (in host:port form) to use + search []string // rooted suffixes to append to local name + ndots int // number of dots in name to trigger absolute lookup + timeout time.Duration // wait before giving up on a query, including retries + attempts int // lost packets before giving up on server + rotate bool // round robin among servers + unknownOpt bool // anything unknown was encountered + lookup []string // OpenBSD top-level database "lookup" order + err error // any error that occurs during open of resolv.conf + mtime time.Time // time of resolv.conf modification + soffset uint32 // used by serverOffset + singleRequest bool // use sequential A and AAAA queries instead of parallel queries + useTCP bool // force usage of TCP for DNS resolutions +} + +// See resolv.conf(5) on a Linux machine. +func dnsReadConfig(filename string) *dnsConfig { + conf := &dnsConfig{ + ndots: 1, + timeout: 5 * time.Second, + attempts: 2, + } + file, err := os.Open(filename) + if err != nil { + conf.servers = defaultNS + conf.search = dnsDefaultSearch() + conf.err = err + return conf + } + defer file.Close() + if fi, err := file.Stat(); err == nil { + conf.mtime = fi.ModTime() + } else { + conf.servers = defaultNS + conf.search = dnsDefaultSearch() + conf.err = err + return conf + } + fio := bufio.NewReader(file) + for line, _, err := fio.ReadLine(); err == nil; line, _, err = fio.ReadLine() { + if len(line) > 0 && (line[0] == ';' || line[0] == '#') { + // comment. + continue + } + f := strings.Fields(string(line)) + if len(f) < 1 { + continue + } + switch f[0] { + case "nameserver": // add one name server + if len(f) > 1 && len(conf.servers) < 3 { // small, but the standard limit + // One more check: make sure server name is + // just an IP address. Otherwise we need DNS + // to look it up. + if _, e := netip.ParseAddr(f[1]); e == nil { + conf.servers = append(conf.servers, net.JoinHostPort(f[1], "53")) + } + } + + case "domain": // set search path to just this domain + if len(f) > 1 { + conf.search = []string{ensureRooted(f[1])} + } + + case "search": // set search path to given servers + conf.search = make([]string, len(f)-1) + for i := 0; i < len(conf.search); i++ { + conf.search[i] = ensureRooted(f[i+1]) + } + + case "options": // magic options + for _, s := range f[1:] { + switch { + case hasPrefix(s, "ndots:"): + n, _ := strconv.Atoi(s[6:]) + if n < 0 { + n = 0 + } else if n > 15 { + n = 15 + } + conf.ndots = n + case hasPrefix(s, "timeout:"): + n, _ := strconv.Atoi(s[8:]) + if n < 1 { + n = 1 + } + conf.timeout = time.Duration(n) * time.Second + case hasPrefix(s, "attempts:"): + n, _ := strconv.Atoi(s[9:]) + if n < 1 { + n = 1 + } + conf.attempts = n + case s == "rotate": + conf.rotate = true + case s == "single-request" || s == "single-request-reopen": + // Linux option: + // http://man7.org/linux/man-pages/man5/resolv.conf.5.html + // "By default, glibc performs IPv4 and IPv6 lookups in parallel [...] + // This option disables the behavior and makes glibc + // perform the IPv6 and IPv4 requests sequentially." + conf.singleRequest = true + case s == "use-vc" || s == "usevc" || s == "tcp": + // Linux (use-vc), FreeBSD (usevc) and OpenBSD (tcp) option: + // http://man7.org/linux/man-pages/man5/resolv.conf.5.html + // "Sets RES_USEVC in _res.options. + // This option forces the use of TCP for DNS resolutions." + // https://www.freebsd.org/cgi/man.cgi?query=resolv.conf&sektion=5&manpath=freebsd-release-ports + // https://man.openbsd.org/resolv.conf.5 + conf.useTCP = true + default: + conf.unknownOpt = true + } + } + + case "lookup": + // OpenBSD option: + // https://www.openbsd.org/cgi-bin/man.cgi/OpenBSD-current/man5/resolv.conf.5 + // "the legal space-separated values are: bind, file, yp" + conf.lookup = f[1:] + + default: + conf.unknownOpt = true + } + } + if len(conf.servers) == 0 { + conf.servers = defaultNS + } + if len(conf.search) == 0 { + conf.search = dnsDefaultSearch() + } + return conf +} + +// serverOffset returns an offset that can be used to determine +// indices of servers in c.servers when making queries. +// When the rotate option is enabled, this offset increases. +// Otherwise it is always 0. +func (c *dnsConfig) serverOffset() uint32 { + if c.rotate { + return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start + } + return 0 +} + +func dnsDefaultSearch() []string { + hn, err := getHostname() + if err != nil { + // best effort + return nil + } + if i := strings.IndexByte(hn, '.'); i >= 0 && i < len(hn)-1 { + return []string{ensureRooted(hn[i+1:])} + } + return nil +} + +func hasPrefix(s, prefix string) bool { + return len(s) >= len(prefix) && s[:len(prefix)] == prefix +} + +func ensureRooted(s string) string { + if len(s) > 0 && s[len(s)-1] == '.' { + return s + } + return s + "." +} diff --git a/common/netutils/ip46.go b/common/netutils/ip46.go new file mode 100644 index 0000000..69ce023 --- /dev/null +++ b/common/netutils/ip46.go @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: AGPL-3.0-only + * Copyright (c) since 2023, mzz2017 + */ + +package netutils + +import ( + "context" + "fmt" + "golang.org/x/net/dns/dnsmessage" + "golang.org/x/net/proxy" + "net/netip" +) + +type Ip46 struct { + Ip4 netip.Addr + Ip6 netip.Addr +} + +func ParseIp46(ctx context.Context, dialer proxy.Dialer, dns netip.AddrPort, host string, must46 bool) (ipv46 *Ip46, err error) { + addrs4, err := ResolveNetip(ctx, dialer, dns, host, dnsmessage.TypeA) + if err != nil { + return nil, err + } + if len(addrs4) == 0 && must46 { + if must46 { + return nil, fmt.Errorf("domain \"%v\" has no ipv4 record", host) + } else { + addrs4 = []netip.Addr{{}} + } + } + addrs6, err := ResolveNetip(ctx, dialer, dns, host, dnsmessage.TypeAAAA) + if err != nil { + return nil, err + } + if len(addrs6) == 0 { + if must46 { + return nil, fmt.Errorf("domain \"%v\" has no ipv6 record", host) + } else { + addrs6 = []netip.Addr{{}} + } + } + return &Ip46{ + Ip4: addrs4[0], + Ip6: addrs6[0], + }, nil +} diff --git a/component/outbound/dialer/alive_dialer_set.go b/component/outbound/dialer/alive_dialer_set.go index 78346cd..1cda121 100644 --- a/component/outbound/dialer/alive_dialer_set.go +++ b/component/outbound/dialer/alive_dialer_set.go @@ -83,8 +83,8 @@ func (a *AliveDialerSet) GetRand() *Dialer { } // GetMinLatency acquires correct selectionPolicy. -func (a *AliveDialerSet) GetMinLatency() *Dialer { - return a.minLatency.dialer +func (a *AliveDialerSet) GetMinLatency() (d *Dialer, latency time.Duration) { + return a.minLatency.dialer, a.minLatency.latency } // NotifyLatencyChange should be invoked when dialer every time latency and alive state changes. diff --git a/component/outbound/dialer/connectivity_check.go b/component/outbound/dialer/connectivity_check.go index b9f7654..9323176 100644 --- a/component/outbound/dialer/connectivity_check.go +++ b/component/outbound/dialer/connectivity_check.go @@ -24,10 +24,6 @@ import ( "time" ) -var ( - BootstrapDns = netip.MustParseAddrPort("223.5.5.5:53") -) - type collection struct { // AliveDialerSetSet uses reference counting. AliveDialerSetSet AliveDialerSetSet @@ -71,43 +67,22 @@ func (d *Dialer) MustGetAlive(l4proto consts.L4ProtoStr, ipversion consts.IpVers return d.mustGetCollection(l4proto, ipversion).Alive } -type Ip46 struct { - Ip4 netip.Addr - Ip6 netip.Addr -} - -func ParseIp46(ctx context.Context, host string) (ipv46 *Ip46, err error) { - addrs4, err := netutils.ResolveNetip(ctx, SymmetricDirect, BootstrapDns, host, dnsmessage.TypeA) - if err != nil { - return nil, err - } - if len(addrs4) == 0 { - return nil, fmt.Errorf("domain \"%v\" has no ipv4 record", host) - } - addrs6, err := netutils.ResolveNetip(ctx, SymmetricDirect, BootstrapDns, host, dnsmessage.TypeAAAA) - if err != nil { - return nil, err - } - if len(addrs6) == 0 { - return nil, fmt.Errorf("domain \"%v\" has no ipv6 record", host) - } - return &Ip46{ - Ip4: addrs4[0], - Ip6: addrs6[0], - }, nil -} - type TcpCheckOption struct { Url *netutils.URL - *Ip46 + *netutils.Ip46 } func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOption, err error) { + systemDns, err := netutils.SystemDns() + if err != nil { + return nil, err + } + u, err := url.Parse(rawURL) if err != nil { return nil, err } - ip46, err := ParseIp46(ctx, u.Hostname()) + ip46, err := netutils.ParseIp46(ctx, SymmetricDirect, systemDns, u.Hostname(), true) if err != nil { return nil, err } @@ -120,10 +95,15 @@ func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOptio type UdpCheckOption struct { DnsHost string DnsPort uint16 - *Ip46 + *netutils.Ip46 } func ParseUdpCheckOption(ctx context.Context, dnsHostPort string) (opt *UdpCheckOption, err error) { + systemDns, err := netutils.SystemDns() + if err != nil { + return nil, err + } + host, _port, err := net.SplitHostPort(dnsHostPort) if err != nil { return nil, err @@ -132,7 +112,7 @@ func ParseUdpCheckOption(ctx context.Context, dnsHostPort string) (opt *UdpCheck if err != nil { return nil, fmt.Errorf("bad port: %v", err) } - ip46, err := ParseIp46(ctx, host) + ip46, err := netutils.ParseIp46(ctx, SymmetricDirect, systemDns, host, true) if err != nil { return nil, err } diff --git a/component/outbound/dialer_group.go b/component/outbound/dialer_group.go index df6714e..c87b1ff 100644 --- a/component/outbound/dialer_group.go +++ b/component/outbound/dialer_group.go @@ -15,6 +15,7 @@ import ( "net" "net/netip" "strings" + "time" ) type DialerGroup struct { @@ -95,9 +96,9 @@ func (g *DialerGroup) SetSelectionPolicy(policy DialerSelectionPolicy) { } // Select selects a dialer from group according to selectionPolicy. -func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersionStr) (*dialer.Dialer, error) { +func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersionStr) (d *dialer.Dialer, latency time.Duration, err error) { if len(g.Dialers) == 0 { - return nil, fmt.Errorf("no dialer in this group") + return nil, 0, fmt.Errorf("no dialer in this group") } var a *dialer.AliveDialerSet switch l4proto { @@ -116,7 +117,7 @@ func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersi a = g.AliveUdp6DialerSet } default: - return nil, fmt.Errorf("DialerGroup.Select: unexpected l4proto type: %v", l4proto) + return nil, 0, fmt.Errorf("DialerGroup.Select: unexpected l4proto type: %v", l4proto) } switch g.selectionPolicy.Policy { @@ -128,30 +129,30 @@ func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersi "network": string(l4proto) + string(ipversion), "group": g.Name, }).Warnf("No alive dialer in DialerGroup, use \"block\".") - return g.block, nil + return g.block, 0, nil } - return d, nil + return d, 0, nil case consts.DialerSelectionPolicy_Fixed: if g.selectionPolicy.FixedIndex < 0 || g.selectionPolicy.FixedIndex >= len(g.Dialers) { - return nil, fmt.Errorf("selected dialer index is out of range") + return nil, 0, fmt.Errorf("selected dialer index is out of range") } - return g.Dialers[g.selectionPolicy.FixedIndex], nil + return g.Dialers[g.selectionPolicy.FixedIndex], 0, nil case consts.DialerSelectionPolicy_MinLastLatency, consts.DialerSelectionPolicy_MinAverage10Latencies: - d := a.GetMinLatency() + d, latency := a.GetMinLatency() if d == nil { // No alive dialer. g.log.WithFields(logrus.Fields{ "network": string(l4proto) + string(ipversion), "group": g.Name, }).Warnf("No alive dialer in DialerGroup, use \"block\".") - return g.block, nil + return g.block, 0, nil } - return d, nil + return d, latency, nil default: - return nil, fmt.Errorf("unsupported DialerSelectionPolicy: %v", g.selectionPolicy) + return nil, 0, fmt.Errorf("unsupported DialerSelectionPolicy: %v", g.selectionPolicy) } } @@ -164,9 +165,9 @@ func (g *DialerGroup) Dial(network string, addr string) (c net.Conn, err error) ipversion := consts.IpVersionFromAddr(ipAddr) switch { case strings.HasPrefix(network, "tcp"): - d, err = g.Select(consts.L4ProtoStr_TCP, ipversion) + d, _, err = g.Select(consts.L4ProtoStr_TCP, ipversion) case strings.HasPrefix(network, "udp"): - d, err = g.Select(consts.L4ProtoStr_UDP, ipversion) + d, _, err = g.Select(consts.L4ProtoStr_UDP, ipversion) default: return nil, fmt.Errorf("unexpected network: %v", network) } diff --git a/component/outbound/dialer_group_test.go b/component/outbound/dialer_group_test.go index 61916c6..f8236f0 100644 --- a/component/outbound/dialer_group_test.go +++ b/component/outbound/dialer_group_test.go @@ -44,9 +44,9 @@ func TestDialerGroup_Select_Fixed(t *testing.T) { g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{ Policy: consts.DialerSelectionPolicy_Fixed, FixedIndex: fixedIndex, - }) + }, func(alive bool, l4proto uint8, ipversion uint8) {}) for i := 0; i < 10; i++ { - d, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4) + d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4) if err != nil { t.Fatal(err) } @@ -58,7 +58,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) { fixedIndex = 0 g.selectionPolicy.FixedIndex = fixedIndex for i := 0; i < 10; i++ { - d, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4) + d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4) if err != nil { t.Fatal(err) } @@ -98,7 +98,7 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) { } g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{ Policy: consts.DialerSelectionPolicy_MinLastLatency, - }) + }, func(alive bool, l4proto uint8, ipversion uint8) {}) // Test 1000 times. for i := 0; i < 1000; i++ { @@ -127,7 +127,7 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) { } g.AliveTcp4DialerSet.NotifyLatencyChange(d, alive) } - d, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4) + d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4) if err != nil { t.Fatal(err) } @@ -170,10 +170,10 @@ func TestDialerGroup_Select_Random(t *testing.T) { } g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{ Policy: consts.DialerSelectionPolicy_Random, - }) + }, func(alive bool, l4proto uint8, ipversion uint8) {}) count := make([]int, len(dialers)) for i := 0; i < 100; i++ { - d, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4) + d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4) if err != nil { t.Fatal(err) } @@ -217,12 +217,12 @@ func TestDialerGroup_SetAlive(t *testing.T) { } g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{ Policy: consts.DialerSelectionPolicy_Random, - }) + }, func(alive bool, l4proto uint8, ipversion uint8) {}) zeroTarget := 3 g.AliveTcp4DialerSet.NotifyLatencyChange(dialers[zeroTarget], false) count := make([]int, len(dialers)) for i := 0; i < 100; i++ { - d, err := g.Select(consts.L4ProtoStr_UDP, consts.IpVersionStr_4) + d, _, err := g.Select(consts.L4ProtoStr_UDP, consts.IpVersionStr_4) if err != nil { t.Fatal(err) } diff --git a/control/control_plane.go b/control/control_plane.go index a9c898e..af5eb3e 100644 --- a/control/control_plane.go +++ b/control/control_plane.go @@ -20,6 +20,7 @@ import ( "github.com/v2rayA/dae/config" "github.com/v2rayA/dae/pkg/config_parser" internal "github.com/v2rayA/dae/pkg/ebpf_internal" + "golang.org/x/net/dns/dnsmessage" "golang.org/x/sys/unix" "net" "net/netip" @@ -50,7 +51,7 @@ type ControlPlane struct { // mutex protects the dnsCache. mutex sync.Mutex dnsCache map[string]*dnsCache - dnsUpstream netip.AddrPort + dnsUpstream *DnsUpstraem } func NewControlPlane( @@ -189,7 +190,7 @@ func NewControlPlane( } /// DialerGroups (outbounds). - ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.TODO(), 30*time.Second) defer cancel() tcpCheckOption, err := dialer.ParseTcpCheckOption(ctx, global.TcpCheckUrl) if err != nil { @@ -285,22 +286,78 @@ func NewControlPlane( } /// DNS upstream. - var dnsAddrPort netip.AddrPort + var dnsUpstream *DnsUpstraem if !global.DnsUpstream.Empty { - if dnsAddrPort, err = resolveDnsUpstream(global.DnsUpstream.Url); err != nil { + if dnsUpstream, err = ResolveDnsUpstream(ctx, global.DnsUpstream.Url); err != nil { return nil, err } - dnsAddr16 := dnsAddrPort.Addr().As16() - if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfIpPort{ - Ip: common.Ipv6ByteSliceToUint32Array(dnsAddr16[:]), - Port: internal.Htons(dnsAddrPort.Port()), + ip4in6 := dnsUpstream.Ip4.As16() + ip6 := dnsUpstream.Ip6.As16() + if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{ + Ip4: common.Ipv6ByteSliceToUint32Array(ip4in6[:]), + Ip6: common.Ipv6ByteSliceToUint32Array(ip6[:]), + HasIp4: dnsUpstream.Ip4.IsValid(), + HasIp6: dnsUpstream.Ip6.IsValid(), + Port: internal.Htons(dnsUpstream.Port), }, ebpf.UpdateAny); err != nil { return nil, err } + defer func() { + // Update dns cache to support domain routing for hostname of dns_upstream. + if err == nil { + // Ten years later. + deadline := time.Now().Add(24 * time.Hour * 365 * 10) + fqdn := dnsUpstream.Hostname + if !strings.HasSuffix(fqdn, ".") { + fqdn = fqdn + "." + } + + if dnsUpstream.Ip4.IsValid() { + typ := dnsmessage.TypeA + answers := []dnsmessage.Resource{{ + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName(fqdn), + Type: typ, + Class: dnsmessage.ClassINET, + TTL: 0, // Must be zero. + }, + Body: &dnsmessage.AResource{ + A: dnsUpstream.Ip4.As4(), + }, + }} + if err = c.UpdateDnsCache(fqdn, typ, answers, deadline); err != nil { + c = nil + return + } + } + + if dnsUpstream.Ip6.IsValid() { + typ := dnsmessage.TypeAAAA + answers := []dnsmessage.Resource{{ + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName(fqdn), + Type: typ, + Class: dnsmessage.ClassINET, + TTL: 0, // Must be zero. + }, + Body: &dnsmessage.AAAAResource{ + AAAA: dnsUpstream.Ip6.As16(), + }, + }} + if err = c.UpdateDnsCache(fqdn, typ, answers, deadline); err != nil { + c = nil + return + } + } + } + }() } else { - // Empty. - if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfIpPort{ - Ip: [4]uint32{}, + // Empty string. As-is. + if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{ + Ip4: [4]uint32{}, + Ip6: [4]uint32{}, + HasIp4: false, + HasIp6: false, // Zero port indicates no element, because bpf_map_lookup_elem cannot return 0 for map_type_array. Port: 0, }, ebpf.UpdateAny); err != nil { @@ -325,7 +382,7 @@ func NewControlPlane( Final: routingA.Final, mutex: sync.Mutex{}, dnsCache: make(map[string]*dnsCache), - dnsUpstream: dnsAddrPort, + dnsUpstream: dnsUpstream, }, nil } diff --git a/control/control_utils.go b/control/control_utils.go deleted file mode 100644 index 04db246..0000000 --- a/control/control_utils.go +++ /dev/null @@ -1,34 +0,0 @@ -/* - * SPDX-License-Identifier: AGPL-3.0-only - * Copyright (c) since 2023, mzz2017 - */ - -package control - -import ( - "fmt" - "net" - "net/netip" - "net/url" -) - -func resolveDnsUpstream(dnsUpstream *url.URL) (addrPort netip.AddrPort, err error) { - if dnsUpstream.Scheme != "udp" { - return netip.AddrPort{}, fmt.Errorf("dns_upstream now only supports udp://") - } - port := dnsUpstream.Port() - if port == "" { - port = "53" - } - hostname := dnsUpstream.Hostname() - ips, _ := net.LookupIP(hostname) - if len(ips) == 0 { - return netip.AddrPort{}, fmt.Errorf("cannot resolve hostname of dns upstream: %v", hostname) - } - // resolve hostname - dnsAddrPort, err := netip.ParseAddrPort(net.JoinHostPort(ips[0].String(), port)) - if err != nil { - return netip.AddrPort{}, fmt.Errorf("failed to parse DNS upstream: \"%v\": %w", dnsUpstream.String(), err) - } - return dnsAddrPort, nil -} diff --git a/control/dns.go b/control/dns.go index 86ac260..f222fe9 100644 --- a/control/dns.go +++ b/control/dns.go @@ -282,27 +282,36 @@ loop: "auth": FormatDnsRsc(msg.Authorities), "addi": FormatDnsRsc(msg.Additionals), }).Tracef("Update DNS record cache") + if err = c.UpdateDnsCache(q.Name.String(), q.Type, msg.Answers, time.Now().Add(time.Duration(ttl)*time.Second+DnsNatTimeout)); err != nil { + return nil, err + } + // Pack to get newData. + return msg.Pack() +} + +func (c *ControlPlane) UpdateDnsCache(host string, typ dnsmessage.Type, answers []dnsmessage.Resource, deadline time.Time) (err error) { c.mutex.Lock() - fqdn := strings.ToLower(q.Name.String()) - cacheKey := fqdn + q.Type.String() + fqdn := strings.ToLower(host) + if !strings.HasSuffix(fqdn, ".") { + fqdn += "." + } + cacheKey := fqdn + typ.String() cache, ok := c.dnsCache[cacheKey] if ok { c.mutex.Unlock() - cache.Deadline = time.Now().Add(time.Duration(ttl)*time.Second + DnsNatTimeout) - cache.Answers = msg.Answers + cache.Deadline = deadline + cache.Answers = answers } else { cache = &dnsCache{ DomainBitmap: c.MatchDomainBitmap(strings.TrimSuffix(fqdn, ".")), - Answers: msg.Answers, - Deadline: time.Now().Add(time.Duration(ttl)*time.Second + DnsNatTimeout), + Answers: answers, + Deadline: deadline, } c.dnsCache[cacheKey] = cache c.mutex.Unlock() } if err = c.BatchUpdateDomainRouting(cache); err != nil { - return nil, fmt.Errorf("BatchUpdateDomainRouting: %w", err) + return fmt.Errorf("BatchUpdateDomainRouting: %w", err) } - - // Pack to get newData. - return msg.Pack() + return nil } diff --git a/control/dns_upstream.go b/control/dns_upstream.go new file mode 100644 index 0000000..928453d --- /dev/null +++ b/control/dns_upstream.go @@ -0,0 +1,89 @@ +/* + * SPDX-License-Identifier: AGPL-3.0-only + * Copyright (c) since 2023, mzz2017 + */ + +package control + +import ( + "context" + "fmt" + "github.com/v2rayA/dae/common/consts" + "github.com/v2rayA/dae/common/netutils" + "github.com/v2rayA/dae/component/outbound/dialer" + "net/url" + "strconv" +) + +type DnsUpstreamScheme string + +const ( + DnsUpstreamScheme_TCP DnsUpstreamScheme = "tcp" + DnsUpstreamScheme_UDP DnsUpstreamScheme = "udp" + DnsUpstreamScheme_TCP_UDP DnsUpstreamScheme = "tcp+udp" +) + +type DnsUpstraem struct { + Scheme DnsUpstreamScheme + Hostname string + Port uint16 + *netutils.Ip46 +} + +func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstraem, err error) { + var _port string + switch DnsUpstreamScheme(dnsUpstream.Scheme) { + case DnsUpstreamScheme_TCP, DnsUpstreamScheme_UDP, DnsUpstreamScheme_TCP_UDP: + _port = dnsUpstream.Port() + if _port == "" { + _port = "53" + } + default: + return nil, fmt.Errorf("dns_upstream now only supports auto://, udp://, tcp:// and empty string (as-is)") + } + + systemDns, err := netutils.SystemDns() + if err != nil { + return nil, err + } + port, err := strconv.ParseUint(dnsUpstream.Port(), 10, 16) + if err != nil { + return nil, fmt.Errorf("parse dns_upstream port: %v", err) + } + hostname := dnsUpstream.Hostname() + ip46, err := netutils.ParseIp46(ctx, dialer.SymmetricDirect, systemDns, hostname, false) + if err != nil { + return nil, fmt.Errorf("failed to resolve dns_upstream") + } + if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() { + return nil, fmt.Errorf("dns_upstream has no record") + } + return &DnsUpstraem{ + Scheme: DnsUpstreamScheme(dnsUpstream.Scheme), + Hostname: hostname, + Port: uint16(port), + Ip46: ip46, + }, nil +} + +func (u *DnsUpstraem) SupportedNetworks() (ipversions []consts.IpVersionStr, l4protos []consts.L4ProtoStr) { + if u.Ip4.IsValid() && u.Ip6.IsValid() { + ipversions = []consts.IpVersionStr{consts.IpVersionStr_4, consts.IpVersionStr_6} + } else { + if u.Ip4.IsValid() { + ipversions = []consts.IpVersionStr{consts.IpVersionStr_4} + } else { + ipversions = []consts.IpVersionStr{consts.IpVersionStr_6} + } + } + switch u.Scheme { + case DnsUpstreamScheme_TCP: + l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_TCP} + case DnsUpstreamScheme_UDP: + l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP} + case DnsUpstreamScheme_TCP_UDP: + // UDP first. + l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP, consts.L4ProtoStr_TCP} + } + return ipversions, l4protos +} diff --git a/control/kern/tproxy.c b/control/kern/tproxy.c index bb798a8..b3aca27 100644 --- a/control/kern/tproxy.c +++ b/control/kern/tproxy.c @@ -74,6 +74,16 @@ enum { DisableL4ChecksumPolicy_SetZero, }; +// Param keys: +static const __u32 zero_key = 0; +static const __u32 tproxy_port_key = 1; +static const __u32 one_key = 1; +static const __u32 disable_l4_tx_checksum_key + __attribute__((unused, deprecated)) = 2; +static const __u32 disable_l4_rx_checksum_key + __attribute__((unused, deprecated)) = 3; +static const __u32 control_plane_pid_key = 4; + // Outbound Connectivity Map: struct outbound_connectivity_query { @@ -97,15 +107,8 @@ struct { __uint(max_entries, 2); } listen_socket_map SEC(".maps"); -// Param keys: -static const __u32 zero_key = 0; -static const __u32 tproxy_port_key = 1; -static const __u32 one_key = 1; -static const __u32 disable_l4_tx_checksum_key - __attribute__((unused, deprecated)) = 2; -static const __u32 disable_l4_rx_checksum_key - __attribute__((unused, deprecated)) = 3; -static const __u32 control_plane_pid_key = 4; +/// TODO: Remove items from the dst_map by conntrack. +// Dest map: struct ip_port { __be32 ip[4]; @@ -125,8 +128,6 @@ struct tuples { __u8 l4proto; }; -/// TODO: Remove items from the dst_map by conntrack. -// Dest map: struct { __uint(type, BPF_MAP_TYPE_LRU_HASH); __type(key, @@ -195,10 +196,19 @@ struct { } ipproto_hdrsize_map SEC(".maps"); // Dns upstream: + +struct dns_upstream { + __be32 ip4[4]; + __be32 ip6[4]; + bool hasIp4; + bool hasIp6; + __be16 port; +}; struct { __uint(type, BPF_MAP_TYPE_ARRAY); __type(key, __u32); - __type(value, struct ip_port); + __type(value, struct dns_upstream); + /// FIXME: l4proto is always udp. __uint(max_entries, 1); } dns_upstream_map SEC(".maps"); @@ -974,11 +984,22 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4], // Modify DNS upstream for routing. if (h_dport == 53 && _l4proto_type == L4ProtoType_UDP) { - struct ip_port *upstream = + struct dns_upstream *upstream = bpf_map_lookup_elem(&dns_upstream_map, &zero_key); if (upstream && upstream->port != 0) { h_dport = bpf_ntohs(upstream->port); - __builtin_memcpy(daddr, upstream->ip, IPV6_BYTE_LENGTH); + if (_ipversion_type == IpVersionType_4 && upstream->hasIp4) { + __builtin_memcpy(daddr, upstream->ip4, IPV6_BYTE_LENGTH); + } else if (_ipversion_type == IpVersionType_6 && upstream->hasIp6) { + __builtin_memcpy(daddr, upstream->ip6, IPV6_BYTE_LENGTH); + } else if (upstream->hasIp4) { + __builtin_memcpy(daddr, upstream->ip4, IPV6_BYTE_LENGTH); + } else if (upstream->hasIp6) { + __builtin_memcpy(daddr, upstream->ip6, IPV6_BYTE_LENGTH); + } else { + bpf_printk("bad dns upstream; use as-is."); + __builtin_memcpy(daddr, _daddr, IPV6_BYTE_LENGTH); + } } else { __builtin_memcpy(daddr, _daddr, IPV6_BYTE_LENGTH); } diff --git a/control/tcp.go b/control/tcp.go index 77ba399..f5db3d7 100644 --- a/control/tcp.go +++ b/control/tcp.go @@ -60,7 +60,7 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) { } l4proto := consts.L4ProtoStr_TCP ipversion := consts.IpVersionFromAddr(dst.Addr()) - dialer, err := outbound.Select(l4proto, ipversion) + dialer, _, err := outbound.Select(l4proto, ipversion) if err != nil { return fmt.Errorf("failed to select dialer from group %v: %w", outbound.Name, err) } diff --git a/control/udp.go b/control/udp.go index cd756bc..264e842 100644 --- a/control/udp.go +++ b/control/udp.go @@ -14,6 +14,7 @@ import ( "github.com/v2rayA/dae/common/consts" "github.com/v2rayA/dae/component/outbound/dialer" "golang.org/x/net/dns/dnsmessage" + "io" "net" "net/netip" "strings" @@ -94,11 +95,12 @@ func sendPktBind(data []byte, from netip.AddrPort, to netip.AddrPort) error { return err } -func (c *ControlPlane) RelayToUDP(to netip.AddrPort, isDNS bool, dummyFrom *netip.AddrPort, validateRushAns bool) UdpHandler { +func (c *ControlPlane) WriteToUDP(to netip.AddrPort, isDNS bool, dummyFrom *netip.AddrPort, validateRushAnsFunc func(from netip.AddrPort) bool) UdpHandler { return func(data []byte, from netip.AddrPort) (err error) { // Do not return conn-unrelated err in this func. if isDNS { + validateRushAns := validateRushAnsFunc(from) data, err = c.DnsRespHandler(data, validateRushAns) if err != nil { if validateRushAns && errors.Is(err, SuspectedRushAnswerError) { @@ -158,15 +160,6 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI return nil } - // Need to make a DNS request. - if c.dnsUpstream.IsValid() { - c.log.Tracef("Modify dns target %v to upstream: %v", RefineAddrPortToShow(destToSend), c.dnsUpstream) - // Modify dns target to upstream. - // NOTICE: Routing was calculated in advance by the eBPF program. - dummyFrom = &dst - destToSend = c.dnsUpstream - } - // Flip dns question to reduce dns pollution. FlipDnsQuestionCase(dnsMessage) // Make sure there is additional record OPT in the request to filter DNS rush-answer in the response process. @@ -180,51 +173,167 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI } } - // We only validate rush-ans when outbound is direct and pkt does not send to a home device. - // Because additional record OPT may not be supported by home router. - // So se should trust home devices even if they make rush-answer (or looks like). - validateRushAns := outboundIndex == consts.OutboundDirect && !destToSend.Addr().IsPrivate() - - // Get udp endpoint. l4proto := consts.L4ProtoStr_UDP ipversion := consts.IpVersionFromAddr(dst.Addr()) -getNew: - ue, isNew, err := DefaultUdpEndpointPool.GetOrCreate(src, &UdpEndpointOptions{ - Handler: c.RelayToUDP(src, isDns, dummyFrom, validateRushAns), - NatTimeout: natTimeout, - DialerFunc: func() (*dialer.Dialer, error) { - newDialer, err := outbound.Select(l4proto, ipversion) - if err != nil { - return nil, fmt.Errorf("failed to select dialer from group %v: %w", outbound.Name, err) - } - return newDialer, nil - }, - Target: destToSend, - }) - if err != nil { - return fmt.Errorf("failed to GetOrCreate: %w", err) - } - // If the udp endpoint has been not alive, remove it from pool and get a new one. - if !isNew && !ue.Dialer.MustGetAlive(l4proto, ipversion) { - c.log.WithFields(logrus.Fields{ - "src": src.String(), - "network": string(l4proto) + string(ipversion), - "dialer": ue.Dialer.Name(), - }).Debugln("Old udp endpoint is not alive and removed") - _ = DefaultUdpEndpointPool.Remove(src, ue) - goto getNew - } - // This is real dialer. - d := ue.Dialer + var dialerForNew *dialer.Dialer - if isNew { + // For DNS request, modify dst to dns upstream. + // NOTICE: We might modify l4proto and ipversion. + if isDns && c.dnsUpstream != nil { + // Modify dns target to upstream. + // NOTICE: Routing was calculated in advance by the eBPF program. + + /// Choose the best l4proto and ipversion. + // Get available ipversions and l4protos for DNS upstream. + ipversions, l4protos := c.dnsUpstream.SupportedNetworks() + var ( + bestDialer *dialer.Dialer + bestLatency time.Duration + bestTarget netip.AddrPort + ) + c.log.WithFields(logrus.Fields{ + "ipversions": ipversions, + "l4protos": l4protos, + }).Debugln("Choose DNS path") + // Get the min latency path. + for _, ver := range ipversions { + for _, proto := range l4protos { + d, latency, err := outbound.Select(proto, ver) + if err != nil { + continue + } + c.log.WithFields(logrus.Fields{ + "latency": latency, + "ver": ver, + "proto": proto, + "outbound": outbound.Name, + }).Debugln("Choose") + if bestDialer == nil || latency < bestLatency { + bestDialer = d + bestLatency = latency + l4proto = proto + ipversion = ver + } + } + } + switch ipversion { + case consts.IpVersionStr_4: + bestTarget = netip.AddrPortFrom(c.dnsUpstream.Ip4, c.dnsUpstream.Port) + case consts.IpVersionStr_6: + bestTarget = netip.AddrPortFrom(c.dnsUpstream.Ip6, c.dnsUpstream.Port) + } + dialerForNew = bestDialer + dummyFrom = &dst + destToSend = bestTarget + c.log.WithFields(logrus.Fields{ + "Original": RefineAddrPortToShow(dst), + "New": destToSend, + "Network": string(l4proto) + string(ipversion), + }).Traceln("Modify DNS target") + } + if dialerForNew == nil { + dialerForNew, _, err = outbound.Select(l4proto, ipversion) + if err != nil { + return fmt.Errorf("failed to select dialer from group %v: %w", outbound.Name, err) + } + } + + var isNew bool + var realDialer *dialer.Dialer + + udpHandler := c.WriteToUDP(src, isDns, dummyFrom, func(from netip.AddrPort) bool { + // We only validate rush-ans when outbound is direct and pkt does not send to a home device. + // Because additional record OPT may not be supported by home router. + // So se should trust home devices even if they make rush-answer (or looks like). + return outboundIndex == consts.OutboundDirect && !from.Addr().IsPrivate() + }) + + // Dial and send. + switch l4proto { + case consts.L4ProtoStr_UDP: + // Get udp endpoint. + var ue *UdpEndpoint + getNew: + ue, isNew, err = DefaultUdpEndpointPool.GetOrCreate(src, &UdpEndpointOptions{ + Handler: udpHandler, + NatTimeout: natTimeout, + DialerFunc: func() (*dialer.Dialer, error) { + return dialerForNew, nil + }, + Target: destToSend, + }) + if err != nil { + return fmt.Errorf("failed to GetOrCreate: %w", err) + } + // If the udp endpoint has been not alive, remove it from pool and get a new one. + if !isNew && !ue.Dialer.MustGetAlive(l4proto, ipversion) { + c.log.WithFields(logrus.Fields{ + "src": src.String(), + "network": string(l4proto) + string(ipversion), + "dialer": ue.Dialer.Name(), + }).Debugln("Old udp endpoint is not alive and removed") + _ = DefaultUdpEndpointPool.Remove(src, ue) + goto getNew + } + // This is real dialer. + realDialer = ue.Dialer + + //log.Printf("WriteToUDPAddrPort->%v", destToSend) + _, err = ue.WriteToUDPAddrPort(data, destToSend) + if err != nil { + return fmt.Errorf("failed to write UDP packet req: %w", err) + } + case consts.L4ProtoStr_TCP: + // MUST be DNS. + if !isDns { + return fmt.Errorf("UDP to TCP only support DNS request") + } + realDialer = dialerForNew + + // We can block because we are in a coroutine. + + conn, err := dialerForNew.Dial("tcp", destToSend.String()) + if err != nil { + return fmt.Errorf("failed to dial proxy to tcp: %w", err) + } + defer conn.Close() + + _ = conn.SetDeadline(time.Now().Add(natTimeout)) + // We should write two byte length in the front of TCP DNS request. + bLen := pool.Get(2) + defer pool.Put(bLen) + binary.BigEndian.PutUint16(bLen, uint16(len(data))) + _, err = conn.Write(bLen) + if err != nil { + return fmt.Errorf("failed to write DNS req length: %w", err) + } + if _, err = conn.Write(data); err != nil { + return fmt.Errorf("failed to write DNS req payload: %w", err) + } + + // Read two byte length. + if _, err = io.ReadFull(conn, bLen); err != nil { + return fmt.Errorf("failed to read DNS resp payload length: %w", err) + } + buf := pool.Get(int(binary.BigEndian.Uint16(bLen))) + defer pool.Put(buf) + if _, err = io.ReadFull(conn, buf); err != nil { + return fmt.Errorf("failed to read DNS resp payload: %w", err) + } + if err = udpHandler(buf, destToSend); err != nil { + return fmt.Errorf("failed to write DNS resp to client: %w", err) + } + } + + // Print log. + if isNew || isDns { // Only print routing for new connection to avoid the log exploded (Quic and BT). if isDns && c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 { q := dnsMessage.Questions[0] c.log.WithFields(logrus.Fields{ "network": string(l4proto) + string(ipversion) + "(DNS)", "outbound": outbound.Name, - "dialer": d.Name(), + "dialer": realDialer.Name(), "qname": strings.ToLower(q.Name.String()), "qtype": q.Type, }).Infof("%v <-> %v", @@ -235,16 +344,12 @@ getNew: c.log.WithFields(logrus.Fields{ "network": string(l4proto) + string(ipversion), "outbound": outbound.Name, - "dialer": d.Name(), + "dialer": realDialer.Name(), }).Infof("%v <-> %v", RefineSourceToShow(src, destToSend.Addr()), RefineAddrPortToShow(destToSend), ) } } - //log.Printf("WriteToUDPAddrPort->%v", destToSend) - _, err = ue.WriteToUDPAddrPort(data, destToSend) - if err != nil { - return fmt.Errorf("failed to write UDP packet req: %w", err) - } + return nil } diff --git a/example.dae b/example.dae index 9b2b0aa..b295caa 100644 --- a/example.dae +++ b/example.dae @@ -11,11 +11,14 @@ global { udp_check_dns: 'cloudflare-dns.com:53' check_interval: 30s - # Now only support udp://IP:Port. Empty value '' indicates as-is. + # Value can be scheme://host:port or empty string ''. + # The scheme can be tcp/udp/tcp+udp. Empty string '' indicates as-is. + # If host is a domain and has both IPv4 and IPv6 record, dae will automatically choose + # IPv4 or IPv6 to use according to group policy (such as min latency policy). # Please make sure DNS traffic will go through and be forwarded by dae. # The upstream DNS answer MUST NOT be polluted. # The request to dns upstream follows routing defined below. - dns_upstream: 'udp://8.8.8.8:53' + dns_upstream: 'tcp+udp://dns.google:53' # The LAN interface to bind. Use it if you only want to proxy LAN instead of localhost. # Multiple interfaces split by ",". @@ -72,7 +75,7 @@ routing { # Write your rules below. # dae arms DNS rush-answer filter so we can use 8.8.8.8 regardless of DNS pollution. - ip(8.8.8.8) && port(53) -> direct + domain(full:dns.google) && port(53) -> direct pname(firefox) && domain(ip.sb) -> direct pname(curl) && domain(ip.sb) -> my_group