feat: support tcp:// and tcp+udp:// for dns_upstream (#11)

This commit is contained in:
mzz
2023-02-09 11:40:34 +08:00
committed by GitHub
parent ac8b88d8ca
commit 15faa3cdd2
15 changed files with 697 additions and 185 deletions

View File

@ -13,8 +13,29 @@ import (
"golang.org/x/net/proxy"
"net/netip"
"strings"
"sync"
"time"
)
var (
systemDnsMu sync.Mutex
systemDns netip.AddrPort
)
func SystemDns() (dns netip.AddrPort, err error) {
systemDnsMu.Lock()
defer systemDnsMu.Unlock()
if !systemDns.IsValid() {
dnsConf := dnsReadConfig("/etc/resolv.conf")
if len(dnsConf.servers) == 0 {
err = fmt.Errorf("no valid dns server in /etc/resolv.conf")
return netip.AddrPort{}, err
}
systemDns = netip.MustParseAddrPort(dnsConf.servers[0])
}
return systemDns, nil
}
func ResolveNetip(ctx context.Context, d proxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type) (addrs []netip.Addr, err error) {
if addr, err := netip.ParseAddr(host); err == nil {
if (addr.Is4() || addr.Is4In6()) && typ == dnsmessage.TypeA {
@ -61,7 +82,23 @@ func ResolveNetip(ctx context.Context, d proxy.Dialer, dns netip.AddrPort, host
if err != nil {
return nil, err
}
ch := make(chan error, 1)
ch := make(chan error, 2)
go func() {
// Resend every 3 seconds.
for {
select {
case <-ctx.Done():
return
default:
time.Sleep(3 * time.Second)
}
_, err := c.Write(b)
if err != nil {
ch <- err
return
}
}
}()
go func() {
buf := pool.Get(512)
n, err := c.Read(buf)

View File

@ -0,0 +1,196 @@
// Modified from go1.18/src/net/dnsconfig_unix.go
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris
// Read system DNS config from /etc/resolv.conf
package netutils
import (
"bufio"
"net"
"net/netip"
"os"
"strconv"
"strings"
"sync/atomic"
"time"
)
var (
defaultNS = []string{"127.0.0.1:53", "[::1]:53"}
getHostname = os.Hostname // variable for testing
)
type dnsConfig struct {
servers []string // server addresses (in host:port form) to use
search []string // rooted suffixes to append to local name
ndots int // number of dots in name to trigger absolute lookup
timeout time.Duration // wait before giving up on a query, including retries
attempts int // lost packets before giving up on server
rotate bool // round robin among servers
unknownOpt bool // anything unknown was encountered
lookup []string // OpenBSD top-level database "lookup" order
err error // any error that occurs during open of resolv.conf
mtime time.Time // time of resolv.conf modification
soffset uint32 // used by serverOffset
singleRequest bool // use sequential A and AAAA queries instead of parallel queries
useTCP bool // force usage of TCP for DNS resolutions
}
// See resolv.conf(5) on a Linux machine.
func dnsReadConfig(filename string) *dnsConfig {
conf := &dnsConfig{
ndots: 1,
timeout: 5 * time.Second,
attempts: 2,
}
file, err := os.Open(filename)
if err != nil {
conf.servers = defaultNS
conf.search = dnsDefaultSearch()
conf.err = err
return conf
}
defer file.Close()
if fi, err := file.Stat(); err == nil {
conf.mtime = fi.ModTime()
} else {
conf.servers = defaultNS
conf.search = dnsDefaultSearch()
conf.err = err
return conf
}
fio := bufio.NewReader(file)
for line, _, err := fio.ReadLine(); err == nil; line, _, err = fio.ReadLine() {
if len(line) > 0 && (line[0] == ';' || line[0] == '#') {
// comment.
continue
}
f := strings.Fields(string(line))
if len(f) < 1 {
continue
}
switch f[0] {
case "nameserver": // add one name server
if len(f) > 1 && len(conf.servers) < 3 { // small, but the standard limit
// One more check: make sure server name is
// just an IP address. Otherwise we need DNS
// to look it up.
if _, e := netip.ParseAddr(f[1]); e == nil {
conf.servers = append(conf.servers, net.JoinHostPort(f[1], "53"))
}
}
case "domain": // set search path to just this domain
if len(f) > 1 {
conf.search = []string{ensureRooted(f[1])}
}
case "search": // set search path to given servers
conf.search = make([]string, len(f)-1)
for i := 0; i < len(conf.search); i++ {
conf.search[i] = ensureRooted(f[i+1])
}
case "options": // magic options
for _, s := range f[1:] {
switch {
case hasPrefix(s, "ndots:"):
n, _ := strconv.Atoi(s[6:])
if n < 0 {
n = 0
} else if n > 15 {
n = 15
}
conf.ndots = n
case hasPrefix(s, "timeout:"):
n, _ := strconv.Atoi(s[8:])
if n < 1 {
n = 1
}
conf.timeout = time.Duration(n) * time.Second
case hasPrefix(s, "attempts:"):
n, _ := strconv.Atoi(s[9:])
if n < 1 {
n = 1
}
conf.attempts = n
case s == "rotate":
conf.rotate = true
case s == "single-request" || s == "single-request-reopen":
// Linux option:
// http://man7.org/linux/man-pages/man5/resolv.conf.5.html
// "By default, glibc performs IPv4 and IPv6 lookups in parallel [...]
// This option disables the behavior and makes glibc
// perform the IPv6 and IPv4 requests sequentially."
conf.singleRequest = true
case s == "use-vc" || s == "usevc" || s == "tcp":
// Linux (use-vc), FreeBSD (usevc) and OpenBSD (tcp) option:
// http://man7.org/linux/man-pages/man5/resolv.conf.5.html
// "Sets RES_USEVC in _res.options.
// This option forces the use of TCP for DNS resolutions."
// https://www.freebsd.org/cgi/man.cgi?query=resolv.conf&sektion=5&manpath=freebsd-release-ports
// https://man.openbsd.org/resolv.conf.5
conf.useTCP = true
default:
conf.unknownOpt = true
}
}
case "lookup":
// OpenBSD option:
// https://www.openbsd.org/cgi-bin/man.cgi/OpenBSD-current/man5/resolv.conf.5
// "the legal space-separated values are: bind, file, yp"
conf.lookup = f[1:]
default:
conf.unknownOpt = true
}
}
if len(conf.servers) == 0 {
conf.servers = defaultNS
}
if len(conf.search) == 0 {
conf.search = dnsDefaultSearch()
}
return conf
}
// serverOffset returns an offset that can be used to determine
// indices of servers in c.servers when making queries.
// When the rotate option is enabled, this offset increases.
// Otherwise it is always 0.
func (c *dnsConfig) serverOffset() uint32 {
if c.rotate {
return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start
}
return 0
}
func dnsDefaultSearch() []string {
hn, err := getHostname()
if err != nil {
// best effort
return nil
}
if i := strings.IndexByte(hn, '.'); i >= 0 && i < len(hn)-1 {
return []string{ensureRooted(hn[i+1:])}
}
return nil
}
func hasPrefix(s, prefix string) bool {
return len(s) >= len(prefix) && s[:len(prefix)] == prefix
}
func ensureRooted(s string) string {
if len(s) > 0 && s[len(s)-1] == '.' {
return s
}
return s + "."
}

48
common/netutils/ip46.go Normal file
View File

@ -0,0 +1,48 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2023, mzz2017 <mzz@tuta.io>
*/
package netutils
import (
"context"
"fmt"
"golang.org/x/net/dns/dnsmessage"
"golang.org/x/net/proxy"
"net/netip"
)
type Ip46 struct {
Ip4 netip.Addr
Ip6 netip.Addr
}
func ParseIp46(ctx context.Context, dialer proxy.Dialer, dns netip.AddrPort, host string, must46 bool) (ipv46 *Ip46, err error) {
addrs4, err := ResolveNetip(ctx, dialer, dns, host, dnsmessage.TypeA)
if err != nil {
return nil, err
}
if len(addrs4) == 0 && must46 {
if must46 {
return nil, fmt.Errorf("domain \"%v\" has no ipv4 record", host)
} else {
addrs4 = []netip.Addr{{}}
}
}
addrs6, err := ResolveNetip(ctx, dialer, dns, host, dnsmessage.TypeAAAA)
if err != nil {
return nil, err
}
if len(addrs6) == 0 {
if must46 {
return nil, fmt.Errorf("domain \"%v\" has no ipv6 record", host)
} else {
addrs6 = []netip.Addr{{}}
}
}
return &Ip46{
Ip4: addrs4[0],
Ip6: addrs6[0],
}, nil
}

View File

@ -83,8 +83,8 @@ func (a *AliveDialerSet) GetRand() *Dialer {
}
// GetMinLatency acquires correct selectionPolicy.
func (a *AliveDialerSet) GetMinLatency() *Dialer {
return a.minLatency.dialer
func (a *AliveDialerSet) GetMinLatency() (d *Dialer, latency time.Duration) {
return a.minLatency.dialer, a.minLatency.latency
}
// NotifyLatencyChange should be invoked when dialer every time latency and alive state changes.

View File

@ -24,10 +24,6 @@ import (
"time"
)
var (
BootstrapDns = netip.MustParseAddrPort("223.5.5.5:53")
)
type collection struct {
// AliveDialerSetSet uses reference counting.
AliveDialerSetSet AliveDialerSetSet
@ -71,43 +67,22 @@ func (d *Dialer) MustGetAlive(l4proto consts.L4ProtoStr, ipversion consts.IpVers
return d.mustGetCollection(l4proto, ipversion).Alive
}
type Ip46 struct {
Ip4 netip.Addr
Ip6 netip.Addr
}
func ParseIp46(ctx context.Context, host string) (ipv46 *Ip46, err error) {
addrs4, err := netutils.ResolveNetip(ctx, SymmetricDirect, BootstrapDns, host, dnsmessage.TypeA)
if err != nil {
return nil, err
}
if len(addrs4) == 0 {
return nil, fmt.Errorf("domain \"%v\" has no ipv4 record", host)
}
addrs6, err := netutils.ResolveNetip(ctx, SymmetricDirect, BootstrapDns, host, dnsmessage.TypeAAAA)
if err != nil {
return nil, err
}
if len(addrs6) == 0 {
return nil, fmt.Errorf("domain \"%v\" has no ipv6 record", host)
}
return &Ip46{
Ip4: addrs4[0],
Ip6: addrs6[0],
}, nil
}
type TcpCheckOption struct {
Url *netutils.URL
*Ip46
*netutils.Ip46
}
func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOption, err error) {
systemDns, err := netutils.SystemDns()
if err != nil {
return nil, err
}
u, err := url.Parse(rawURL)
if err != nil {
return nil, err
}
ip46, err := ParseIp46(ctx, u.Hostname())
ip46, err := netutils.ParseIp46(ctx, SymmetricDirect, systemDns, u.Hostname(), true)
if err != nil {
return nil, err
}
@ -120,10 +95,15 @@ func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOptio
type UdpCheckOption struct {
DnsHost string
DnsPort uint16
*Ip46
*netutils.Ip46
}
func ParseUdpCheckOption(ctx context.Context, dnsHostPort string) (opt *UdpCheckOption, err error) {
systemDns, err := netutils.SystemDns()
if err != nil {
return nil, err
}
host, _port, err := net.SplitHostPort(dnsHostPort)
if err != nil {
return nil, err
@ -132,7 +112,7 @@ func ParseUdpCheckOption(ctx context.Context, dnsHostPort string) (opt *UdpCheck
if err != nil {
return nil, fmt.Errorf("bad port: %v", err)
}
ip46, err := ParseIp46(ctx, host)
ip46, err := netutils.ParseIp46(ctx, SymmetricDirect, systemDns, host, true)
if err != nil {
return nil, err
}

View File

@ -15,6 +15,7 @@ import (
"net"
"net/netip"
"strings"
"time"
)
type DialerGroup struct {
@ -95,9 +96,9 @@ func (g *DialerGroup) SetSelectionPolicy(policy DialerSelectionPolicy) {
}
// Select selects a dialer from group according to selectionPolicy.
func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersionStr) (*dialer.Dialer, error) {
func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersionStr) (d *dialer.Dialer, latency time.Duration, err error) {
if len(g.Dialers) == 0 {
return nil, fmt.Errorf("no dialer in this group")
return nil, 0, fmt.Errorf("no dialer in this group")
}
var a *dialer.AliveDialerSet
switch l4proto {
@ -116,7 +117,7 @@ func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersi
a = g.AliveUdp6DialerSet
}
default:
return nil, fmt.Errorf("DialerGroup.Select: unexpected l4proto type: %v", l4proto)
return nil, 0, fmt.Errorf("DialerGroup.Select: unexpected l4proto type: %v", l4proto)
}
switch g.selectionPolicy.Policy {
@ -128,30 +129,30 @@ func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersi
"network": string(l4proto) + string(ipversion),
"group": g.Name,
}).Warnf("No alive dialer in DialerGroup, use \"block\".")
return g.block, nil
return g.block, 0, nil
}
return d, nil
return d, 0, nil
case consts.DialerSelectionPolicy_Fixed:
if g.selectionPolicy.FixedIndex < 0 || g.selectionPolicy.FixedIndex >= len(g.Dialers) {
return nil, fmt.Errorf("selected dialer index is out of range")
return nil, 0, fmt.Errorf("selected dialer index is out of range")
}
return g.Dialers[g.selectionPolicy.FixedIndex], nil
return g.Dialers[g.selectionPolicy.FixedIndex], 0, nil
case consts.DialerSelectionPolicy_MinLastLatency, consts.DialerSelectionPolicy_MinAverage10Latencies:
d := a.GetMinLatency()
d, latency := a.GetMinLatency()
if d == nil {
// No alive dialer.
g.log.WithFields(logrus.Fields{
"network": string(l4proto) + string(ipversion),
"group": g.Name,
}).Warnf("No alive dialer in DialerGroup, use \"block\".")
return g.block, nil
return g.block, 0, nil
}
return d, nil
return d, latency, nil
default:
return nil, fmt.Errorf("unsupported DialerSelectionPolicy: %v", g.selectionPolicy)
return nil, 0, fmt.Errorf("unsupported DialerSelectionPolicy: %v", g.selectionPolicy)
}
}
@ -164,9 +165,9 @@ func (g *DialerGroup) Dial(network string, addr string) (c net.Conn, err error)
ipversion := consts.IpVersionFromAddr(ipAddr)
switch {
case strings.HasPrefix(network, "tcp"):
d, err = g.Select(consts.L4ProtoStr_TCP, ipversion)
d, _, err = g.Select(consts.L4ProtoStr_TCP, ipversion)
case strings.HasPrefix(network, "udp"):
d, err = g.Select(consts.L4ProtoStr_UDP, ipversion)
d, _, err = g.Select(consts.L4ProtoStr_UDP, ipversion)
default:
return nil, fmt.Errorf("unexpected network: %v", network)
}

View File

@ -44,9 +44,9 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Fixed,
FixedIndex: fixedIndex,
})
}, func(alive bool, l4proto uint8, ipversion uint8) {})
for i := 0; i < 10; i++ {
d, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
if err != nil {
t.Fatal(err)
}
@ -58,7 +58,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
fixedIndex = 0
g.selectionPolicy.FixedIndex = fixedIndex
for i := 0; i < 10; i++ {
d, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
if err != nil {
t.Fatal(err)
}
@ -98,7 +98,7 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
}
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_MinLastLatency,
})
}, func(alive bool, l4proto uint8, ipversion uint8) {})
// Test 1000 times.
for i := 0; i < 1000; i++ {
@ -127,7 +127,7 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
}
g.AliveTcp4DialerSet.NotifyLatencyChange(d, alive)
}
d, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
if err != nil {
t.Fatal(err)
}
@ -170,10 +170,10 @@ func TestDialerGroup_Select_Random(t *testing.T) {
}
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Random,
})
}, func(alive bool, l4proto uint8, ipversion uint8) {})
count := make([]int, len(dialers))
for i := 0; i < 100; i++ {
d, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
if err != nil {
t.Fatal(err)
}
@ -217,12 +217,12 @@ func TestDialerGroup_SetAlive(t *testing.T) {
}
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Random,
})
}, func(alive bool, l4proto uint8, ipversion uint8) {})
zeroTarget := 3
g.AliveTcp4DialerSet.NotifyLatencyChange(dialers[zeroTarget], false)
count := make([]int, len(dialers))
for i := 0; i < 100; i++ {
d, err := g.Select(consts.L4ProtoStr_UDP, consts.IpVersionStr_4)
d, _, err := g.Select(consts.L4ProtoStr_UDP, consts.IpVersionStr_4)
if err != nil {
t.Fatal(err)
}

View File

@ -20,6 +20,7 @@ import (
"github.com/v2rayA/dae/config"
"github.com/v2rayA/dae/pkg/config_parser"
internal "github.com/v2rayA/dae/pkg/ebpf_internal"
"golang.org/x/net/dns/dnsmessage"
"golang.org/x/sys/unix"
"net"
"net/netip"
@ -50,7 +51,7 @@ type ControlPlane struct {
// mutex protects the dnsCache.
mutex sync.Mutex
dnsCache map[string]*dnsCache
dnsUpstream netip.AddrPort
dnsUpstream *DnsUpstraem
}
func NewControlPlane(
@ -189,7 +190,7 @@ func NewControlPlane(
}
/// DialerGroups (outbounds).
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.TODO(), 30*time.Second)
defer cancel()
tcpCheckOption, err := dialer.ParseTcpCheckOption(ctx, global.TcpCheckUrl)
if err != nil {
@ -285,22 +286,78 @@ func NewControlPlane(
}
/// DNS upstream.
var dnsAddrPort netip.AddrPort
var dnsUpstream *DnsUpstraem
if !global.DnsUpstream.Empty {
if dnsAddrPort, err = resolveDnsUpstream(global.DnsUpstream.Url); err != nil {
if dnsUpstream, err = ResolveDnsUpstream(ctx, global.DnsUpstream.Url); err != nil {
return nil, err
}
dnsAddr16 := dnsAddrPort.Addr().As16()
if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfIpPort{
Ip: common.Ipv6ByteSliceToUint32Array(dnsAddr16[:]),
Port: internal.Htons(dnsAddrPort.Port()),
ip4in6 := dnsUpstream.Ip4.As16()
ip6 := dnsUpstream.Ip6.As16()
if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{
Ip4: common.Ipv6ByteSliceToUint32Array(ip4in6[:]),
Ip6: common.Ipv6ByteSliceToUint32Array(ip6[:]),
HasIp4: dnsUpstream.Ip4.IsValid(),
HasIp6: dnsUpstream.Ip6.IsValid(),
Port: internal.Htons(dnsUpstream.Port),
}, ebpf.UpdateAny); err != nil {
return nil, err
}
defer func() {
// Update dns cache to support domain routing for hostname of dns_upstream.
if err == nil {
// Ten years later.
deadline := time.Now().Add(24 * time.Hour * 365 * 10)
fqdn := dnsUpstream.Hostname
if !strings.HasSuffix(fqdn, ".") {
fqdn = fqdn + "."
}
if dnsUpstream.Ip4.IsValid() {
typ := dnsmessage.TypeA
answers := []dnsmessage.Resource{{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(fqdn),
Type: typ,
Class: dnsmessage.ClassINET,
TTL: 0, // Must be zero.
},
Body: &dnsmessage.AResource{
A: dnsUpstream.Ip4.As4(),
},
}}
if err = c.UpdateDnsCache(fqdn, typ, answers, deadline); err != nil {
c = nil
return
}
}
if dnsUpstream.Ip6.IsValid() {
typ := dnsmessage.TypeAAAA
answers := []dnsmessage.Resource{{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(fqdn),
Type: typ,
Class: dnsmessage.ClassINET,
TTL: 0, // Must be zero.
},
Body: &dnsmessage.AAAAResource{
AAAA: dnsUpstream.Ip6.As16(),
},
}}
if err = c.UpdateDnsCache(fqdn, typ, answers, deadline); err != nil {
c = nil
return
}
}
}
}()
} else {
// Empty.
if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfIpPort{
Ip: [4]uint32{},
// Empty string. As-is.
if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{
Ip4: [4]uint32{},
Ip6: [4]uint32{},
HasIp4: false,
HasIp6: false,
// Zero port indicates no element, because bpf_map_lookup_elem cannot return 0 for map_type_array.
Port: 0,
}, ebpf.UpdateAny); err != nil {
@ -325,7 +382,7 @@ func NewControlPlane(
Final: routingA.Final,
mutex: sync.Mutex{},
dnsCache: make(map[string]*dnsCache),
dnsUpstream: dnsAddrPort,
dnsUpstream: dnsUpstream,
}, nil
}

View File

@ -1,34 +0,0 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2023, mzz2017 <mzz@tuta.io>
*/
package control
import (
"fmt"
"net"
"net/netip"
"net/url"
)
func resolveDnsUpstream(dnsUpstream *url.URL) (addrPort netip.AddrPort, err error) {
if dnsUpstream.Scheme != "udp" {
return netip.AddrPort{}, fmt.Errorf("dns_upstream now only supports udp://")
}
port := dnsUpstream.Port()
if port == "" {
port = "53"
}
hostname := dnsUpstream.Hostname()
ips, _ := net.LookupIP(hostname)
if len(ips) == 0 {
return netip.AddrPort{}, fmt.Errorf("cannot resolve hostname of dns upstream: %v", hostname)
}
// resolve hostname
dnsAddrPort, err := netip.ParseAddrPort(net.JoinHostPort(ips[0].String(), port))
if err != nil {
return netip.AddrPort{}, fmt.Errorf("failed to parse DNS upstream: \"%v\": %w", dnsUpstream.String(), err)
}
return dnsAddrPort, nil
}

View File

@ -282,27 +282,36 @@ loop:
"auth": FormatDnsRsc(msg.Authorities),
"addi": FormatDnsRsc(msg.Additionals),
}).Tracef("Update DNS record cache")
if err = c.UpdateDnsCache(q.Name.String(), q.Type, msg.Answers, time.Now().Add(time.Duration(ttl)*time.Second+DnsNatTimeout)); err != nil {
return nil, err
}
// Pack to get newData.
return msg.Pack()
}
func (c *ControlPlane) UpdateDnsCache(host string, typ dnsmessage.Type, answers []dnsmessage.Resource, deadline time.Time) (err error) {
c.mutex.Lock()
fqdn := strings.ToLower(q.Name.String())
cacheKey := fqdn + q.Type.String()
fqdn := strings.ToLower(host)
if !strings.HasSuffix(fqdn, ".") {
fqdn += "."
}
cacheKey := fqdn + typ.String()
cache, ok := c.dnsCache[cacheKey]
if ok {
c.mutex.Unlock()
cache.Deadline = time.Now().Add(time.Duration(ttl)*time.Second + DnsNatTimeout)
cache.Answers = msg.Answers
cache.Deadline = deadline
cache.Answers = answers
} else {
cache = &dnsCache{
DomainBitmap: c.MatchDomainBitmap(strings.TrimSuffix(fqdn, ".")),
Answers: msg.Answers,
Deadline: time.Now().Add(time.Duration(ttl)*time.Second + DnsNatTimeout),
Answers: answers,
Deadline: deadline,
}
c.dnsCache[cacheKey] = cache
c.mutex.Unlock()
}
if err = c.BatchUpdateDomainRouting(cache); err != nil {
return nil, fmt.Errorf("BatchUpdateDomainRouting: %w", err)
return fmt.Errorf("BatchUpdateDomainRouting: %w", err)
}
// Pack to get newData.
return msg.Pack()
return nil
}

