feat: support to check independent tcp dns connectivity

This commit is contained in:
mzz2017 2023-02-12 15:39:00 +08:00
parent c43b6887d7
commit 4c2f936fa4
15 changed files with 508 additions and 315 deletions

View File

@ -5,7 +5,10 @@
package consts
import "net/netip"
import (
"golang.org/x/sys/unix"
"net/netip"
)
type DialerSelectionPolicy string
@ -27,6 +30,16 @@ const (
L4ProtoStr_UDP L4ProtoStr = "udp"
)
func (l L4ProtoStr) ToL4Proto() uint8 {
switch l {
case L4ProtoStr_TCP:
return unix.IPPROTO_TCP
case L4ProtoStr_UDP:
return unix.IPPROTO_IDP
}
panic("unsupported l4proto")
}
type IpVersionStr string
const (
@ -34,6 +47,16 @@ const (
IpVersionStr_6 IpVersionStr = "6"
)
func (v IpVersionStr) ToIpVersion() uint8 {
switch v {
case IpVersionStr_4:
return 4
case IpVersionStr_6:
return 6
}
panic("unsupported ipversion")
}
func IpVersionFromAddr(addr netip.Addr) IpVersionStr {
var ipversion IpVersionStr
switch {

View File

@ -7,11 +7,13 @@ package netutils
import (
"context"
"encoding/binary"
"fmt"
"github.com/mzz2017/softwind/pkg/fastrand"
"github.com/mzz2017/softwind/pool"
"golang.org/x/net/dns/dnsmessage"
"golang.org/x/net/proxy"
"io"
"math"
"net/netip"
"strings"
@ -68,7 +70,7 @@ func SystemDns() (dns netip.AddrPort, err error) {
return systemDns, nil
}
func ResolveNetip(ctx context.Context, d proxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type) (addrs []netip.Addr, err error) {
func ResolveNetip(ctx context.Context, d proxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, tcp bool) (addrs []netip.Addr, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
if addr, err := netip.ParseAddr(host); err == nil {
@ -111,10 +113,23 @@ func ResolveNetip(ctx context.Context, d proxy.Dialer, dns netip.AddrPort, host
if err != nil {
return nil, err
}
if tcp {
buf := pool.Get(2 + len(b))
defer pool.Put(buf)
binary.BigEndian.PutUint16(buf, uint16(len(b)))
copy(buf[2:], b)
b = buf
}
// Dial and write.
var network string
if tcp {
network = "tcp"
} else {
network = "udp"
}
cd := ContextDialer{d}
c, err := cd.DialContext(ctx, "udp", dns.String())
c, err := cd.DialContext(ctx, network, dns.String())
if err != nil {
return nil, err
}
@ -124,24 +139,40 @@ func ResolveNetip(ctx context.Context, d proxy.Dialer, dns netip.AddrPort, host
return nil, err
}
ch := make(chan error, 2)
go func() {
// Resend every 3 seconds.
for {
select {
case <-ctx.Done():
return
default:
time.Sleep(3 * time.Second)
if !tcp {
go func() {
// Resend every 3 seconds for UDP.
for {
select {
case <-ctx.Done():
return
default:
time.Sleep(3 * time.Second)
}
_, err := c.Write(b)
if err != nil {
ch <- err
return
}
}
_, err := c.Write(b)
}()
}
go func() {
buf := pool.Get(512)
defer pool.Put(buf)
if tcp {
_, err := io.ReadFull(c, buf[:2])
if err != nil {
ch <- err
return
}
n := binary.BigEndian.Uint16(buf)
if n > 512 {
ch <- fmt.Errorf("too big dns resp")
return
}
buf = buf[:n]
}
}()
go func() {
buf := pool.Get(512)
n, err := c.Read(buf)
if err != nil {
ch <- err

View File

@ -18,8 +18,8 @@ type Ip46 struct {
Ip6 netip.Addr
}
func ParseIp46(ctx context.Context, dialer proxy.Dialer, dns netip.AddrPort, host string, must46 bool) (ipv46 *Ip46, err error) {
addrs4, err := ResolveNetip(ctx, dialer, dns, host, dnsmessage.TypeA)
func ParseIp46(ctx context.Context, dialer proxy.Dialer, dns netip.AddrPort, host string, must46 bool, tcp bool) (ipv46 *Ip46, err error) {
addrs4, err := ResolveNetip(ctx, dialer, dns, host, dnsmessage.TypeA, tcp)
if err != nil {
return nil, err
}
@ -30,7 +30,7 @@ func ParseIp46(ctx context.Context, dialer proxy.Dialer, dns netip.AddrPort, hos
addrs4 = []netip.Addr{{}}
}
}
addrs6, err := ResolveNetip(ctx, dialer, dns, host, dnsmessage.TypeAAAA)
addrs6, err := ResolveNetip(ctx, dialer, dns, host, dnsmessage.TypeAAAA, tcp)
if err != nil {
return nil, err
}

View File

@ -24,8 +24,7 @@ type minLatency struct {
type AliveDialerSet struct {
log *logrus.Logger
dialerGroupName string
l4proto consts.L4ProtoStr
ipversion consts.IpVersionStr
CheckTyp *NetworkType
tolerance time.Duration
aliveChangeCallback func(alive bool)
@ -42,8 +41,7 @@ type AliveDialerSet struct {
func NewAliveDialerSet(
log *logrus.Logger,
dialerGroupName string,
l4proto consts.L4ProtoStr,
ipversion consts.IpVersionStr,
networkType *NetworkType,
tolerance time.Duration,
selectionPolicy consts.DialerSelectionPolicy,
dialers []*Dialer,
@ -53,8 +51,7 @@ func NewAliveDialerSet(
a := &AliveDialerSet{
log: log,
dialerGroupName: dialerGroupName,
l4proto: l4proto,
ipversion: ipversion,
CheckTyp: networkType,
tolerance: tolerance,
aliveChangeCallback: aliveChangeCallback,
dialerToIndex: make(map[*Dialer]int),
@ -102,10 +99,10 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
switch a.selectionPolicy {
case consts.DialerSelectionPolicy_MinLastLatency:
latency, hasLatency = dialer.MustGetLatencies10(a.l4proto, a.ipversion).LastLatency()
latency, hasLatency = dialer.MustGetLatencies10(a.CheckTyp).LastLatency()
minPolicy = true
case consts.DialerSelectionPolicy_MinAverage10Latencies:
latency, hasLatency = dialer.MustGetLatencies10(a.l4proto, a.ipversion).AvgLatency()
latency, hasLatency = dialer.MustGetLatencies10(a.CheckTyp).AvgLatency()
minPolicy = true
}
@ -180,7 +177,7 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
a.log.WithFields(logrus.Fields{
string(a.selectionPolicy): a.minLatency.latency,
"group": a.dialerGroupName,
"network": string(a.l4proto) + string(a.ipversion),
"network": a.CheckTyp.String(),
"new dialer": a.minLatency.dialer.Name(),
"old dialer": oldDialerName,
}).Infof("Group %vselects dialer", re)
@ -189,7 +186,7 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
defer a.aliveChangeCallback(false)
a.log.WithFields(logrus.Fields{
"group": a.dialerGroupName,
"network": string(a.l4proto) + string(a.ipversion),
"network": a.CheckTyp.String(),
}).Infof("Group has no dialer alive")
}
}
@ -199,7 +196,7 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
a.minLatency.dialer = dialer
a.log.WithFields(logrus.Fields{
"group": a.dialerGroupName,
"network": string(a.l4proto) + string(a.ipversion),
"network": a.CheckTyp.String(),
"dialer": a.minLatency.dialer.Name(),
}).Infof("Group selects dialer")
}

View File

@ -25,6 +25,20 @@ import (
"time"
)
type NetworkType struct {
L4Proto consts.L4ProtoStr
IpVersion consts.IpVersionStr
IsDns bool
}
func (t *NetworkType) String() string {
if t.IsDns {
return string(t.L4Proto) + string(t.IpVersion) + "(DNS)"
} else {
return string(t.L4Proto) + string(t.IpVersion)
}
}
type collection struct {
// AliveDialerSetSet uses reference counting.
AliveDialerSetSet AliveDialerSetSet
@ -40,32 +54,45 @@ func newCollection() *collection {
}
}
func (d *Dialer) mustGetCollection(l4proto consts.L4ProtoStr, ipversion consts.IpVersionStr) *collection {
switch l4proto {
case consts.L4ProtoStr_TCP:
switch ipversion {
case consts.IpVersionStr_4:
return d.collections[0]
case consts.IpVersionStr_6:
return d.collections[1]
func (d *Dialer) mustGetCollection(typ *NetworkType) *collection {
if typ.IsDns {
switch typ.L4Proto {
case consts.L4ProtoStr_TCP:
switch typ.IpVersion {
case consts.IpVersionStr_4:
return d.collections[0]
case consts.IpVersionStr_6:
return d.collections[1]
}
case consts.L4ProtoStr_UDP:
switch typ.IpVersion {
case consts.IpVersionStr_4:
return d.collections[2]
case consts.IpVersionStr_6:
return d.collections[3]
}
}
case consts.L4ProtoStr_UDP:
switch ipversion {
case consts.IpVersionStr_4:
return d.collections[2]
case consts.IpVersionStr_6:
return d.collections[3]
} else {
switch typ.L4Proto {
case consts.L4ProtoStr_TCP:
switch typ.IpVersion {
case consts.IpVersionStr_4:
return d.collections[4]
case consts.IpVersionStr_6:
return d.collections[5]
}
case consts.L4ProtoStr_UDP:
}
}
panic("invalid param")
}
func (d *Dialer) MustGetLatencies10(l4proto consts.L4ProtoStr, ipversion consts.IpVersionStr) *LatenciesN {
return d.mustGetCollection(l4proto, ipversion).Latencies10
func (d *Dialer) MustGetLatencies10(typ *NetworkType) *LatenciesN {
return d.mustGetCollection(typ).Latencies10
}
func (d *Dialer) MustGetAlive(l4proto consts.L4ProtoStr, ipversion consts.IpVersionStr) bool {
return d.mustGetCollection(l4proto, ipversion).Alive
func (d *Dialer) MustGetAlive(typ *NetworkType) bool {
return d.mustGetCollection(typ).Alive
}
type TcpCheckOption struct {
@ -88,7 +115,7 @@ func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOptio
if err != nil {
return nil, err
}
ip46, err := netutils.ParseIp46(ctx, SymmetricDirect, systemDns, u.Hostname(), true)
ip46, err := netutils.ParseIp46(ctx, SymmetricDirect, systemDns, u.Hostname(), true, false)
if err != nil {
return nil, err
}
@ -98,13 +125,13 @@ func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOptio
}, nil
}
type UdpCheckOption struct {
type CheckDnsOption struct {
DnsHost string
DnsPort uint16
*netutils.Ip46
}
func ParseUdpCheckOption(ctx context.Context, dnsHostPort string) (opt *UdpCheckOption, err error) {
func ParseCheckDnsOption(ctx context.Context, dnsHostPort string) (opt *CheckDnsOption, err error) {
systemDns, err := netutils.SystemDns()
if err != nil {
return nil, err
@ -123,11 +150,11 @@ func ParseUdpCheckOption(ctx context.Context, dnsHostPort string) (opt *UdpCheck
if err != nil {
return nil, fmt.Errorf("bad port: %v", err)
}
ip46, err := netutils.ParseIp46(ctx, SymmetricDirect, systemDns, host, true)
ip46, err := netutils.ParseIp46(ctx, SymmetricDirect, systemDns, host, true, false)
if err != nil {
return nil, err
}
return &UdpCheckOption{
return &CheckDnsOption{
DnsHost: host,
DnsPort: uint16(port),
Ip46: ip46,
@ -155,19 +182,19 @@ func (c *TcpCheckOptionRaw) Option() (opt *TcpCheckOption, err error) {
return c.opt, nil
}
type UdpCheckOptionRaw struct {
opt *UdpCheckOption
type CheckDnsOptionRaw struct {
opt *CheckDnsOption
mu sync.Mutex
Raw string
}
func (c *UdpCheckOptionRaw) Option() (opt *UdpCheckOption, err error) {
func (c *CheckDnsOptionRaw) Option() (opt *CheckDnsOption, err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.opt == nil {
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
udpCheckOption, err := ParseUdpCheckOption(ctx, c.Raw)
udpCheckOption, err := ParseCheckDnsOption(ctx, c.Raw)
if err != nil {
return nil, fmt.Errorf("failed to parse tcp_check_url: %w", err)
}
@ -177,9 +204,8 @@ func (c *UdpCheckOptionRaw) Option() (opt *UdpCheckOption, err error) {
}
type CheckOption struct {
L4proto consts.L4ProtoStr
IpVersion consts.IpVersionStr
CheckFunc func(ctx context.Context) (ok bool, err error)
networkType *NetworkType
CheckFunc func(ctx context.Context) (ok bool, err error)
}
func (d *Dialer) ActivateCheck() {
@ -196,8 +222,11 @@ func (d *Dialer) aliveBackground() {
timeout := 10 * time.Second
cycle := d.CheckInterval
tcp4CheckOpt := &CheckOption{
L4proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_4,
networkType: &NetworkType{
L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_4,
IsDns: false,
},
CheckFunc: func(ctx context.Context) (ok bool, err error) {
opt, err := d.TcpCheckOptionRaw.Option()
if err != nil {
@ -207,8 +236,11 @@ func (d *Dialer) aliveBackground() {
},
}
tcp6CheckOpt := &CheckOption{
L4proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_6,
networkType: &NetworkType{
L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_6,
IsDns: false,
},
CheckFunc: func(ctx context.Context) (ok bool, err error) {
opt, err := d.TcpCheckOptionRaw.Option()
if err != nil {
@ -217,33 +249,75 @@ func (d *Dialer) aliveBackground() {
return d.HttpCheck(ctx, opt.Url, opt.Ip6)
},
}
udp4CheckOpt := &CheckOption{
L4proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionStr_4,
tcp4CheckDnsOpt := &CheckOption{
networkType: &NetworkType{
L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_4,
IsDns: true,
},
CheckFunc: func(ctx context.Context) (ok bool, err error) {
opt, err := d.UdpCheckOptionRaw.Option()
opt, err := d.CheckDnsOptionRaw.Option()
if err != nil {
return false, err
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort))
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), true)
},
}
udp6CheckOpt := &CheckOption{
L4proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionStr_6,
tcp6CheckDnsOpt := &CheckOption{
networkType: &NetworkType{
L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_6,
IsDns: true,
},
CheckFunc: func(ctx context.Context) (ok bool, err error) {
opt, err := d.UdpCheckOptionRaw.Option()
opt, err := d.CheckDnsOptionRaw.Option()
if err != nil {
return false, err
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort))
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), true)
},
}
udp4CheckDnsOpt := &CheckOption{
networkType: &NetworkType{
L4Proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionStr_4,
IsDns: true,
},
CheckFunc: func(ctx context.Context) (ok bool, err error) {
opt, err := d.CheckDnsOptionRaw.Option()
if err != nil {
return false, err
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip4, opt.DnsPort), false)
},
}
udp6CheckDnsOpt := &CheckOption{
networkType: &NetworkType{
L4Proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionStr_6,
IsDns: true,
},
CheckFunc: func(ctx context.Context) (ok bool, err error) {
opt, err := d.CheckDnsOptionRaw.Option()
if err != nil {
return false, err
}
return d.DnsCheck(ctx, netip.AddrPortFrom(opt.Ip6, opt.DnsPort), false)
},
}
var CheckOpts = []*CheckOption{
tcp4CheckOpt,
tcp6CheckOpt,
udp4CheckDnsOpt,
udp6CheckDnsOpt,
tcp4CheckDnsOpt,
tcp6CheckDnsOpt,
}
// Check once immediately.
go d.Check(timeout, tcp4CheckOpt)
go d.Check(timeout, udp4CheckOpt)
go d.Check(timeout, tcp6CheckOpt)
go d.Check(timeout, udp6CheckOpt)
for i := range CheckOpts {
opt := CheckOpts[i]
go d.Check(timeout, opt)
}
ctx, cancel := context.WithCancel(d.ctx)
defer cancel()
@ -266,33 +340,14 @@ func (d *Dialer) aliveBackground() {
var wg sync.WaitGroup
for range d.checkCh {
// No need to test if there is no dialer selection policy using its latency.
if len(d.mustGetCollection(consts.L4ProtoStr_TCP, consts.IpVersionStr_4).AliveDialerSetSet) > 0 {
wg.Add(1)
go func() {
d.Check(timeout, tcp4CheckOpt)
wg.Done()
}()
}
if len(d.mustGetCollection(consts.L4ProtoStr_TCP, consts.IpVersionStr_6).AliveDialerSetSet) > 0 {
wg.Add(1)
go func() {
d.Check(timeout, tcp6CheckOpt)
wg.Done()
}()
}
if len(d.mustGetCollection(consts.L4ProtoStr_UDP, consts.IpVersionStr_4).AliveDialerSetSet) > 0 {
wg.Add(1)
go func() {
d.Check(timeout, udp4CheckOpt)
wg.Done()
}()
}
if len(d.mustGetCollection(consts.L4ProtoStr_UDP, consts.IpVersionStr_6).AliveDialerSetSet) > 0 {
wg.Add(1)
go func() {
d.Check(timeout, udp6CheckOpt)
wg.Done()
}()
for _, opt := range CheckOpts {
if len(d.mustGetCollection(opt.networkType).AliveDialerSetSet) > 0 {
wg.Add(1)
go func(opt *CheckOption) {
d.Check(timeout, opt)
wg.Done()
}(opt)
}
}
// Wait to block the loop.
wg.Wait()
@ -315,17 +370,23 @@ func (d *Dialer) NotifyCheck() {
}
// RegisterAliveDialerSet is thread-safe.
func (d *Dialer) RegisterAliveDialerSet(a *AliveDialerSet, l4proto consts.L4ProtoStr, ipversion consts.IpVersionStr) {
func (d *Dialer) RegisterAliveDialerSet(a *AliveDialerSet) {
if a == nil {
return
}
d.collectionFineMu.Lock()
d.mustGetCollection(l4proto, ipversion).AliveDialerSetSet[a]++
d.mustGetCollection(a.CheckTyp).AliveDialerSetSet[a]++
d.collectionFineMu.Unlock()
}
// UnregisterAliveDialerSet is thread-safe.
func (d *Dialer) UnregisterAliveDialerSet(a *AliveDialerSet, l4proto consts.L4ProtoStr, ipversion consts.IpVersionStr) {
func (d *Dialer) UnregisterAliveDialerSet(a *AliveDialerSet) {
if a == nil {
return
}
d.collectionFineMu.Lock()
defer d.collectionFineMu.Unlock()
setSet := d.mustGetCollection(consts.L4ProtoStr_TCP, consts.IpVersionStr_4).AliveDialerSetSet
setSet := d.mustGetCollection(a.CheckTyp).AliveDialerSetSet
setSet[a]--
if setSet[a] <= 0 {
delete(setSet, a)
@ -339,16 +400,15 @@ func (d *Dialer) Check(timeout time.Duration,
defer cancel()
start := time.Now()
// Calc latency.
collection := d.mustGetCollection(opts.L4proto, opts.IpVersion)
collection := d.mustGetCollection(opts.networkType)
if ok, err = opts.CheckFunc(ctx); ok && err == nil {
// No error.
latency := time.Since(start)
latencies10 := d.mustGetCollection(opts.L4proto, opts.IpVersion).Latencies10
latencies10 := d.mustGetCollection(opts.networkType).Latencies10
latencies10.AppendLatency(latency)
avg, _ := latencies10.AvgLatency()
d.Log.WithFields(logrus.Fields{
// Add a space to ensure alphabetical order is first.
"network": string(opts.L4proto) + string(opts.IpVersion),
"network": opts.networkType.String(),
"node": d.name,
"last": latency.Truncate(time.Millisecond),
"avg_10": avg.Truncate(time.Millisecond),
@ -361,8 +421,7 @@ func (d *Dialer) Check(timeout time.Duration,
err = fmt.Errorf("network is unreachable")
}
d.Log.WithFields(logrus.Fields{
// Add a space to ensure alphabetical order is first.
"network": string(opts.L4proto) + string(opts.IpVersion),
"network": opts.networkType.String(),
"node": d.name,
"err": err.Error(),
}).Debugln("Connectivity Check Failed")
@ -412,8 +471,8 @@ func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr)
return resp.StatusCode >= 200 && resp.StatusCode < 400, nil
}
func (d *Dialer) DnsCheck(ctx context.Context, dns netip.AddrPort) (ok bool, err error) {
addrs, err := netutils.ResolveNetip(ctx, d, dns, consts.UdpCheckLookupHost, dnsmessage.TypeA)
func (d *Dialer) DnsCheck(ctx context.Context, dns netip.AddrPort, tcp bool) (ok bool, err error) {
addrs, err := netutils.ResolveNetip(ctx, d, dns, consts.UdpCheckLookupHost, dnsmessage.TypeA, tcp)
if err != nil {
return false, err
}

View File

@ -23,7 +23,7 @@ type Dialer struct {
link string
collectionFineMu sync.Mutex
collections [4]*collection
collections [6]*collection
tickerMu sync.Mutex
ticker *time.Ticker
@ -35,9 +35,10 @@ type Dialer struct {
type GlobalOption struct {
Log *logrus.Logger
TcpCheckOptionRaw TcpCheckOptionRaw // Lazy parse
UdpCheckOptionRaw UdpCheckOptionRaw // Lazy parse
CheckDnsOptionRaw CheckDnsOptionRaw // Lazy parse
CheckInterval time.Duration
CheckTolerance time.Duration
CheckDnsTcp bool
}
type InstanceOption struct {
@ -48,7 +49,7 @@ type AliveDialerSetSet map[*AliveDialerSet]int
// NewDialer is for register in general.
func NewDialer(dialer proxy.Dialer, option *GlobalOption, iOption InstanceOption, name string, protocol string, link string) *Dialer {
var collections [4]*collection
var collections [6]*collection
for i := range collections {
collections[i] = newCollection()
}

View File

@ -11,51 +11,100 @@ import (
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/outbound/dialer"
"golang.org/x/net/proxy"
"golang.org/x/sys/unix"
"net"
"net/netip"
"strings"
"time"
)
var NoAliveDialerError = fmt.Errorf("no alive dialer")
type DialerGroup struct {
proxy.Dialer
block *dialer.Dialer
log *logrus.Logger
Name string
Dialers []*dialer.Dialer
registeredAliveDialerSet bool
AliveTcp4DialerSet *dialer.AliveDialerSet
AliveTcp6DialerSet *dialer.AliveDialerSet
AliveUdp4DialerSet *dialer.AliveDialerSet
AliveUdp6DialerSet *dialer.AliveDialerSet
aliveDialerSets [6]*dialer.AliveDialerSet
selectionPolicy *DialerSelectionPolicy
}
func NewDialerGroup(option *dialer.GlobalOption, name string, dialers []*dialer.Dialer, p DialerSelectionPolicy, aliveChangeCallback func(alive bool, l4proto uint8, ipversion uint8)) *DialerGroup {
func NewDialerGroup(option *dialer.GlobalOption, name string, dialers []*dialer.Dialer, p DialerSelectionPolicy, aliveChangeCallback func(alive bool, networkType *dialer.NetworkType)) *DialerGroup {
log := option.Log
var registeredAliveDialerSet bool
aliveTcp4DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_TCP, consts.IpVersionStr_4, option.CheckTolerance, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_TCP, 4) }, true)
aliveTcp6DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_TCP, consts.IpVersionStr_6, option.CheckTolerance, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_TCP, 6) }, true)
aliveUdp4DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_UDP, consts.IpVersionStr_4, option.CheckTolerance, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_UDP, 4) }, true)
aliveUdp6DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_UDP, consts.IpVersionStr_6, option.CheckTolerance, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_UDP, 6) }, true)
var aliveDnsTcp4DialerSet *dialer.AliveDialerSet
var aliveDnsTcp6DialerSet *dialer.AliveDialerSet
var aliveTcp4DialerSet *dialer.AliveDialerSet
var aliveTcp6DialerSet *dialer.AliveDialerSet
var aliveDnsUdp4DialerSet *dialer.AliveDialerSet
var aliveDnsUdp6DialerSet *dialer.AliveDialerSet
switch p.Policy {
case consts.DialerSelectionPolicy_Random,
consts.DialerSelectionPolicy_MinLastLatency,
consts.DialerSelectionPolicy_MinAverage10Latencies:
// Need to know the alive state or latency.
for _, d := range dialers {
d.RegisterAliveDialerSet(aliveTcp4DialerSet, consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d.RegisterAliveDialerSet(aliveTcp6DialerSet, consts.L4ProtoStr_TCP, consts.IpVersionStr_6)
d.RegisterAliveDialerSet(aliveUdp4DialerSet, consts.L4ProtoStr_UDP, consts.IpVersionStr_4)
d.RegisterAliveDialerSet(aliveUdp6DialerSet, consts.L4ProtoStr_UDP, consts.IpVersionStr_6)
networkType := &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_4,
IsDns: false,
}
aliveTcp4DialerSet = dialer.NewAliveDialerSet(
log, name, networkType, option.CheckTolerance, p.Policy, dialers,
func(networkType *dialer.NetworkType) func(alive bool) {
// Use the trick to copy a pointer of *dialer.NetworkType.
return func(alive bool) { aliveChangeCallback(alive, networkType) }
}(networkType), true)
networkType = &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_6,
IsDns: false,
}
aliveTcp6DialerSet = dialer.NewAliveDialerSet(
log, name, networkType, option.CheckTolerance, p.Policy, dialers,
func(networkType *dialer.NetworkType) func(alive bool) {
// Use the trick to copy a pointer of *dialer.NetworkType.
return func(alive bool) { aliveChangeCallback(alive, networkType) }
}(networkType), true)
networkType = &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionStr_4,
IsDns: true,
}
aliveDnsUdp4DialerSet = dialer.NewAliveDialerSet(
log, name, networkType, option.CheckTolerance, p.Policy, dialers,
func(networkType *dialer.NetworkType) func(alive bool) {
// Use the trick to copy a pointer of *dialer.NetworkType.
return func(alive bool) { aliveChangeCallback(alive, networkType) }
}(networkType), true)
networkType = &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionStr_6,
IsDns: true,
}
aliveDnsUdp6DialerSet = dialer.NewAliveDialerSet(
log, name, networkType, option.CheckTolerance, p.Policy, dialers,
func(networkType *dialer.NetworkType) func(alive bool) {
// Use the trick to copy a pointer of *dialer.NetworkType.
return func(alive bool) { aliveChangeCallback(alive, networkType) }
}(networkType), true)
if option.CheckDnsTcp {
aliveDnsTcp4DialerSet = dialer.NewAliveDialerSet(log, name, &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_4,
IsDns: true,
}, option.CheckTolerance, p.Policy, dialers, func(alive bool) {}, true)
aliveDnsTcp6DialerSet = dialer.NewAliveDialerSet(log, name, &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_6,
IsDns: true,
}, option.CheckTolerance, p.Policy, dialers, func(alive bool) {}, true)
}
registeredAliveDialerSet = true
case consts.DialerSelectionPolicy_Fixed:
// No need to know if the dialer is alive.
@ -64,31 +113,35 @@ func NewDialerGroup(option *dialer.GlobalOption, name string, dialers []*dialer.
log.Panicf("Unexpected dialer selection policy: %v", p.Policy)
}
for _, d := range dialers {
d.RegisterAliveDialerSet(aliveTcp4DialerSet)
d.RegisterAliveDialerSet(aliveTcp6DialerSet)
d.RegisterAliveDialerSet(aliveDnsTcp4DialerSet)
d.RegisterAliveDialerSet(aliveDnsTcp6DialerSet)
d.RegisterAliveDialerSet(aliveDnsUdp4DialerSet)
d.RegisterAliveDialerSet(aliveDnsUdp6DialerSet)
}
return &DialerGroup{
log: log,
Name: name,
Dialers: dialers,
block: dialer.NewBlockDialer(option, func() {
log.WithFields(logrus.Fields{
"group": name,
}).Warnf("No alive dialer for given nerwork in DialerGroup, use \"block\".")
}),
AliveTcp4DialerSet: aliveTcp4DialerSet,
AliveTcp6DialerSet: aliveTcp6DialerSet,
AliveUdp4DialerSet: aliveUdp4DialerSet,
AliveUdp6DialerSet: aliveUdp6DialerSet,
registeredAliveDialerSet: registeredAliveDialerSet,
selectionPolicy: &p,
aliveDialerSets: [6]*dialer.AliveDialerSet{
aliveDnsTcp4DialerSet,
aliveDnsTcp6DialerSet,
aliveDnsUdp4DialerSet,
aliveDnsUdp6DialerSet,
aliveTcp4DialerSet,
aliveTcp6DialerSet,
},
selectionPolicy: &p,
}
}
func (g *DialerGroup) Close() error {
if g.registeredAliveDialerSet {
for _, d := range g.Dialers {
d.UnregisterAliveDialerSet(g.AliveTcp4DialerSet, consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d.UnregisterAliveDialerSet(g.AliveTcp6DialerSet, consts.L4ProtoStr_TCP, consts.IpVersionStr_6)
d.UnregisterAliveDialerSet(g.AliveUdp4DialerSet, consts.L4ProtoStr_UDP, consts.IpVersionStr_4)
d.UnregisterAliveDialerSet(g.AliveUdp6DialerSet, consts.L4ProtoStr_UDP, consts.IpVersionStr_6)
for _, d := range g.Dialers {
for _, a := range g.aliveDialerSets {
d.UnregisterAliveDialerSet(a)
}
}
return nil
@ -99,37 +152,53 @@ func (g *DialerGroup) SetSelectionPolicy(policy DialerSelectionPolicy) {
g.selectionPolicy = &policy
}
func (d *DialerGroup) MustGetAliveDialerSet(typ *dialer.NetworkType) *dialer.AliveDialerSet {
if typ.IsDns {
switch typ.L4Proto {
case consts.L4ProtoStr_TCP:
switch typ.IpVersion {
case consts.IpVersionStr_4:
return d.aliveDialerSets[0]
case consts.IpVersionStr_6:
return d.aliveDialerSets[1]
}
case consts.L4ProtoStr_UDP:
switch typ.IpVersion {
case consts.IpVersionStr_4:
return d.aliveDialerSets[2]
case consts.IpVersionStr_6:
return d.aliveDialerSets[3]
}
}
} else {
switch typ.L4Proto {
case consts.L4ProtoStr_TCP:
switch typ.IpVersion {
case consts.IpVersionStr_4:
return d.aliveDialerSets[4]
case consts.IpVersionStr_6:
return d.aliveDialerSets[5]
}
case consts.L4ProtoStr_UDP:
}
}
panic("invalid param")
}
// Select selects a dialer from group according to selectionPolicy.
func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersionStr) (d *dialer.Dialer, latency time.Duration, err error) {
func (g *DialerGroup) Select(networkType *dialer.NetworkType) (d *dialer.Dialer, latency time.Duration, err error) {
if len(g.Dialers) == 0 {
return nil, 0, fmt.Errorf("no dialer in this group")
}
var a *dialer.AliveDialerSet
switch l4proto {
case consts.L4ProtoStr_TCP:
switch ipversion {
case consts.IpVersionStr_4:
a = g.AliveTcp4DialerSet
case consts.IpVersionStr_6:
a = g.AliveTcp6DialerSet
}
case consts.L4ProtoStr_UDP:
switch ipversion {
case consts.IpVersionStr_4:
a = g.AliveUdp4DialerSet
case consts.IpVersionStr_6:
a = g.AliveUdp6DialerSet
}
default:
return nil, 0, fmt.Errorf("DialerGroup.Select: unexpected l4proto type: %v", l4proto)
}
a := g.MustGetAliveDialerSet(networkType)
switch g.selectionPolicy.Policy {
case consts.DialerSelectionPolicy_Random:
d := a.GetRand()
if d == nil {
// No alive dialer.
return g.block, time.Hour, nil
return nil, time.Hour, NoAliveDialerError
}
return d, 0, nil
@ -143,7 +212,7 @@ func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersi
d, latency := a.GetMinLatency()
if d == nil {
// No alive dialer.
return g.block, time.Hour, nil
return nil, time.Hour, NoAliveDialerError
}
return d, latency, nil
@ -151,24 +220,3 @@ func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersi
return nil, 0, fmt.Errorf("unsupported DialerSelectionPolicy: %v", g.selectionPolicy)
}
}
func (g *DialerGroup) Dial(network string, addr string) (c net.Conn, err error) {
var d proxy.Dialer
ipAddr, err := netip.ParseAddr(addr)
if err != nil {
return nil, fmt.Errorf("DialerGroup.Dial only supports ip as addr")
}
ipversion := consts.IpVersionFromAddr(ipAddr)
switch {
case strings.HasPrefix(network, "tcp"):
d, _, err = g.Select(consts.L4ProtoStr_TCP, ipversion)
case strings.HasPrefix(network, "udp"):
d, _, err = g.Select(consts.L4ProtoStr_UDP, ipversion)
default:
return nil, fmt.Errorf("unexpected network: %v", network)
}
if err != nil {
return nil, err
}
return d.Dial(network, addr)
}

View File

@ -6,7 +6,6 @@
package outbound
import (
"context"
"github.com/mzz2017/softwind/pkg/fastrand"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/outbound/dialer"
@ -20,21 +19,21 @@ const (
testUdpCheckDns = "https://connectivitycheck.gstatic.com/generate_204"
)
var TestNetworkType = &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_4,
IsDns: false,
}
func TestDialerGroup_Select_Fixed(t *testing.T) {
log := logger.NewLogger("trace", false)
topt, err := dialer.ParseTcpCheckOption(context.TODO(), testTcpCheckUrl)
if err != nil {
t.Fatal(err)
}
uopt, err := dialer.ParseUdpCheckOption(context.TODO(), testUdpCheckDns)
if err != nil {
t.Fatal(err)
}
option := &dialer.GlobalOption{
Log: log,
TcpCheckOption: topt,
UdpCheckOption: uopt,
CheckInterval: 15 * time.Second,
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
CheckInterval: 15 * time.Second,
CheckTolerance: 0,
CheckDnsTcp: false,
}
dialers := []*dialer.Dialer{
dialer.NewDirectDialer(option, true),
@ -44,9 +43,9 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Fixed,
FixedIndex: fixedIndex,
}, func(alive bool, l4proto uint8, ipversion uint8) {})
}, func(alive bool, networkType *dialer.NetworkType) {})
for i := 0; i < 10; i++ {
d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d, _, err := g.Select(TestNetworkType)
if err != nil {
t.Fatal(err)
}
@ -58,7 +57,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
fixedIndex = 0
g.selectionPolicy.FixedIndex = fixedIndex
for i := 0; i < 10; i++ {
d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d, _, err := g.Select(TestNetworkType)
if err != nil {
t.Fatal(err)
}
@ -70,19 +69,12 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
log := logger.NewLogger("trace", false)
topt, err := dialer.ParseTcpCheckOption(context.TODO(), testTcpCheckUrl)
if err != nil {
t.Fatal(err)
}
uopt, err := dialer.ParseUdpCheckOption(context.TODO(), testUdpCheckDns)
if err != nil {
t.Fatal(err)
}
option := &dialer.GlobalOption{
Log: log,
TcpCheckOption: topt,
UdpCheckOption: uopt,
CheckInterval: 15 * time.Second,
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
CheckInterval: 15 * time.Second,
}
dialers := []*dialer.Dialer{
dialer.NewDirectDialer(option, false),
@ -98,7 +90,7 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
}
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_MinLastLatency,
}, func(alive bool, l4proto uint8, ipversion uint8) {})
}, func(alive bool, networkType *dialer.NetworkType) {})
// Test 1000 times.
for i := 0; i < 1000; i++ {
@ -120,14 +112,14 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
latency = time.Duration(fastrand.Int63n(int64(1000 * time.Millisecond)))
alive = true
}
d.MustGetLatencies10(consts.L4ProtoStr_TCP, consts.IpVersionStr_4).AppendLatency(latency)
d.MustGetLatencies10(TestNetworkType).AppendLatency(latency)
if jMinLatency == -1 || latency < minLatency {
jMinLatency = j
minLatency = latency
}
g.AliveTcp4DialerSet.NotifyLatencyChange(d, alive)
g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(d, alive)
}
d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d, _, err := g.Select(TestNetworkType)
if err != nil {
t.Fatal(err)
}
@ -147,19 +139,12 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
func TestDialerGroup_Select_Random(t *testing.T) {
log := logger.NewLogger("trace", false)
topt, err := dialer.ParseTcpCheckOption(context.TODO(), testTcpCheckUrl)
if err != nil {
t.Fatal(err)
}
uopt, err := dialer.ParseUdpCheckOption(context.TODO(), testUdpCheckDns)
if err != nil {
t.Fatal(err)
}
option := &dialer.GlobalOption{
Log: log,
TcpCheckOption: topt,
UdpCheckOption: uopt,
CheckInterval: 15 * time.Second,
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
CheckInterval: 15 * time.Second,
}
dialers := []*dialer.Dialer{
dialer.NewDirectDialer(option, false),
@ -170,10 +155,10 @@ func TestDialerGroup_Select_Random(t *testing.T) {
}
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Random,
}, func(alive bool, l4proto uint8, ipversion uint8) {})
}, func(alive bool, networkType *dialer.NetworkType) {})
count := make([]int, len(dialers))
for i := 0; i < 100; i++ {
d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d, _, err := g.Select(TestNetworkType)
if err != nil {
t.Fatal(err)
}
@ -194,19 +179,12 @@ func TestDialerGroup_Select_Random(t *testing.T) {
func TestDialerGroup_SetAlive(t *testing.T) {
log := logger.NewLogger("trace", false)
topt, err := dialer.ParseTcpCheckOption(context.TODO(), testTcpCheckUrl)
if err != nil {
t.Fatal(err)
}
uopt, err := dialer.ParseUdpCheckOption(context.TODO(), testUdpCheckDns)
if err != nil {
t.Fatal(err)
}
option := &dialer.GlobalOption{
Log: log,
TcpCheckOption: topt,
UdpCheckOption: uopt,
CheckInterval: 15 * time.Second,
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
CheckInterval: 15 * time.Second,
}
dialers := []*dialer.Dialer{
dialer.NewDirectDialer(option, false),
@ -217,12 +195,12 @@ func TestDialerGroup_SetAlive(t *testing.T) {
}
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Random,
}, func(alive bool, l4proto uint8, ipversion uint8) {})
}, func(alive bool, networkType *dialer.NetworkType) {})
zeroTarget := 3
g.AliveTcp4DialerSet.NotifyLatencyChange(dialers[zeroTarget], false)
g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(dialers[zeroTarget], false)
count := make([]int, len(dialers))
for i := 0; i < 100; i++ {
d, _, err := g.Select(consts.L4ProtoStr_UDP, consts.IpVersionStr_4)
d, _, err := g.Select(TestNetworkType)
if err != nil {
t.Fatal(err)
}

View File

@ -14,10 +14,15 @@ import (
)
type Global struct {
TproxyPort uint16 `mapstructure:"tproxy_port" default:"12345"`
LogLevel string `mapstructure:"log_level" default:"info"`
TcpCheckUrl string `mapstructure:"tcp_check_url" default:"http://cp.cloudflare.com"`
UdpCheckDns string `mapstructure:"udp_check_dns" default:"cloudflare-dns.com:53"`
TproxyPort uint16 `mapstructure:"tproxy_port" default:"12345"`
LogLevel string `mapstructure:"log_level" default:"info"`
// We use DirectTcpCheckUrl to check (tcp)*(ipv4/ipv6) connectivity for direct.
//DirectTcpCheckUrl string `mapstructure:"direct_tcp_check_url" default:"http://www.qualcomm.cn/generate_204"`
// We use TcpCheckUrl to check (tcp)*(ipv4/ipv6) connectivity for non-direct and non-DNS packets.
TcpCheckUrl string `mapstructure:"tcp_check_url" default:"http://cp.cloudflare.com"`
// We use UdpCheckDns to check (tcp/udp)*(ipv4/ipv6) connectivity for DNS packets,
// and udp*(ipv4/ipv6) connectivity for all other types of packets.
UdpCheckDns string `mapstructure:"udp_check_dns" default:"dns.google:53"`
CheckInterval time.Duration `mapstructure:"check_interval" default:"30s"`
CheckTolerance time.Duration `mapstructure:"check_tolerance" default:"0"`
DnsUpstream common.UrlOrEmpty `mapstructure:"dns_upstream" required:""`

View File

@ -8,6 +8,7 @@ package control
import (
"github.com/cilium/ebpf"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/component/outbound/dialer"
"golang.org/x/sys/unix"
"strconv"
)
@ -22,11 +23,11 @@ func FormatL4Proto(l4proto uint8) string {
return strconv.Itoa(int(l4proto))
}
func (c *ControlPlaneCore) OutboundAliveChangeCallback(outbound uint8) func(alive bool, l4proto uint8, ipversion uint8) {
return func(alive bool, l4proto uint8, ipversion uint8) {
func (c *ControlPlaneCore) OutboundAliveChangeCallback(outbound uint8) func(alive bool, networkType *dialer.NetworkType) {
return func(alive bool, networkType *dialer.NetworkType) {
c.log.WithFields(logrus.Fields{
"alive": alive,
"network": FormatL4Proto(l4proto) + strconv.Itoa(int(ipversion)),
"network": networkType.String(),
"outbound_id": outbound,
}).Warnf("Outbound alive state changed, notify the kernel program.")
@ -36,8 +37,8 @@ func (c *ControlPlaneCore) OutboundAliveChangeCallback(outbound uint8) func(aliv
}
_ = c.bpf.OutboundConnectivityMap.Update(bpfOutboundConnectivityQuery{
Outbound: outbound,
L4proto: l4proto,
Ipversion: ipversion,
L4proto: networkType.L4Proto.ToL4Proto(),
Ipversion: networkType.IpVersion.ToIpVersion(),
}, value, ebpf.UpdateAny)
}
}

View File

@ -193,12 +193,20 @@ func NewControlPlane(
}
/// DialerGroups (outbounds).
checkDnsTcp := false
if !global.DnsUpstream.Empty {
if scheme, _, _, err := ParseDnsUpstream(global.DnsUpstream.Url); err == nil &&
scheme.ContainsTcp() {
checkDnsTcp = true
}
}
option := &dialer.GlobalOption{
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: global.TcpCheckUrl},
UdpCheckOptionRaw: dialer.UdpCheckOptionRaw{Raw: global.UdpCheckDns},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: global.UdpCheckDns},
CheckInterval: global.CheckInterval,
CheckTolerance: global.CheckTolerance,
CheckDnsTcp: checkDnsTcp,
}
outbounds := []*outbound.DialerGroup{
outbound.NewDialerGroup(option, consts.OutboundDirect.String(),

View File

@ -26,6 +26,16 @@ const (
DnsUpstreamScheme_TCP_UDP DnsUpstreamScheme = "tcp+udp"
)
func (s DnsUpstreamScheme) ContainsTcp() bool {
switch s {
case DnsUpstreamScheme_TCP,
DnsUpstreamScheme_TCP_UDP:
return true
default:
return false
}
}
type DnsUpstream struct {
Scheme DnsUpstreamScheme
Hostname string
@ -33,16 +43,30 @@ type DnsUpstream struct {
*netutils.Ip46
}
func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstream, err error) {
var _port string
switch DnsUpstreamScheme(dnsUpstream.Scheme) {
func ParseDnsUpstream(dnsUpstream *url.URL) (scheme DnsUpstreamScheme, hostname string, port uint16, err error) {
var __port string
switch scheme = DnsUpstreamScheme(dnsUpstream.Scheme); scheme {
case DnsUpstreamScheme_TCP, DnsUpstreamScheme_UDP, DnsUpstreamScheme_TCP_UDP:
_port = dnsUpstream.Port()
if _port == "" {
_port = "53"
__port = dnsUpstream.Port()
if __port == "" {
__port = "53"
}
default:
return nil, fmt.Errorf("dns_upstream now only supports auto://, udp://, tcp:// and empty string (as-is)")
return "", "", 0, fmt.Errorf("unexpected dns_upstream format")
}
_port, err := strconv.ParseUint(dnsUpstream.Port(), 10, 16)
port = uint16(_port)
if err != nil {
return "", "", 0, fmt.Errorf("parse dns_upstream port: %v", err)
}
hostname = dnsUpstream.Hostname()
return scheme, hostname, port, nil
}
func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstream, err error) {
scheme, hostname, port, err := ParseDnsUpstream(dnsUpstream)
if err != nil {
return nil, err
}
systemDns, err := netutils.SystemDns()
@ -55,22 +79,18 @@ func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstr
}
}()
port, err := strconv.ParseUint(dnsUpstream.Port(), 10, 16)
if err != nil {
return nil, fmt.Errorf("parse dns_upstream port: %v", err)
}
hostname := dnsUpstream.Hostname()
ip46, err := netutils.ParseIp46(ctx, dialer.SymmetricDirect, systemDns, hostname, false)
ip46, err := netutils.ParseIp46(ctx, dialer.SymmetricDirect, systemDns, hostname, false, false)
if err != nil {
return nil, fmt.Errorf("failed to resolve dns_upstream: %w", err)
}
if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() {
return nil, fmt.Errorf("dns_upstream has no record")
}
return &DnsUpstream{
Scheme: DnsUpstreamScheme(dnsUpstream.Scheme),
Scheme: scheme,
Hostname: hostname,
Port: uint16(port),
Port: port,
Ip46: ip46,
}, nil
}

View File

@ -11,6 +11,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/outbound/dialer"
internal "github.com/v2rayA/dae/pkg/ebpf_internal"
"golang.org/x/sys/unix"
"net"
@ -58,18 +59,21 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
if outboundIndex < 0 || int(outboundIndex) >= len(c.outbounds) {
return fmt.Errorf("outbound id from bpf is out of range: %v not in [0, %v]", outboundIndex, len(c.outbounds)-1)
}
l4proto := consts.L4ProtoStr_TCP
ipversion := consts.IpVersionFromAddr(dst.Addr())
dialer, _, err := outbound.Select(l4proto, ipversion)
networkType := &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionFromAddr(dst.Addr()),
IsDns: false,
}
d, _, err := outbound.Select(networkType)
if err != nil {
return fmt.Errorf("failed to select dialer from group %v: %w", outbound.Name, err)
return fmt.Errorf("failed to select dialer from group %v (%v): %w", outbound.Name, networkType.String(), err)
}
c.log.WithFields(logrus.Fields{
"network": string(l4proto) + string(ipversion),
"network": networkType.String(),
"outbound": outbound.Name,
"dialer": dialer.Name(),
"dialer": d.Name(),
}).Infof("%v <-> %v", RefineSourceToShow(src, dst.Addr()), RefineAddrPortToShow(dst))
rConn, err := dialer.Dial("tcp", dst.String())
rConn, err := d.Dial("tcp", dst.String())
if err != nil {
return fmt.Errorf("failed to dial %v: %w", dst, err)
}

View File

@ -198,19 +198,24 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI
c.log.WithFields(logrus.Fields{
"ipversions": ipversions,
"l4protos": l4protos,
"src": src.String(),
}).Traceln("Choose DNS path")
// Get the min latency path.
networkType := dialer.NetworkType{
IsDns: isDns,
}
for _, ver := range ipversions {
for _, proto := range l4protos {
d, latency, err := outbound.Select(proto, ver)
networkType.L4Proto = proto
networkType.IpVersion = ver
d, latency, err := outbound.Select(&networkType)
if err != nil {
continue
}
c.log.WithFields(logrus.Fields{
"name": d.Name(),
"latency": latency,
"ver": ver,
"proto": proto,
"network": networkType.String(),
"outbound": outbound.Name,
}).Traceln("Choice")
if bestDialer == nil || latency < bestLatency {
@ -236,10 +241,15 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI
"Network": string(l4proto) + string(ipversion),
}).Traceln("Modify DNS target")
}
networkType := &dialer.NetworkType{
L4Proto: l4proto,
IpVersion: ipversion,
IsDns: true,
}
if dialerForNew == nil {
dialerForNew, _, err = outbound.Select(l4proto, ipversion)
dialerForNew, _, err = outbound.Select(networkType)
if err != nil {
return fmt.Errorf("failed to select dialer from group %v: %w", outbound.Name, err)
return fmt.Errorf("failed to select dialer from group %v (%v): %w", outbound.Name, networkType.String(), err)
}
}
@ -271,7 +281,7 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI
return fmt.Errorf("failed to GetOrCreate: %w", err)
}
// If the udp endpoint has been not alive, remove it from pool and get a new one.
if !isNew && !ue.Dialer.MustGetAlive(l4proto, ipversion) {
if !isNew && !ue.Dialer.MustGetAlive(networkType) {
c.log.WithFields(logrus.Fields{
"src": RefineSourceToShow(src, dst.Addr()),
"network": string(l4proto) + string(ipversion),
@ -305,27 +315,33 @@ func (c *ControlPlane) handlePkt(data []byte, src, dst netip.AddrPort, outboundI
_ = conn.SetDeadline(time.Now().Add(natTimeout))
// We should write two byte length in the front of TCP DNS request.
bLen := pool.Get(2)
defer pool.Put(bLen)
binary.BigEndian.PutUint16(bLen, uint16(len(data)))
_, err = conn.Write(bLen)
bReq := pool.Get(2 + len(data))
defer pool.Put(bReq)
binary.BigEndian.PutUint16(bReq, uint16(len(data)))
copy(bReq[2:], data)
_, err = conn.Write(bReq)
if err != nil {
return fmt.Errorf("failed to write DNS req length: %w", err)
}
if _, err = conn.Write(data); err != nil {
return fmt.Errorf("failed to write DNS req payload: %w", err)
return fmt.Errorf("failed to write DNS req: %w", err)
}
// Read two byte length.
if _, err = io.ReadFull(conn, bLen); err != nil {
if _, err = io.ReadFull(conn, bReq[:2]); err != nil {
return fmt.Errorf("failed to read DNS resp payload length: %w", err)
}
buf := pool.Get(int(binary.BigEndian.Uint16(bLen)))
defer pool.Put(buf)
if _, err = io.ReadFull(conn, buf); err != nil {
respLen := int(binary.BigEndian.Uint16(bReq))
// Try to reuse the buf.
var buf []byte
if len(bReq) < respLen {
buf = pool.Get(respLen)
defer pool.Put(buf)
} else {
buf = bReq
}
var n int
if n, err = conn.Read(buf[:respLen]); err != nil {
return fmt.Errorf("failed to read DNS resp payload: %w", err)
}
if err = udpHandler(buf, destToSend); err != nil {
if err = udpHandler(buf[:n], destToSend); err != nil {
return fmt.Errorf("failed to write DNS resp to client: %w", err)
}
}

View File

@ -6,10 +6,12 @@ global {
log_level: info
# Node connectivity check.
# URL and DNS should have both IPv4 and IPv6 if you want to check both.
# Host of URL should have both IPv4 and IPv6 if you have double stack in local.
tcp_check_url: 'http://cp.cloudflare.com'
# This DNS will be used to check UDP connectivity of nodes.
# This DNS will be used to check UDP connectivity. And if dns_upstream below contains tcp, it also be used to check
# TCP DNS connectivity of nodes.
# Host of DNS should have both IPv4 and IPv6 if you have double stack in local.
udp_check_dns: 'dns.google:53'
check_interval: 30s