feat: lazily init dns upstream and tcp/udp check to avoid fatal when start

This commit is contained in:
mzz2017 2023-02-09 19:22:37 +08:00
parent f0f68ffb84
commit 3060417be7
6 changed files with 346 additions and 141 deletions

View File

@ -20,20 +20,50 @@ import (
)
var (
systemDnsMu sync.Mutex
systemDns netip.AddrPort
systemDnsMu sync.Mutex
systemDns netip.AddrPort
systemDnsNextUpdateAfter time.Time
)
func TryUpdateSystemDns() (err error) {
systemDnsMu.Lock()
err = tryUpdateSystemDns()
systemDnsMu.Unlock()
return err
}
// TryUpdateSystemDns1s will update system DNS if 1 second has elapsed since the last TryUpdateSystemDns1s call.
func TryUpdateSystemDns1s() (err error) {
systemDnsMu.Lock()
defer systemDnsMu.Unlock()
if time.Now().Before(systemDnsNextUpdateAfter) {
return fmt.Errorf("update too quickly")
}
err = tryUpdateSystemDns()
if err != nil {
return err
}
systemDnsNextUpdateAfter = time.Now().Add(time.Second)
return nil
}
func tryUpdateSystemDns() (err error) {
dnsConf := dnsReadConfig("/etc/resolv.conf")
if len(dnsConf.servers) == 0 {
err = fmt.Errorf("no valid dns server in /etc/resolv.conf")
return err
}
systemDns = netip.MustParseAddrPort(dnsConf.servers[0])
return nil
}
func SystemDns() (dns netip.AddrPort, err error) {
systemDnsMu.Lock()
defer systemDnsMu.Unlock()
if !systemDns.IsValid() {
dnsConf := dnsReadConfig("/etc/resolv.conf")
if len(dnsConf.servers) == 0 {
err = fmt.Errorf("no valid dns server in /etc/resolv.conf")
if err = tryUpdateSystemDns(); err != nil {
return netip.AddrPort{}, err
}
systemDns = netip.MustParseAddrPort(dnsConf.servers[0])
}
return systemDns, nil
}

View File

@ -21,6 +21,7 @@ import (
"path"
"strconv"
"strings"
"sync"
"time"
)
@ -77,6 +78,11 @@ func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOptio
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = netutils.TryUpdateSystemDns1s()
}
}()
u, err := url.Parse(rawURL)
if err != nil {
@ -103,6 +109,11 @@ func ParseUdpCheckOption(ctx context.Context, dnsHostPort string) (opt *UdpCheck
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = netutils.TryUpdateSystemDns1s()
}
}()
host, _port, err := net.SplitHostPort(dnsHostPort)
if err != nil {
@ -123,6 +134,48 @@ func ParseUdpCheckOption(ctx context.Context, dnsHostPort string) (opt *UdpCheck
}, nil
}
type TcpCheckOptionRaw struct {
opt *TcpCheckOption
mu sync.Mutex
Raw string
}
func (c *TcpCheckOptionRaw) Option() (opt *TcpCheckOption, err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.opt == nil {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
tcpCheckOption, err := ParseTcpCheckOption(ctx, c.Raw)
if err != nil {
return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err)
}
c.opt = tcpCheckOption
}
return c.opt, nil
}
type UdpCheckOptionRaw struct {
opt *UdpCheckOption
mu sync.Mutex
Raw string
}
func (c *UdpCheckOptionRaw) Option() (opt *UdpCheckOption, err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.opt == nil {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
udpCheckOption, err := ParseUdpCheckOption(ctx, c.Raw)
if err != nil {
return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err)
}
c.opt = udpCheckOption
}
return c.opt, nil
}
type CheckOption struct {
L4proto consts.L4ProtoStr
IpVersion consts.IpVersionStr
@ -146,28 +199,44 @@ func (d *Dialer) aliveBackground() {
L4proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_4,
CheckFunc: func(ctx context.Context) (ok bool, err error) {
return d.HttpCheck(ctx, d.TcpCheckOption.Url, d.TcpCheckOption.Ip4)
opt, err := d.TcpCheckOptionRaw.Option()
if err != nil {
return false, err
}
return d.HttpCheck(ctx, opt.Url, opt.Ip4)
},
}
tcp6CheckOpt := &CheckOption{
L4proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_6,
CheckFunc: func(ctx context.Context) (ok bool, err error) {
return d.HttpCheck(ctx, d.TcpCheckOption.Url, d.TcpCheckOption.Ip6)
opt, err := d.TcpCheckOptionRaw.Option()
if err != nil {
return false, err
}
return d.HttpCheck(ctx, opt.Url, opt.Ip6)
},
}
udp4CheckOpt := &CheckOption{
L4proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionStr_4,
CheckFunc: func(ctx context.Context) (ok bool, err error) {
return d.DnsCheck(ctx, netip.AddrPortFrom(d.UdpCheckOption.Ip4, d.UdpCheckOption.DnsPort))
opt, err := d.UdpCheckOptionRaw.Option()
if err != nil {
return false, err
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort))
},
}
udp6CheckOpt := &CheckOption{
L4proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionStr_6,
CheckFunc: func(ctx context.Context) (ok bool, err error) {
return d.DnsCheck(ctx, netip.AddrPortFrom(d.UdpCheckOption.Ip4, d.UdpCheckOption.DnsPort))
opt, err := d.UdpCheckOptionRaw.Option()
if err != nil {
return false, err
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort))
},
}
// Check once immediately.
@ -176,25 +245,72 @@ func (d *Dialer) aliveBackground() {
go d.Check(timeout, tcp6CheckOpt)
go d.Check(timeout, udp6CheckOpt)
// Sleep to avoid avalanche.
time.Sleep(time.Duration(fastrand.Int63n(int64(cycle))))
d.tickerMu.Lock()
d.ticker.Reset(cycle)
d.tickerMu.Unlock()
for range d.ticker.C {
ctx, cancel := context.WithCancel(d.ctx)
defer cancel()
go func() {
/// Splice ticker.C to checkCh.
// Sleep to avoid avalanche.
time.Sleep(time.Duration(fastrand.Int63n(int64(cycle))))
d.tickerMu.Lock()
d.ticker = time.NewTicker(cycle)
d.tickerMu.Unlock()
for t := range d.ticker.C {
select {
case <-ctx.Done():
return
default:
d.checkCh <- t
}
}
}()
var wg sync.WaitGroup
for range d.checkCh {
// No need to test if there is no dialer selection policy using its latency.
if len(d.mustGetCollection(consts.L4ProtoStr_TCP, consts.IpVersionStr_4).AliveDialerSetSet) > 0 {
go d.Check(timeout, tcp4CheckOpt)
wg.Add(1)
go func() {
d.Check(timeout, tcp4CheckOpt)
wg.Done()
}()
}
if len(d.mustGetCollection(consts.L4ProtoStr_TCP, consts.IpVersionStr_6).AliveDialerSetSet) > 0 {
go d.Check(timeout, tcp6CheckOpt)
wg.Add(1)
go func() {
d.Check(timeout, tcp6CheckOpt)
wg.Done()
}()
}
if len(d.mustGetCollection(consts.L4ProtoStr_UDP, consts.IpVersionStr_4).AliveDialerSetSet) > 0 {
go d.Check(timeout, udp4CheckOpt)
wg.Add(1)
go func() {
d.Check(timeout, udp4CheckOpt)
wg.Done()
}()
}
if len(d.mustGetCollection(consts.L4ProtoStr_UDP, consts.IpVersionStr_6).AliveDialerSetSet) > 0 {
go d.Check(timeout, udp6CheckOpt)
wg.Add(1)
go func() {
d.Check(timeout, udp6CheckOpt)
wg.Done()
}()
}
// Wait to block the loop.
wg.Wait()
}
}
// NotifyCheck will succeed only when CheckEnabled is true.
func (d *Dialer) NotifyCheck() {
select {
case <-d.ctx.Done():
return
default:
}
select {
// If fail to push elem to chan, the check is in process.
case d.checkCh <- time.Now():
default:
}
}

View File

@ -1,6 +1,7 @@
package dialer
import (
"context"
"fmt"
"github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
@ -26,13 +27,16 @@ type Dialer struct {
tickerMu sync.Mutex
ticker *time.Ticker
checkCh chan time.Time
ctx context.Context
cancel context.CancelFunc
}
type GlobalOption struct {
Log *logrus.Logger
TcpCheckOption *TcpCheckOption
UdpCheckOption *UdpCheckOption
CheckInterval time.Duration
Log *logrus.Logger
TcpCheckOptionRaw TcpCheckOptionRaw // Lazy parse
UdpCheckOptionRaw UdpCheckOptionRaw // Lazy parse
CheckInterval time.Duration
}
type InstanceOption struct {
@ -47,16 +51,21 @@ func NewDialer(dialer proxy.Dialer, option *GlobalOption, iOption InstanceOption
for i := range collections {
collections[i] = newCollection()
}
ctx, cancel := context.WithCancel(context.Background())
d := &Dialer{
GlobalOption: option,
instanceOption: iOption,
Dialer: dialer,
name: name,
protocol: protocol,
link: link,
collections: collections,
// Set a very big cycle to wait for init.
ticker: time.NewTicker(time.Hour),
GlobalOption: option,
instanceOption: iOption,
Dialer: dialer,
name: name,
protocol: protocol,
link: link,
collectionFineMu: sync.Mutex{},
collections: collections,
tickerMu: sync.Mutex{},
ticker: nil,
checkCh: make(chan time.Time, 1),
ctx: ctx,
cancel: cancel,
}
if iOption.CheckEnabled {
go d.aliveBackground()
@ -65,9 +74,13 @@ func NewDialer(dialer proxy.Dialer, option *GlobalOption, iOption InstanceOption
}
func (d *Dialer) Close() error {
d.cancel()
d.tickerMu.Lock()
d.ticker.Stop()
if d.ticker != nil {
d.ticker.Stop()
}
d.tickerMu.Unlock()
close(d.checkCh)
return nil
}

View File

@ -51,7 +51,7 @@ type ControlPlane struct {
// mutex protects the dnsCache.
mutex sync.Mutex
dnsCache map[string]*dnsCache
dnsUpstream *DnsUpstraem
dnsUpstream DnsUpstreamRaw
}
func NewControlPlane(
@ -191,21 +191,11 @@ func NewControlPlane(
}
/// DialerGroups (outbounds).
ctx, cancel := context.WithTimeout(context.TODO(), 30*time.Second)
defer cancel()
tcpCheckOption, err := dialer.ParseTcpCheckOption(ctx, global.TcpCheckUrl)
if err != nil {
return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err)
}
udpCheckOption, err := dialer.ParseUdpCheckOption(ctx, global.UdpCheckDns)
if err != nil {
return nil, fmt.Errorf("failed to parse udp_check_dns: %w", err)
}
option := &dialer.GlobalOption{
Log: log,
TcpCheckOption: tcpCheckOption,
UdpCheckOption: udpCheckOption,
CheckInterval: global.CheckInterval,
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: global.TcpCheckUrl},
UdpCheckOptionRaw: dialer.UdpCheckOptionRaw{Raw: global.UdpCheckDns},
CheckInterval: global.CheckInterval,
}
outbounds := []*outbound.DialerGroup{
outbound.NewDialerGroup(option, consts.OutboundDirect.String(),
@ -286,92 +276,13 @@ func NewControlPlane(
return nil, fmt.Errorf("RoutingMatcherBuilder.Build: %w", err)
}
/// DNS upstream.
var dnsUpstream *DnsUpstraem
if !global.DnsUpstream.Empty {
if dnsUpstream, err = ResolveDnsUpstream(ctx, global.DnsUpstream.Url); err != nil {
return nil, err
}
ip4in6 := dnsUpstream.Ip4.As16()
ip6 := dnsUpstream.Ip6.As16()
if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{
Ip4: common.Ipv6ByteSliceToUint32Array(ip4in6[:]),
Ip6: common.Ipv6ByteSliceToUint32Array(ip6[:]),
HasIp4: dnsUpstream.Ip4.IsValid(),
HasIp6: dnsUpstream.Ip6.IsValid(),
Port: internal.Htons(dnsUpstream.Port),
}, ebpf.UpdateAny); err != nil {
return nil, err
}
defer func() {
// Update dns cache to support domain routing for hostname of dns_upstream.
if err == nil {
// Ten years later.
deadline := time.Now().Add(24 * time.Hour * 365 * 10)
fqdn := dnsUpstream.Hostname
if !strings.HasSuffix(fqdn, ".") {
fqdn = fqdn + "."
}
if dnsUpstream.Ip4.IsValid() {
typ := dnsmessage.TypeA
answers := []dnsmessage.Resource{{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(fqdn),
Type: typ,
Class: dnsmessage.ClassINET,
TTL: 0, // Must be zero.
},
Body: &dnsmessage.AResource{
A: dnsUpstream.Ip4.As4(),
},
}}
if err = c.UpdateDnsCache(fqdn, typ, answers, deadline); err != nil {
c = nil
return
}
}
if dnsUpstream.Ip6.IsValid() {
typ := dnsmessage.TypeAAAA
answers := []dnsmessage.Resource{{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(fqdn),
Type: typ,
Class: dnsmessage.ClassINET,
TTL: 0, // Must be zero.
},
Body: &dnsmessage.AAAAResource{
AAAA: dnsUpstream.Ip6.As16(),
},
}}
if err = c.UpdateDnsCache(fqdn, typ, answers, deadline); err != nil {
c = nil
return
}
}
}
}()
} else {
// Empty string. As-is.
if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{
Ip4: [4]uint32{},
Ip6: [4]uint32{},
HasIp4: false,
HasIp6: false,
// Zero port indicates no element, because bpf_map_lookup_elem cannot return 0 for map_type_array.
Port: 0,
}, ebpf.UpdateAny); err != nil {
return nil, err
}
}
/// Listen address.
listenIp := "::1"
if len(global.WanInterface) > 0 {
listenIp = "0.0.0.0"
}
return &ControlPlane{
c = &ControlPlane{
log: log,
core: core,
deferFuncs: nil,
@ -383,8 +294,95 @@ func NewControlPlane(
Final: routingA.Final,
mutex: sync.Mutex{},
dnsCache: make(map[string]*dnsCache),
dnsUpstream: dnsUpstream,
}, nil
dnsUpstream: DnsUpstreamRaw{
Raw: global.DnsUpstream,
FinishInitCallback: nil,
},
}
c.dnsUpstream.FinishInitCallback = c.finishInitDnsUpstreamResolve
return c, nil
}
func (c *ControlPlane) finishInitDnsUpstreamResolve(raw common.UrlOrEmpty, dnsUpstream *DnsUpstream) (err error) {
/// Notify dialers to check.
for _, out := range c.outbounds {
for _, d := range out.Dialers {
d.NotifyCheck()
}
}
/// Updates dns cache to support domain routing for hostname of dns_upstream.
if !raw.Empty {
ip4in6 := dnsUpstream.Ip4.As16()
ip6 := dnsUpstream.Ip6.As16()
if err = c.core.bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{
Ip4: common.Ipv6ByteSliceToUint32Array(ip4in6[:]),
Ip6: common.Ipv6ByteSliceToUint32Array(ip6[:]),
HasIp4: dnsUpstream.Ip4.IsValid(),
HasIp6: dnsUpstream.Ip6.IsValid(),
Port: internal.Htons(dnsUpstream.Port),
}, ebpf.UpdateAny); err != nil {
return err
}
/// Update dns cache to support domain routing for hostname of dns_upstream.
// Ten years later.
deadline := time.Now().Add(24 * time.Hour * 365 * 10)
fqdn := dnsUpstream.Hostname
if !strings.HasSuffix(fqdn, ".") {
fqdn = fqdn + "."
}
if dnsUpstream.Ip4.IsValid() {
typ := dnsmessage.TypeA
answers := []dnsmessage.Resource{{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(fqdn),
Type: typ,
Class: dnsmessage.ClassINET,
TTL: 0, // Must be zero.
},
Body: &dnsmessage.AResource{
A: dnsUpstream.Ip4.As4(),
},
}}
if err = c.UpdateDnsCache(fqdn, typ, answers, deadline); err != nil {
c = nil
return
}
}
if dnsUpstream.Ip6.IsValid() {
typ := dnsmessage.TypeAAAA
answers := []dnsmessage.Resource{{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(fqdn),
Type: typ,
Class: dnsmessage.ClassINET,
TTL: 0, // Must be zero.
},
Body: &dnsmessage.AAAAResource{
AAAA: dnsUpstream.Ip6.As16(),
},
}}
if err = c.UpdateDnsCache(fqdn, typ, answers, deadline); err != nil {
c = nil
return
}
}
} else {
// Empty string. As-is.
if err = c.core.bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{
Ip4: [4]uint32{},
Ip6: [4]uint32{},
HasIp4: false,
HasIp6: false,
// Zero port indicates no element, because bpf_map_lookup_elem cannot return 0 for map_type_array.
Port: 0,
}, ebpf.UpdateAny); err != nil {
return err
}
}
return nil
}
func (c *ControlPlane) ListenAndServe(port uint16) (err error) {

View File

@ -8,11 +8,14 @@ package control
import (
"context"
"fmt"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/common/netutils"
"github.com/v2rayA/dae/component/outbound/dialer"
"net/url"
"strconv"
"sync"
"time"
)
type DnsUpstreamScheme string
@ -23,14 +26,14 @@ const (
DnsUpstreamScheme_TCP_UDP DnsUpstreamScheme = "tcp+udp"
)
type DnsUpstraem struct {
type DnsUpstream struct {
Scheme DnsUpstreamScheme
Hostname string
Port uint16
*netutils.Ip46
}
func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstraem, err error) {
func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstream, err error) {
var _port string
switch DnsUpstreamScheme(dnsUpstream.Scheme) {
case DnsUpstreamScheme_TCP, DnsUpstreamScheme_UDP, DnsUpstreamScheme_TCP_UDP:
@ -46,6 +49,12 @@ func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstr
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = netutils.TryUpdateSystemDns1s()
}
}()
port, err := strconv.ParseUint(dnsUpstream.Port(), 10, 16)
if err != nil {
return nil, fmt.Errorf("parse dns_upstream port: %v", err)
@ -53,12 +62,12 @@ func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstr
hostname := dnsUpstream.Hostname()
ip46, err := netutils.ParseIp46(ctx, dialer.SymmetricDirect, systemDns, hostname, false)
if err != nil {
return nil, fmt.Errorf("failed to resolve dns_upstream")
return nil, fmt.Errorf("failed to resolve dns_upstream: %w", err)
}
if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() {
return nil, fmt.Errorf("dns_upstream has no record")
}
return &DnsUpstraem{
return &DnsUpstream{
Scheme: DnsUpstreamScheme(dnsUpstream.Scheme),
Hostname: hostname,
Port: uint16(port),
@ -66,7 +75,7 @@ func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstr
}, nil
}
func (u *DnsUpstraem) SupportedNetworks() (ipversions []consts.IpVersionStr, l4protos []consts.L4ProtoStr) {
func (u *DnsUpstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4protos []consts.L4ProtoStr) {
if u.Ip4.IsValid() && u.Ip6.IsValid() {
ipversions = []consts.IpVersionStr{consts.IpVersionStr_4, consts.IpVersionStr_6}
} else {
@ -87,3 +96,38 @@ func (u *DnsUpstraem) SupportedNetworks() (ipversions []consts.IpVersionStr, l4p
}
return ipversions, l4protos
}
type DnsUpstreamRaw struct {
Raw common.UrlOrEmpty
// FinishInitCallback may be invoked again if err is not nil
FinishInitCallback func(raw common.UrlOrEmpty, upstream *DnsUpstream) (err error)
mu sync.Mutex
upstream *DnsUpstream
init bool
}
func (u *DnsUpstreamRaw) Upstream() (_ *DnsUpstream, err error) {
u.mu.Lock()
defer u.mu.Unlock()
if !u.init {
defer func() {
if err == nil {
if err = u.FinishInitCallback(u.Raw, u.upstream); err != nil {
u.upstream = nil
return
}
u.init = true
}
}()
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
if !u.Raw.Empty {
if u.upstream, err = ResolveDnsUpstream(ctx, u.Raw.Url); err != nil {
return nil, fmt.Errorf("failed to init dns upstream: %v", err)
}
} else {
// Empty string. As-is.
}
}
return u.upstream, nil
}

View File

@ -179,13 +179,17 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI
// For DNS request, modify dst to dns upstream.
// NOTICE: We might modify l4proto and ipversion.
if isDns && c.dnsUpstream != nil {
dnsUpstream, err := c.dnsUpstream.Upstream()
if err != nil {
return err
}
if isDns && dnsUpstream != nil {
// Modify dns target to upstream.
// NOTICE: Routing was calculated in advance by the eBPF program.
/// Choose the best l4proto and ipversion.
// Get available ipversions and l4protos for DNS upstream.
ipversions, l4protos := c.dnsUpstream.SupportedNetworks()
ipversions, l4protos := dnsUpstream.SupportedNetworks()
var (
bestDialer *dialer.Dialer
bestLatency time.Duration
@ -219,9 +223,9 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI
}
switch ipversion {
case consts.IpVersionStr_4:
bestTarget = netip.AddrPortFrom(c.dnsUpstream.Ip4, c.dnsUpstream.Port)
bestTarget = netip.AddrPortFrom(dnsUpstream.Ip4, dnsUpstream.Port)
case consts.IpVersionStr_6:
bestTarget = netip.AddrPortFrom(c.dnsUpstream.Ip6, c.dnsUpstream.Port)
bestTarget = netip.AddrPortFrom(dnsUpstream.Ip6, dnsUpstream.Port)
}
dialerForNew = bestDialer
dummyFrom = &dst