89
control/dns_upstream.go Normal file
View File

@ -0,0 +1,89 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2023, mzz2017 <mzz@tuta.io>
*/
package control
import (
"context"
"fmt"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/common/netutils"
"github.com/v2rayA/dae/component/outbound/dialer"
"net/url"
"strconv"
)
type DnsUpstreamScheme string
const (
DnsUpstreamScheme_TCP DnsUpstreamScheme = "tcp"
DnsUpstreamScheme_UDP DnsUpstreamScheme = "udp"
DnsUpstreamScheme_TCP_UDP DnsUpstreamScheme = "tcp+udp"
)
type DnsUpstraem struct {
Scheme DnsUpstreamScheme
Hostname string
Port uint16
*netutils.Ip46
}
func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstraem, err error) {
var _port string
switch DnsUpstreamScheme(dnsUpstream.Scheme) {
case DnsUpstreamScheme_TCP, DnsUpstreamScheme_UDP, DnsUpstreamScheme_TCP_UDP:
_port = dnsUpstream.Port()
if _port == "" {
_port = "53"
}
default:
return nil, fmt.Errorf("dns_upstream now only supports auto://, udp://, tcp:// and empty string (as-is)")
}
systemDns, err := netutils.SystemDns()
if err != nil {
return nil, err
}
port, err := strconv.ParseUint(dnsUpstream.Port(), 10, 16)
if err != nil {
return nil, fmt.Errorf("parse dns_upstream port: %v", err)
}
hostname := dnsUpstream.Hostname()
ip46, err := netutils.ParseIp46(ctx, dialer.SymmetricDirect, systemDns, hostname, false)
if err != nil {
return nil, fmt.Errorf("failed to resolve dns_upstream")
}
if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() {
return nil, fmt.Errorf("dns_upstream has no record")
}
return &DnsUpstraem{
Scheme: DnsUpstreamScheme(dnsUpstream.Scheme),
Hostname: hostname,
Port: uint16(port),
Ip46: ip46,
}, nil
}
func (u *DnsUpstraem) SupportedNetworks() (ipversions []consts.IpVersionStr, l4protos []consts.L4ProtoStr) {
if u.Ip4.IsValid() && u.Ip6.IsValid() {
ipversions = []consts.IpVersionStr{consts.IpVersionStr_4, consts.IpVersionStr_6}
} else {
if u.Ip4.IsValid() {
ipversions = []consts.IpVersionStr{consts.IpVersionStr_4}
} else {
ipversions = []consts.IpVersionStr{consts.IpVersionStr_6}
}
}
switch u.Scheme {
case DnsUpstreamScheme_TCP:
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_TCP}
case DnsUpstreamScheme_UDP:
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP}
case DnsUpstreamScheme_TCP_UDP:
// UDP first.
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP, consts.L4ProtoStr_TCP}
}
return ipversions, l4protos
}

