feat: support tcp:// and tcp+udp:// for dns_upstream (#11)

This commit is contained in:
mzz
2023-02-09 11:40:34 +08:00
committed by GitHub
parent ac8b88d8ca
commit 15faa3cdd2
15 changed files with 697 additions and 185 deletions

View File

@ -83,8 +83,8 @@ func (a *AliveDialerSet) GetRand() *Dialer {
}
// GetMinLatency acquires correct selectionPolicy.
func (a *AliveDialerSet) GetMinLatency() *Dialer {
return a.minLatency.dialer
func (a *AliveDialerSet) GetMinLatency() (d *Dialer, latency time.Duration) {
return a.minLatency.dialer, a.minLatency.latency
}
// NotifyLatencyChange should be invoked when dialer every time latency and alive state changes.

View File

@ -24,10 +24,6 @@ import (
"time"
)
var (
BootstrapDns = netip.MustParseAddrPort("223.5.5.5:53")
)
type collection struct {
// AliveDialerSetSet uses reference counting.
AliveDialerSetSet AliveDialerSetSet
@ -71,43 +67,22 @@ func (d *Dialer) MustGetAlive(l4proto consts.L4ProtoStr, ipversion consts.IpVers
return d.mustGetCollection(l4proto, ipversion).Alive
}
type Ip46 struct {
Ip4 netip.Addr
Ip6 netip.Addr
}
func ParseIp46(ctx context.Context, host string) (ipv46 *Ip46, err error) {
addrs4, err := netutils.ResolveNetip(ctx, SymmetricDirect, BootstrapDns, host, dnsmessage.TypeA)
if err != nil {
return nil, err
}
if len(addrs4) == 0 {
return nil, fmt.Errorf("domain \"%v\" has no ipv4 record", host)
}
addrs6, err := netutils.ResolveNetip(ctx, SymmetricDirect, BootstrapDns, host, dnsmessage.TypeAAAA)
if err != nil {
return nil, err
}
if len(addrs6) == 0 {
return nil, fmt.Errorf("domain \"%v\" has no ipv6 record", host)
}
return &Ip46{
Ip4: addrs4[0],
Ip6: addrs6[0],
}, nil
}
type TcpCheckOption struct {
Url *netutils.URL
*Ip46
*netutils.Ip46
}
func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOption, err error) {
systemDns, err := netutils.SystemDns()
if err != nil {
return nil, err
}
u, err := url.Parse(rawURL)
if err != nil {
return nil, err
}
ip46, err := ParseIp46(ctx, u.Hostname())
ip46, err := netutils.ParseIp46(ctx, SymmetricDirect, systemDns, u.Hostname(), true)
if err != nil {
return nil, err
}
@ -120,10 +95,15 @@ func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOptio
type UdpCheckOption struct {
DnsHost string
DnsPort uint16
*Ip46
*netutils.Ip46
}
func ParseUdpCheckOption(ctx context.Context, dnsHostPort string) (opt *UdpCheckOption, err error) {
systemDns, err := netutils.SystemDns()
if err != nil {
return nil, err
}
host, _port, err := net.SplitHostPort(dnsHostPort)
if err != nil {
return nil, err
@ -132,7 +112,7 @@ func ParseUdpCheckOption(ctx context.Context, dnsHostPort string) (opt *UdpCheck
if err != nil {
return nil, fmt.Errorf("bad port: %v", err)
}
ip46, err := ParseIp46(ctx, host)
ip46, err := netutils.ParseIp46(ctx, SymmetricDirect, systemDns, host, true)
if err != nil {
return nil, err
}

View File

@ -15,6 +15,7 @@ import (
"net"
"net/netip"
"strings"
"time"
)
type DialerGroup struct {
@ -95,9 +96,9 @@ func (g *DialerGroup) SetSelectionPolicy(policy DialerSelectionPolicy) {
}
// Select selects a dialer from group according to selectionPolicy.
func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersionStr) (*dialer.Dialer, error) {
func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersionStr) (d *dialer.Dialer, latency time.Duration, err error) {
if len(g.Dialers) == 0 {
return nil, fmt.Errorf("no dialer in this group")
return nil, 0, fmt.Errorf("no dialer in this group")
}
var a *dialer.AliveDialerSet
switch l4proto {
@ -116,7 +117,7 @@ func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersi
a = g.AliveUdp6DialerSet
}
default:
return nil, fmt.Errorf("DialerGroup.Select: unexpected l4proto type: %v", l4proto)
return nil, 0, fmt.Errorf("DialerGroup.Select: unexpected l4proto type: %v", l4proto)
}
switch g.selectionPolicy.Policy {
@ -128,30 +129,30 @@ func (g *DialerGroup) Select(l4proto consts.L4ProtoStr, ipversion consts.IpVersi
"network": string(l4proto) + string(ipversion),
"group": g.Name,
}).Warnf("No alive dialer in DialerGroup, use \"block\".")
return g.block, nil
return g.block, 0, nil
}
return d, nil
return d, 0, nil
case consts.DialerSelectionPolicy_Fixed:
if g.selectionPolicy.FixedIndex < 0 || g.selectionPolicy.FixedIndex >= len(g.Dialers) {
return nil, fmt.Errorf("selected dialer index is out of range")
return nil, 0, fmt.Errorf("selected dialer index is out of range")
}
return g.Dialers[g.selectionPolicy.FixedIndex], nil
return g.Dialers[g.selectionPolicy.FixedIndex], 0, nil
case consts.DialerSelectionPolicy_MinLastLatency, consts.DialerSelectionPolicy_MinAverage10Latencies:
d := a.GetMinLatency()
d, latency := a.GetMinLatency()
if d == nil {
// No alive dialer.
g.log.WithFields(logrus.Fields{
"network": string(l4proto) + string(ipversion),
"group": g.Name,
}).Warnf("No alive dialer in DialerGroup, use \"block\".")
return g.block, nil
return g.block, 0, nil
}
return d, nil
return d, latency, nil
default:
return nil, fmt.Errorf("unsupported DialerSelectionPolicy: %v", g.selectionPolicy)
return nil, 0, fmt.Errorf("unsupported DialerSelectionPolicy: %v", g.selectionPolicy)
}
}
@ -164,9 +165,9 @@ func (g *DialerGroup) Dial(network string, addr string) (c net.Conn, err error)
ipversion := consts.IpVersionFromAddr(ipAddr)
switch {
case strings.HasPrefix(network, "tcp"):
d, err = g.Select(consts.L4ProtoStr_TCP, ipversion)
d, _, err = g.Select(consts.L4ProtoStr_TCP, ipversion)
case strings.HasPrefix(network, "udp"):
d, err = g.Select(consts.L4ProtoStr_UDP, ipversion)
d, _, err = g.Select(consts.L4ProtoStr_UDP, ipversion)
default:
return nil, fmt.Errorf("unexpected network: %v", network)
}

View File

@ -44,9 +44,9 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Fixed,
FixedIndex: fixedIndex,
})
}, func(alive bool, l4proto uint8, ipversion uint8) {})
for i := 0; i < 10; i++ {
d, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
if err != nil {
t.Fatal(err)
}
@ -58,7 +58,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
fixedIndex = 0
g.selectionPolicy.FixedIndex = fixedIndex
for i := 0; i < 10; i++ {
d, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
if err != nil {
t.Fatal(err)
}
@ -98,7 +98,7 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
}
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_MinLastLatency,
})
}, func(alive bool, l4proto uint8, ipversion uint8) {})
// Test 1000 times.
for i := 0; i < 1000; i++ {
@ -127,7 +127,7 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
}
g.AliveTcp4DialerSet.NotifyLatencyChange(d, alive)
}
d, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
if err != nil {
t.Fatal(err)
}
@ -170,10 +170,10 @@ func TestDialerGroup_Select_Random(t *testing.T) {
}
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Random,
})
}, func(alive bool, l4proto uint8, ipversion uint8) {})
count := make([]int, len(dialers))
for i := 0; i < 100; i++ {
d, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
d, _, err := g.Select(consts.L4ProtoStr_TCP, consts.IpVersionStr_4)
if err != nil {
t.Fatal(err)
}
@ -217,12 +217,12 @@ func TestDialerGroup_SetAlive(t *testing.T) {
}
g := NewDialerGroup(option, "test-group", dialers, DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Random,
})
}, func(alive bool, l4proto uint8, ipversion uint8) {})
zeroTarget := 3
g.AliveTcp4DialerSet.NotifyLatencyChange(dialers[zeroTarget], false)
count := make([]int, len(dialers))
for i := 0; i < 100; i++ {
d, err := g.Select(consts.L4ProtoStr_UDP, consts.IpVersionStr_4)
d, _, err := g.Select(consts.L4ProtoStr_UDP, consts.IpVersionStr_4)
if err != nil {
t.Fatal(err)
}