feat: lazily init dns upstream and tcp/udp check to avoid fatal when start

This commit is contained in:
mzz2017
2023-02-09 19:22:37 +08:00
parent f0f68ffb84
commit 3060417be7
6 changed files with 346 additions and 141 deletions

View File

@ -22,18 +22,48 @@ import (
var ( var (
systemDnsMu sync.Mutex systemDnsMu sync.Mutex
systemDns netip.AddrPort systemDns netip.AddrPort
systemDnsNextUpdateAfter time.Time
) )
func TryUpdateSystemDns() (err error) {
systemDnsMu.Lock()
err = tryUpdateSystemDns()
systemDnsMu.Unlock()
return err
}
// TryUpdateSystemDns1s will update system DNS if 1 second has elapsed since the last TryUpdateSystemDns1s call.
func TryUpdateSystemDns1s() (err error) {
systemDnsMu.Lock()
defer systemDnsMu.Unlock()
if time.Now().Before(systemDnsNextUpdateAfter) {
return fmt.Errorf("update too quickly")
}
err = tryUpdateSystemDns()
if err != nil {
return err
}
systemDnsNextUpdateAfter = time.Now().Add(time.Second)
return nil
}
func tryUpdateSystemDns() (err error) {
dnsConf := dnsReadConfig("/etc/resolv.conf")
if len(dnsConf.servers) == 0 {
err = fmt.Errorf("no valid dns server in /etc/resolv.conf")
return err
}
systemDns = netip.MustParseAddrPort(dnsConf.servers[0])
return nil
}
func SystemDns() (dns netip.AddrPort, err error) { func SystemDns() (dns netip.AddrPort, err error) {
systemDnsMu.Lock() systemDnsMu.Lock()
defer systemDnsMu.Unlock() defer systemDnsMu.Unlock()
if !systemDns.IsValid() { if !systemDns.IsValid() {
dnsConf := dnsReadConfig("/etc/resolv.conf") if err = tryUpdateSystemDns(); err != nil {
if len(dnsConf.servers) == 0 {
err = fmt.Errorf("no valid dns server in /etc/resolv.conf")
return netip.AddrPort{}, err return netip.AddrPort{}, err
} }
systemDns = netip.MustParseAddrPort(dnsConf.servers[0])
} }
return systemDns, nil return systemDns, nil
} }

View File

@ -21,6 +21,7 @@ import (
"path" "path"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
) )
@ -77,6 +78,11 @@ func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOptio
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() {
if err != nil {
_ = netutils.TryUpdateSystemDns1s()
}
}()
u, err := url.Parse(rawURL) u, err := url.Parse(rawURL)
if err != nil { if err != nil {
@ -103,6 +109,11 @@ func ParseUdpCheckOption(ctx context.Context, dnsHostPort string) (opt *UdpCheck
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() {
if err != nil {
_ = netutils.TryUpdateSystemDns1s()
}
}()
host, _port, err := net.SplitHostPort(dnsHostPort) host, _port, err := net.SplitHostPort(dnsHostPort)
if err != nil { if err != nil {
@ -123,6 +134,48 @@ func ParseUdpCheckOption(ctx context.Context, dnsHostPort string) (opt *UdpCheck
}, nil }, nil
} }
type TcpCheckOptionRaw struct {
opt *TcpCheckOption
mu sync.Mutex
Raw string
}
func (c *TcpCheckOptionRaw) Option() (opt *TcpCheckOption, err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.opt == nil {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
tcpCheckOption, err := ParseTcpCheckOption(ctx, c.Raw)
if err != nil {
return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err)
}
c.opt = tcpCheckOption
}
return c.opt, nil
}
type UdpCheckOptionRaw struct {
opt *UdpCheckOption
mu sync.Mutex
Raw string
}
func (c *UdpCheckOptionRaw) Option() (opt *UdpCheckOption, err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.opt == nil {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
udpCheckOption, err := ParseUdpCheckOption(ctx, c.Raw)
if err != nil {
return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err)
}
c.opt = udpCheckOption
}
return c.opt, nil
}
type CheckOption struct { type CheckOption struct {
L4proto consts.L4ProtoStr L4proto consts.L4ProtoStr
IpVersion consts.IpVersionStr IpVersion consts.IpVersionStr
@ -146,28 +199,44 @@ func (d *Dialer) aliveBackground() {
L4proto: consts.L4ProtoStr_TCP, L4proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_4, IpVersion: consts.IpVersionStr_4,
CheckFunc: func(ctx context.Context) (ok bool, err error) { CheckFunc: func(ctx context.Context) (ok bool, err error) {
return d.HttpCheck(ctx, d.TcpCheckOption.Url, d.TcpCheckOption.Ip4) opt, err := d.TcpCheckOptionRaw.Option()
if err != nil {
return false, err
}
return d.HttpCheck(ctx, opt.Url, opt.Ip4)
}, },
} }
tcp6CheckOpt := &CheckOption{ tcp6CheckOpt := &CheckOption{
L4proto: consts.L4ProtoStr_TCP, L4proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_6, IpVersion: consts.IpVersionStr_6,
CheckFunc: func(ctx context.Context) (ok bool, err error) { CheckFunc: func(ctx context.Context) (ok bool, err error) {
return d.HttpCheck(ctx, d.TcpCheckOption.Url, d.TcpCheckOption.Ip6) opt, err := d.TcpCheckOptionRaw.Option()
if err != nil {
return false, err
}
return d.HttpCheck(ctx, opt.Url, opt.Ip6)
}, },
} }
udp4CheckOpt := &CheckOption{ udp4CheckOpt := &CheckOption{
L4proto: consts.L4ProtoStr_UDP, L4proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionStr_4, IpVersion: consts.IpVersionStr_4,
CheckFunc: func(ctx context.Context) (ok bool, err error) { CheckFunc: func(ctx context.Context) (ok bool, err error) {
return d.DnsCheck(ctx, netip.AddrPortFrom(d.UdpCheckOption.Ip4, d.UdpCheckOption.DnsPort)) opt, err := d.UdpCheckOptionRaw.Option()
if err != nil {
return false, err
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort))
}, },
} }
udp6CheckOpt := &CheckOption{ udp6CheckOpt := &CheckOption{
L4proto: consts.L4ProtoStr_UDP, L4proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionStr_6, IpVersion: consts.IpVersionStr_6,
CheckFunc: func(ctx context.Context) (ok bool, err error) { CheckFunc: func(ctx context.Context) (ok bool, err error) {
return d.DnsCheck(ctx, netip.AddrPortFrom(d.UdpCheckOption.Ip4, d.UdpCheckOption.DnsPort)) opt, err := d.UdpCheckOptionRaw.Option()
if err != nil {
return false, err
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort))
}, },
} }
// Check once immediately. // Check once immediately.
@ -176,25 +245,72 @@ func (d *Dialer) aliveBackground() {
go d.Check(timeout, tcp6CheckOpt) go d.Check(timeout, tcp6CheckOpt)
go d.Check(timeout, udp6CheckOpt) go d.Check(timeout, udp6CheckOpt)
ctx, cancel := context.WithCancel(d.ctx)
defer cancel()
go func() {
/// Splice ticker.C to checkCh.
// Sleep to avoid avalanche. // Sleep to avoid avalanche.
time.Sleep(time.Duration(fastrand.Int63n(int64(cycle)))) time.Sleep(time.Duration(fastrand.Int63n(int64(cycle))))
d.tickerMu.Lock() d.tickerMu.Lock()
d.ticker.Reset(cycle) d.ticker = time.NewTicker(cycle)
d.tickerMu.Unlock() d.tickerMu.Unlock()
for range d.ticker.C { for t := range d.ticker.C {
select {
case <-ctx.Done():
return
default:
d.checkCh <- t
}
}
}()
var wg sync.WaitGroup
for range d.checkCh {
// No need to test if there is no dialer selection policy using its latency. // No need to test if there is no dialer selection policy using its latency.
if len(d.mustGetCollection(consts.L4ProtoStr_TCP, consts.IpVersionStr_4).AliveDialerSetSet) > 0 { if len(d.mustGetCollection(consts.L4ProtoStr_TCP, consts.IpVersionStr_4).AliveDialerSetSet) > 0 {
go d.Check(timeout, tcp4CheckOpt) wg.Add(1)
go func() {
d.Check(timeout, tcp4CheckOpt)
wg.Done()
}()
} }
if len(d.mustGetCollection(consts.L4ProtoStr_TCP, consts.IpVersionStr_6).AliveDialerSetSet) > 0 { if len(d.mustGetCollection(consts.L4ProtoStr_TCP, consts.IpVersionStr_6).AliveDialerSetSet) > 0 {
go d.Check(timeout, tcp6CheckOpt) wg.Add(1)
go func() {
d.Check(timeout, tcp6CheckOpt)
wg.Done()
}()
} }
if len(d.mustGetCollection(consts.L4ProtoStr_UDP, consts.IpVersionStr_4).AliveDialerSetSet) > 0 { if len(d.mustGetCollection(consts.L4ProtoStr_UDP, consts.IpVersionStr_4).AliveDialerSetSet) > 0 {
go d.Check(timeout, udp4CheckOpt) wg.Add(1)
go func() {
d.Check(timeout, udp4CheckOpt)
wg.Done()
}()
} }
if len(d.mustGetCollection(consts.L4ProtoStr_UDP, consts.IpVersionStr_6).AliveDialerSetSet) > 0 { if len(d.mustGetCollection(consts.L4ProtoStr_UDP, consts.IpVersionStr_6).AliveDialerSetSet) > 0 {
go d.Check(timeout, udp6CheckOpt) wg.Add(1)
go func() {
d.Check(timeout, udp6CheckOpt)
wg.Done()
}()
} }
// Wait to block the loop.
wg.Wait()
}
}
// NotifyCheck will succeed only when CheckEnabled is true.
func (d *Dialer) NotifyCheck() {
select {
case <-d.ctx.Done():
return
default:
}
select {
// If fail to push elem to chan, the check is in process.
case d.checkCh <- time.Now():
default:
} }
} }