View File

@ -74,6 +74,16 @@ enum {
DisableL4ChecksumPolicy_SetZero,
};
// Param keys:
static const __u32 zero_key = 0;
static const __u32 tproxy_port_key = 1;
static const __u32 one_key = 1;
static const __u32 disable_l4_tx_checksum_key
__attribute__((unused, deprecated)) = 2;
static const __u32 disable_l4_rx_checksum_key
__attribute__((unused, deprecated)) = 3;
static const __u32 control_plane_pid_key = 4;
// Outbound Connectivity Map:
struct outbound_connectivity_query {
@ -97,15 +107,8 @@ struct {
__uint(max_entries, 2);
} listen_socket_map SEC(".maps");
// Param keys:
static const __u32 zero_key = 0;
static const __u32 tproxy_port_key = 1;
static const __u32 one_key = 1;
static const __u32 disable_l4_tx_checksum_key
__attribute__((unused, deprecated)) = 2;
static const __u32 disable_l4_rx_checksum_key
__attribute__((unused, deprecated)) = 3;
static const __u32 control_plane_pid_key = 4;
/// TODO: Remove items from the dst_map by conntrack.
// Dest map:
struct ip_port {
__be32 ip[4];
@ -125,8 +128,6 @@ struct tuples {
__u8 l4proto;
};
/// TODO: Remove items from the dst_map by conntrack.
// Dest map:
struct {
__uint(type, BPF_MAP_TYPE_LRU_HASH);
__type(key,
@ -195,10 +196,19 @@ struct {
} ipproto_hdrsize_map SEC(".maps");
// Dns upstream:
struct dns_upstream {
__be32 ip4[4];
__be32 ip6[4];
bool hasIp4;
bool hasIp6;
__be16 port;
};
struct {
__uint(type, BPF_MAP_TYPE_ARRAY);
__type(key, __u32);
__type(value, struct ip_port);
__type(value, struct dns_upstream);
/// FIXME: l4proto is always udp.
__uint(max_entries, 1);
} dns_upstream_map SEC(".maps");
@ -974,11 +984,22 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
// Modify DNS upstream for routing.
if (h_dport == 53 && _l4proto_type == L4ProtoType_UDP) {
struct ip_port *upstream =
struct dns_upstream *upstream =
bpf_map_lookup_elem(&dns_upstream_map, &zero_key);
if (upstream && upstream->port != 0) {
h_dport = bpf_ntohs(upstream->port);
__builtin_memcpy(daddr, upstream->ip, IPV6_BYTE_LENGTH);
if (_ipversion_type == IpVersionType_4 && upstream->hasIp4) {
__builtin_memcpy(daddr, upstream->ip4, IPV6_BYTE_LENGTH);
} else if (_ipversion_type == IpVersionType_6 && upstream->hasIp6) {
__builtin_memcpy(daddr, upstream->ip6, IPV6_BYTE_LENGTH);
} else if (upstream->hasIp4) {
__builtin_memcpy(daddr, upstream->ip4, IPV6_BYTE_LENGTH);
} else if (upstream->hasIp6) {
__builtin_memcpy(daddr, upstream->ip6, IPV6_BYTE_LENGTH);
} else {
bpf_printk("bad dns upstream; use as-is.");
__builtin_memcpy(daddr, _daddr, IPV6_BYTE_LENGTH);
}
} else {
__builtin_memcpy(daddr, _daddr, IPV6_BYTE_LENGTH);
}

View File

@ -60,7 +60,7 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
}
l4proto := consts.L4ProtoStr_TCP
ipversion := consts.IpVersionFromAddr(dst.Addr())
dialer, err := outbound.Select(l4proto, ipversion)
dialer, _, err := outbound.Select(l4proto, ipversion)
if err != nil {
return fmt.Errorf("failed to select dialer from group %v: %w", outbound.Name, err)
}

View File

@ -14,6 +14,7 @@ import (
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/outbound/dialer"
"golang.org/x/net/dns/dnsmessage"
"io"
"net"
"net/netip"
"strings"
@ -94,11 +95,12 @@ func sendPktBind(data []byte, from netip.AddrPort, to netip.AddrPort) error {
return err
}
func (c *ControlPlane) RelayToUDP(to netip.AddrPort, isDNS bool, dummyFrom *netip.AddrPort, validateRushAns bool) UdpHandler {
func (c *ControlPlane) WriteToUDP(to netip.AddrPort, isDNS bool, dummyFrom *netip.AddrPort, validateRushAnsFunc func(from netip.AddrPort) bool) UdpHandler {
return func(data []byte, from netip.AddrPort) (err error) {
// Do not return conn-unrelated err in this func.
if isDNS {
validateRushAns := validateRushAnsFunc(from)
data, err = c.DnsRespHandler(data, validateRushAns)
if err != nil {
if validateRushAns && errors.Is(err, SuspectedRushAnswerError) {
@ -158,15 +160,6 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI
return nil
}
// Need to make a DNS request.
if c.dnsUpstream.IsValid() {
c.log.Tracef("Modify dns target %v to upstream: %v", RefineAddrPortToShow(destToSend), c.dnsUpstream)
// Modify dns target to upstream.
// NOTICE: Routing was calculated in advance by the eBPF program.
dummyFrom = &dst
destToSend = c.dnsUpstream
}
// Flip dns question to reduce dns pollution.
FlipDnsQuestionCase(dnsMessage)
// Make sure there is additional record OPT in the request to filter DNS rush-answer in the response process.
@ -180,51 +173,167 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI
}
}
// We only validate rush-ans when outbound is direct and pkt does not send to a home device.
// Because additional record OPT may not be supported by home router.
// So se should trust home devices even if they make rush-answer (or looks like).
validateRushAns := outboundIndex == consts.OutboundDirect && !destToSend.Addr().IsPrivate()
// Get udp endpoint.
l4proto := consts.L4ProtoStr_UDP
ipversion := consts.IpVersionFromAddr(dst.Addr())
getNew:
ue, isNew, err := DefaultUdpEndpointPool.GetOrCreate(src, &UdpEndpointOptions{
Handler: c.RelayToUDP(src, isDns, dummyFrom, validateRushAns),
NatTimeout: natTimeout,
DialerFunc: func() (*dialer.Dialer, error) {
newDialer, err := outbound.Select(l4proto, ipversion)
if err != nil {
return nil, fmt.Errorf("failed to select dialer from group %v: %w", outbound.Name, err)
}
return newDialer, nil
},
Target: destToSend,
})
if err != nil {
return fmt.Errorf("failed to GetOrCreate: %w", err)
}
// If the udp endpoint has been not alive, remove it from pool and get a new one.
if !isNew && !ue.Dialer.MustGetAlive(l4proto, ipversion) {
c.log.WithFields(logrus.Fields{
"src": src.String(),
"network": string(l4proto) + string(ipversion),
"dialer": ue.Dialer.Name(),
}).Debugln("Old udp endpoint is not alive and removed")
_ = DefaultUdpEndpointPool.Remove(src, ue)
goto getNew
}
// This is real dialer.
d := ue.Dialer
var dialerForNew *dialer.Dialer
if isNew {
// For DNS request, modify dst to dns upstream.
// NOTICE: We might modify l4proto and ipversion.
if isDns && c.dnsUpstream != nil {
// Modify dns target to upstream.
// NOTICE: Routing was calculated in advance by the eBPF program.
/// Choose the best l4proto and ipversion.
// Get available ipversions and l4protos for DNS upstream.
ipversions, l4protos := c.dnsUpstream.SupportedNetworks()
var (
bestDialer *dialer.Dialer
bestLatency time.Duration
bestTarget netip.AddrPort
)
c.log.WithFields(logrus.Fields{
"ipversions": ipversions,
"l4protos": l4protos,
}).Debugln("Choose DNS path")
// Get the min latency path.
for _, ver := range ipversions {
for _, proto := range l4protos {
d, latency, err := outbound.Select(proto, ver)
if err != nil {
continue
}
c.log.WithFields(logrus.Fields{
"latency": latency,
"ver": ver,
"proto": proto,
"outbound": outbound.Name,
}).Debugln("Choose")
if bestDialer == nil || latency < bestLatency {
bestDialer = d
bestLatency = latency
l4proto = proto
ipversion = ver
}
}
}
switch ipversion {
case consts.IpVersionStr_4:
bestTarget = netip.AddrPortFrom(c.dnsUpstream.Ip4, c.dnsUpstream.Port)
case consts.IpVersionStr_6:
bestTarget = netip.AddrPortFrom(c.dnsUpstream.Ip6, c.dnsUpstream.Port)
}
dialerForNew = bestDialer
dummyFrom = &dst
destToSend = bestTarget
c.log.WithFields(logrus.Fields{
"Original": RefineAddrPortToShow(dst),
"New": destToSend,
"Network": string(l4proto) + string(ipversion),
}).Traceln("Modify DNS target")
}
if dialerForNew == nil {
dialerForNew, _, err = outbound.Select(l4proto, ipversion)
if err != nil {
return fmt.Errorf("failed to select dialer from group %v: %w", outbound.Name, err)
}
}
var isNew bool
var realDialer *dialer.Dialer
udpHandler := c.WriteToUDP(src, isDns, dummyFrom, func(from netip.AddrPort) bool {
// We only validate rush-ans when outbound is direct and pkt does not send to a home device.
// Because additional record OPT may not be supported by home router.
// So se should trust home devices even if they make rush-answer (or looks like).
return outboundIndex == consts.OutboundDirect && !from.Addr().IsPrivate()
})
// Dial and send.
switch l4proto {
case consts.L4ProtoStr_UDP:
// Get udp endpoint.
var ue *UdpEndpoint
getNew:
ue, isNew, err = DefaultUdpEndpointPool.GetOrCreate(src, &UdpEndpointOptions{
Handler: udpHandler,
NatTimeout: natTimeout,
DialerFunc: func() (*dialer.Dialer, error) {
return dialerForNew, nil
},
Target: destToSend,
})
if err != nil {
return fmt.Errorf("failed to GetOrCreate: %w", err)
}
// If the udp endpoint has been not alive, remove it from pool and get a new one.
if !isNew && !ue.Dialer.MustGetAlive(l4proto, ipversion) {
c.log.WithFields(logrus.Fields{
"src": src.String(),
"network": string(l4proto) + string(ipversion),
"dialer": ue.Dialer.Name(),
}).Debugln("Old udp endpoint is not alive and removed")
_ = DefaultUdpEndpointPool.Remove(src, ue)
goto getNew
}
// This is real dialer.
realDialer = ue.Dialer
//log.Printf("WriteToUDPAddrPort->%v", destToSend)
_, err = ue.WriteToUDPAddrPort(data, destToSend)
if err != nil {
return fmt.Errorf("failed to write UDP packet req: %w", err)
}
case consts.L4ProtoStr_TCP:
// MUST be DNS.
if !isDns {
return fmt.Errorf("UDP to TCP only support DNS request")
}
realDialer = dialerForNew
// We can block because we are in a coroutine.
conn, err := dialerForNew.Dial("tcp", destToSend.String())
if err != nil {
return fmt.Errorf("failed to dial proxy to tcp: %w", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(natTimeout))
// We should write two byte length in the front of TCP DNS request.
bLen := pool.Get(2)
defer pool.Put(bLen)
binary.BigEndian.PutUint16(bLen, uint16(len(data)))
_, err = conn.Write(bLen)
if err != nil {
return fmt.Errorf("failed to write DNS req length: %w", err)
}
if _, err = conn.Write(data); err != nil {
return fmt.Errorf("failed to write DNS req payload: %w", err)
}
// Read two byte length.
if _, err = io.ReadFull(conn, bLen); err != nil {
return fmt.Errorf("failed to read DNS resp payload length: %w", err)
}
buf := pool.Get(int(binary.BigEndian.Uint16(bLen)))
defer pool.Put(buf)
if _, err = io.ReadFull(conn, buf); err != nil {
return fmt.Errorf("failed to read DNS resp payload: %w", err)
}
if err = udpHandler(buf, destToSend); err != nil {
return fmt.Errorf("failed to write DNS resp to client: %w", err)
}
}
// Print log.
if isNew || isDns {
// Only print routing for new connection to avoid the log exploded (Quic and BT).
if isDns && c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 {
q := dnsMessage.Questions[0]
c.log.WithFields(logrus.Fields{
"network": string(l4proto) + string(ipversion) + "(DNS)",
"outbound": outbound.Name,
"dialer": d.Name(),
"dialer": realDialer.Name(),
"qname": strings.ToLower(q.Name.String()),
"qtype": q.Type,
}).Infof("%v <-> %v",
@ -235,16 +344,12 @@ getNew:
c.log.WithFields(logrus.Fields{
"network": string(l4proto) + string(ipversion),
"outbound": outbound.Name,
"dialer": d.Name(),
"dialer": realDialer.Name(),
}).Infof("%v <-> %v",
RefineSourceToShow(src, destToSend.Addr()), RefineAddrPortToShow(destToSend),
)
}
}
//log.Printf("WriteToUDPAddrPort->%v", destToSend)
_, err = ue.WriteToUDPAddrPort(data, destToSend)
if err != nil {
return fmt.Errorf("failed to write UDP packet req: %w", err)
}
return nil
}

View File

@ -11,11 +11,14 @@ global {
udp_check_dns: 'cloudflare-dns.com:53'
check_interval: 30s
# Now only support udp://IP:Port. Empty value '' indicates as-is.
# Value can be scheme://host:port or empty string ''.
# The scheme can be tcp/udp/tcp+udp. Empty string '' indicates as-is.
# If host is a domain and has both IPv4 and IPv6 record, dae will automatically choose
# IPv4 or IPv6 to use according to group policy (such as min latency policy).
# Please make sure DNS traffic will go through and be forwarded by dae.
# The upstream DNS answer MUST NOT be polluted.
# The request to dns upstream follows routing defined below.
dns_upstream: 'udp://8.8.8.8:53'
dns_upstream: 'tcp+udp://dns.google:53'
# The LAN interface to bind. Use it if you only want to proxy LAN instead of localhost.
# Multiple interfaces split by ",".
@ -72,7 +75,7 @@ routing {
# Write your rules below.
# dae arms DNS rush-answer filter so we can use 8.8.8.8 regardless of DNS pollution.
ip(8.8.8.8) && port(53) -> direct
domain(full:dns.google) && port(53) -> direct
pname(firefox) && domain(ip.sb) -> direct
pname(curl) && domain(ip.sb) -> my_group