View File

@ -1,6 +1,7 @@
package dialer package dialer
import ( import (
"context"
"fmt" "fmt"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
@ -26,12 +27,15 @@ type Dialer struct {
tickerMu sync.Mutex tickerMu sync.Mutex
ticker *time.Ticker ticker *time.Ticker
checkCh chan time.Time
ctx context.Context
cancel context.CancelFunc
} }
type GlobalOption struct { type GlobalOption struct {
Log *logrus.Logger Log *logrus.Logger
TcpCheckOption *TcpCheckOption TcpCheckOptionRaw TcpCheckOptionRaw // Lazy parse
UdpCheckOption *UdpCheckOption UdpCheckOptionRaw UdpCheckOptionRaw // Lazy parse
CheckInterval time.Duration CheckInterval time.Duration
} }
@ -47,6 +51,7 @@ func NewDialer(dialer proxy.Dialer, option *GlobalOption, iOption InstanceOption
for i := range collections { for i := range collections {
collections[i] = newCollection() collections[i] = newCollection()
} }
ctx, cancel := context.WithCancel(context.Background())
d := &Dialer{ d := &Dialer{
GlobalOption: option, GlobalOption: option,
instanceOption: iOption, instanceOption: iOption,
@ -54,9 +59,13 @@ func NewDialer(dialer proxy.Dialer, option *GlobalOption, iOption InstanceOption
name: name, name: name,
protocol: protocol, protocol: protocol,
link: link, link: link,
collectionFineMu: sync.Mutex{},
collections: collections, collections: collections,
// Set a very big cycle to wait for init. tickerMu: sync.Mutex{},
ticker: time.NewTicker(time.Hour), ticker: nil,
checkCh: make(chan time.Time, 1),
ctx: ctx,
cancel: cancel,
} }
if iOption.CheckEnabled { if iOption.CheckEnabled {
go d.aliveBackground() go d.aliveBackground()
@ -65,9 +74,13 @@ func NewDialer(dialer proxy.Dialer, option *GlobalOption, iOption InstanceOption
} }
func (d *Dialer) Close() error { func (d *Dialer) Close() error {
d.cancel()
d.tickerMu.Lock() d.tickerMu.Lock()
if d.ticker != nil {
d.ticker.Stop() d.ticker.Stop()
}
d.tickerMu.Unlock() d.tickerMu.Unlock()
close(d.checkCh)
return nil return nil
} }

View File

@ -51,7 +51,7 @@ type ControlPlane struct {
// mutex protects the dnsCache. // mutex protects the dnsCache.
mutex sync.Mutex mutex sync.Mutex
dnsCache map[string]*dnsCache dnsCache map[string]*dnsCache
dnsUpstream *DnsUpstraem dnsUpstream DnsUpstreamRaw
} }
func NewControlPlane( func NewControlPlane(
@ -191,20 +191,10 @@ func NewControlPlane(
} }
/// DialerGroups (outbounds). /// DialerGroups (outbounds).
ctx, cancel := context.WithTimeout(context.TODO(), 30*time.Second)
defer cancel()
tcpCheckOption, err := dialer.ParseTcpCheckOption(ctx, global.TcpCheckUrl)
if err != nil {
return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err)
}
udpCheckOption, err := dialer.ParseUdpCheckOption(ctx, global.UdpCheckDns)
if err != nil {
return nil, fmt.Errorf("failed to parse udp_check_dns: %w", err)
}
option := &dialer.GlobalOption{ option := &dialer.GlobalOption{
Log: log, Log: log,
TcpCheckOption: tcpCheckOption, TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: global.TcpCheckUrl},
UdpCheckOption: udpCheckOption, UdpCheckOptionRaw: dialer.UdpCheckOptionRaw{Raw: global.UdpCheckDns},
CheckInterval: global.CheckInterval, CheckInterval: global.CheckInterval,
} }
outbounds := []*outbound.DialerGroup{ outbounds := []*outbound.DialerGroup{
@ -286,26 +276,55 @@ func NewControlPlane(
return nil, fmt.Errorf("RoutingMatcherBuilder.Build: %w", err) return nil, fmt.Errorf("RoutingMatcherBuilder.Build: %w", err)
} }
/// DNS upstream. /// Listen address.
var dnsUpstream *DnsUpstraem listenIp := "::1"
if !global.DnsUpstream.Empty { if len(global.WanInterface) > 0 {
if dnsUpstream, err = ResolveDnsUpstream(ctx, global.DnsUpstream.Url); err != nil { listenIp = "0.0.0.0"
return nil, err
} }
c = &ControlPlane{
log: log,
core: core,
deferFuncs: nil,
listenIp: listenIp,
outbounds: outbounds,
outboundName2Id: outboundName2Id,
SimulatedLpmTries: builder.SimulatedLpmTries,
SimulatedDomainSet: builder.SimulatedDomainSet,
Final: routingA.Final,
mutex: sync.Mutex{},
dnsCache: make(map[string]*dnsCache),
dnsUpstream: DnsUpstreamRaw{
Raw: global.DnsUpstream,
FinishInitCallback: nil,
},
}
c.dnsUpstream.FinishInitCallback = c.finishInitDnsUpstreamResolve
return c, nil
}
func (c *ControlPlane) finishInitDnsUpstreamResolve(raw common.UrlOrEmpty, dnsUpstream *DnsUpstream) (err error) {
/// Notify dialers to check.
for _, out := range c.outbounds {
for _, d := range out.Dialers {
d.NotifyCheck()
}
}
/// Updates dns cache to support domain routing for hostname of dns_upstream.
if !raw.Empty {
ip4in6 := dnsUpstream.Ip4.As16() ip4in6 := dnsUpstream.Ip4.As16()
ip6 := dnsUpstream.Ip6.As16() ip6 := dnsUpstream.Ip6.As16()
if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{ if err = c.core.bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{
Ip4: common.Ipv6ByteSliceToUint32Array(ip4in6[:]), Ip4: common.Ipv6ByteSliceToUint32Array(ip4in6[:]),
Ip6: common.Ipv6ByteSliceToUint32Array(ip6[:]), Ip6: common.Ipv6ByteSliceToUint32Array(ip6[:]),
HasIp4: dnsUpstream.Ip4.IsValid(), HasIp4: dnsUpstream.Ip4.IsValid(),
HasIp6: dnsUpstream.Ip6.IsValid(), HasIp6: dnsUpstream.Ip6.IsValid(),
Port: internal.Htons(dnsUpstream.Port), Port: internal.Htons(dnsUpstream.Port),
}, ebpf.UpdateAny); err != nil { }, ebpf.UpdateAny); err != nil {
return nil, err return err
} }
defer func() { /// Update dns cache to support domain routing for hostname of dns_upstream.
// Update dns cache to support domain routing for hostname of dns_upstream.
if err == nil {
// Ten years later. // Ten years later.
deadline := time.Now().Add(24 * time.Hour * 365 * 10) deadline := time.Now().Add(24 * time.Hour * 365 * 10)
fqdn := dnsUpstream.Hostname fqdn := dnsUpstream.Hostname
@ -350,11 +369,9 @@ func NewControlPlane(
return return
} }
} }
}
}()
} else { } else {
// Empty string. As-is. // Empty string. As-is.
if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{ if err = c.core.bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{
Ip4: [4]uint32{}, Ip4: [4]uint32{},
Ip6: [4]uint32{}, Ip6: [4]uint32{},
HasIp4: false, HasIp4: false,
@ -362,29 +379,10 @@ func NewControlPlane(
// Zero port indicates no element, because bpf_map_lookup_elem cannot return 0 for map_type_array. // Zero port indicates no element, because bpf_map_lookup_elem cannot return 0 for map_type_array.
Port: 0, Port: 0,
}, ebpf.UpdateAny); err != nil { }, ebpf.UpdateAny); err != nil {
return nil, err return err
} }
} }
return nil
/// Listen address.
listenIp := "::1"
if len(global.WanInterface) > 0 {
listenIp = "0.0.0.0"
}
return &ControlPlane{
log: log,
core: core,
deferFuncs: nil,
listenIp: listenIp,
outbounds: outbounds,
outboundName2Id: outboundName2Id,
SimulatedLpmTries: builder.SimulatedLpmTries,
SimulatedDomainSet: builder.SimulatedDomainSet,
Final: routingA.Final,
mutex: sync.Mutex{},
dnsCache: make(map[string]*dnsCache),
dnsUpstream: dnsUpstream,
}, nil
} }
func (c *ControlPlane) ListenAndServe(port uint16) (err error) { func (c *ControlPlane) ListenAndServe(port uint16) (err error) {

View File

@ -8,11 +8,14 @@ package control
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts" "github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/common/netutils" "github.com/v2rayA/dae/common/netutils"
"github.com/v2rayA/dae/component/outbound/dialer" "github.com/v2rayA/dae/component/outbound/dialer"
"net/url" "net/url"
"strconv" "strconv"
"sync"
"time"
) )
type DnsUpstreamScheme string type DnsUpstreamScheme string
@ -23,14 +26,14 @@ const (
DnsUpstreamScheme_TCP_UDP DnsUpstreamScheme = "tcp+udp" DnsUpstreamScheme_TCP_UDP DnsUpstreamScheme = "tcp+udp"
) )
type DnsUpstraem struct { type DnsUpstream struct {
Scheme DnsUpstreamScheme Scheme DnsUpstreamScheme
Hostname string Hostname string
Port uint16 Port uint16
*netutils.Ip46 *netutils.Ip46
} }
func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstraem, err error) { func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstream, err error) {
var _port string var _port string
switch DnsUpstreamScheme(dnsUpstream.Scheme) { switch DnsUpstreamScheme(dnsUpstream.Scheme) {
case DnsUpstreamScheme_TCP, DnsUpstreamScheme_UDP, DnsUpstreamScheme_TCP_UDP: case DnsUpstreamScheme_TCP, DnsUpstreamScheme_UDP, DnsUpstreamScheme_TCP_UDP:
@ -46,6 +49,12 @@ func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstr
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() {
if err != nil {
_ = netutils.TryUpdateSystemDns1s()
}
}()
port, err := strconv.ParseUint(dnsUpstream.Port(), 10, 16) port, err := strconv.ParseUint(dnsUpstream.Port(), 10, 16)
if err != nil { if err != nil {
return nil, fmt.Errorf("parse dns_upstream port: %v", err) return nil, fmt.Errorf("parse dns_upstream port: %v", err)
@ -53,12 +62,12 @@ func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstr
hostname := dnsUpstream.Hostname() hostname := dnsUpstream.Hostname()
ip46, err := netutils.ParseIp46(ctx, dialer.SymmetricDirect, systemDns, hostname, false) ip46, err := netutils.ParseIp46(ctx, dialer.SymmetricDirect, systemDns, hostname, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to resolve dns_upstream") return nil, fmt.Errorf("failed to resolve dns_upstream: %w", err)
} }
if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() { if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() {
return nil, fmt.Errorf("dns_upstream has no record") return nil, fmt.Errorf("dns_upstream has no record")
} }
return &DnsUpstraem{ return &DnsUpstream{
Scheme: DnsUpstreamScheme(dnsUpstream.Scheme), Scheme: DnsUpstreamScheme(dnsUpstream.Scheme),
Hostname: hostname, Hostname: hostname,
Port: uint16(port), Port: uint16(port),
@ -66,7 +75,7 @@ func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstr
}, nil }, nil
} }
func (u *DnsUpstraem) SupportedNetworks() (ipversions []consts.IpVersionStr, l4protos []consts.L4ProtoStr) { func (u *DnsUpstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4protos []consts.L4ProtoStr) {
if u.Ip4.IsValid() && u.Ip6.IsValid() { if u.Ip4.IsValid() && u.Ip6.IsValid() {
ipversions = []consts.IpVersionStr{consts.IpVersionStr_4, consts.IpVersionStr_6} ipversions = []consts.IpVersionStr{consts.IpVersionStr_4, consts.IpVersionStr_6}
} else { } else {
@ -87,3 +96,38 @@ func (u *DnsUpstraem) SupportedNetworks() (ipversions []consts.IpVersionStr, l4p
} }
return ipversions, l4protos return ipversions, l4protos
} }
type DnsUpstreamRaw struct {
Raw common.UrlOrEmpty
// FinishInitCallback may be invoked again if err is not nil
FinishInitCallback func(raw common.UrlOrEmpty, upstream *DnsUpstream) (err error)
mu sync.Mutex
upstream *DnsUpstream
init bool
}
func (u *DnsUpstreamRaw) Upstream() (_ *DnsUpstream, err error) {
u.mu.Lock()
defer u.mu.Unlock()
if !u.init {
defer func() {
if err == nil {
if err = u.FinishInitCallback(u.Raw, u.upstream); err != nil {
u.upstream = nil
return
}
u.init = true
}
}()
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
if !u.Raw.Empty {
if u.upstream, err = ResolveDnsUpstream(ctx, u.Raw.Url); err != nil {
return nil, fmt.Errorf("failed to init dns upstream: %v", err)
}
} else {
// Empty string. As-is.
}
}
return u.upstream, nil
}

View File

@ -179,13 +179,17 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI
// For DNS request, modify dst to dns upstream. // For DNS request, modify dst to dns upstream.
// NOTICE: We might modify l4proto and ipversion. // NOTICE: We might modify l4proto and ipversion.
if isDns && c.dnsUpstream != nil { dnsUpstream, err := c.dnsUpstream.Upstream()
if err != nil {
return err
}
if isDns && dnsUpstream != nil {
// Modify dns target to upstream. // Modify dns target to upstream.
// NOTICE: Routing was calculated in advance by the eBPF program. // NOTICE: Routing was calculated in advance by the eBPF program.
/// Choose the best l4proto and ipversion. /// Choose the best l4proto and ipversion.
// Get available ipversions and l4protos for DNS upstream. // Get available ipversions and l4protos for DNS upstream.
ipversions, l4protos := c.dnsUpstream.SupportedNetworks() ipversions, l4protos := dnsUpstream.SupportedNetworks()
var ( var (
bestDialer *dialer.Dialer bestDialer *dialer.Dialer
bestLatency time.Duration bestLatency time.Duration
@ -219,9 +223,9 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI
} }
switch ipversion { switch ipversion {
case consts.IpVersionStr_4: case consts.IpVersionStr_4:
bestTarget = netip.AddrPortFrom(c.dnsUpstream.Ip4, c.dnsUpstream.Port) bestTarget = netip.AddrPortFrom(dnsUpstream.Ip4, dnsUpstream.Port)
case consts.IpVersionStr_6: case consts.IpVersionStr_6:
bestTarget = netip.AddrPortFrom(c.dnsUpstream.Ip6, c.dnsUpstream.Port) bestTarget = netip.AddrPortFrom(dnsUpstream.Ip6, dnsUpstream.Port)
} }
dialerForNew = bestDialer dialerForNew = bestDialer
dummyFrom = &dst dummyFrom = &dst