fix: should allow fallbacking ip version if dialing domain (#164)

This commit is contained in:
mzz 2023-06-29 22:30:33 +08:00 committed by GitHub
parent 1a9afb0913
commit 0bac1c6ecb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 56 additions and 25 deletions

View File

@ -131,6 +131,16 @@ const (
IpVersion_X IpVersionType = 3 IpVersion_X IpVersionType = 3
) )
func (v IpVersionType) ToIpVersionStr() IpVersionStr {
switch v {
case IpVersion_4:
return IpVersionStr_4
case IpVersion_6:
return IpVersionStr_6
}
panic("unsupported ipversion")
}
var ( var (
BasicFeatureVersion = internal.Version{5, 2, 0} BasicFeatureVersion = internal.Version{5, 2, 0}
// Deprecated: Ftrace does not support arm64 yet (Linux 6.2). // Deprecated: Ftrace does not support arm64 yet (Linux 6.2).

View File

@ -6,6 +6,7 @@
package outbound package outbound
import ( import (
"errors"
"fmt" "fmt"
"time" "time"
@ -207,8 +208,17 @@ func (d *DialerGroup) MustGetAliveDialerSet(typ *dialer.NetworkType) *dialer.Ali
panic("invalid param") panic("invalid param")
} }
// Select selects a dialer from group according to selectionPolicy. // Select selects a dialer from group according to selectionPolicy. If 'strictIpVersion' is false and no alive dialer, it will fallback to another ipversion.
func (g *DialerGroup) Select(networkType *dialer.NetworkType) (d *dialer.Dialer, latency time.Duration, err error) { func (g *DialerGroup) Select(networkType *dialer.NetworkType, strictIpVersion bool) (d *dialer.Dialer, latency time.Duration, err error) {
d, latency, err = g._select(networkType)
if !strictIpVersion && errors.Is(err, NoAliveDialerError) {
networkType.IpVersion = (consts.IpVersion_X - networkType.IpVersion.ToIpVersionType()).ToIpVersionStr()
return g._select(networkType)
}
return d, latency, err
}
func (g *DialerGroup) _select(networkType *dialer.NetworkType) (d *dialer.Dialer, latency time.Duration, err error) {
if len(g.Dialers) == 0 { if len(g.Dialers) == 0 {
return nil, 0, fmt.Errorf("no dialer in this group") return nil, 0, fmt.Errorf("no dialer in this group")
} }

View File

@ -24,13 +24,16 @@ func FormatL4Proto(l4proto uint8) string {
return strconv.Itoa(int(l4proto)) return strconv.Itoa(int(l4proto))
} }
func (c *controlPlaneCore) OutboundAliveChangeCallback(outbound uint8) func(alive bool, networkType *dialer.NetworkType, isInit bool) { func (c *controlPlaneCore) outboundAliveChangeCallback(outbound uint8, dryrun bool) func(alive bool, networkType *dialer.NetworkType, isInit bool) {
return func(alive bool, networkType *dialer.NetworkType, isInit bool) { return func(alive bool, networkType *dialer.NetworkType, isInit bool) {
select { select {
case <-c.closed.Done(): case <-c.closed.Done():
return return
default: default:
} }
if !isInit && dryrun {
return
}
if !isInit || c.log.IsLevelEnabled(logrus.TraceLevel) { if !isInit || c.log.IsLevelEnabled(logrus.TraceLevel) {
strAlive := "NOT ALIVE" strAlive := "NOT ALIVE"
if alive { if alive {

View File

@ -248,19 +248,30 @@ func NewControlPlane(
TlsImplementation: global.TlsImplementation, TlsImplementation: global.TlsImplementation,
UtlsImitate: global.UtlsImitate, UtlsImitate: global.UtlsImitate,
} }
// Dial mode.
dialMode, err := consts.ParseDialMode(global.DialMode)
if err != nil {
return nil, err
}
sniffingTimeout := global.SniffingTimeout
if dialMode == consts.DialMode_Ip {
sniffingTimeout = 0
}
disableKernelAliveCallback := dialMode != consts.DialMode_Ip
outbounds := []*outbound.DialerGroup{ outbounds := []*outbound.DialerGroup{
outbound.NewDialerGroup(option, consts.OutboundDirect.String(), outbound.NewDialerGroup(option, consts.OutboundDirect.String(),
[]*dialer.Dialer{dialer.NewDirectDialer(option, true)}, []*dialer.Dialer{dialer.NewDirectDialer(option, true)},
outbound.DialerSelectionPolicy{ outbound.DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Fixed, Policy: consts.DialerSelectionPolicy_Fixed,
FixedIndex: 0, FixedIndex: 0,
}, core.OutboundAliveChangeCallback(0)), }, core.outboundAliveChangeCallback(0, disableKernelAliveCallback)),
outbound.NewDialerGroup(option, consts.OutboundBlock.String(), outbound.NewDialerGroup(option, consts.OutboundBlock.String(),
[]*dialer.Dialer{dialer.NewBlockDialer(option, func() { /*Dialer Outbound*/ })}, []*dialer.Dialer{dialer.NewBlockDialer(option, func() { /*Dialer Outbound*/ })},
outbound.DialerSelectionPolicy{ outbound.DialerSelectionPolicy{
Policy: consts.DialerSelectionPolicy_Fixed, Policy: consts.DialerSelectionPolicy_Fixed,
FixedIndex: 0, FixedIndex: 0,
}, core.OutboundAliveChangeCallback(1)), }, core.outboundAliveChangeCallback(1, disableKernelAliveCallback)),
} }
// Filter out groups. // Filter out groups.
@ -290,7 +301,8 @@ func NewControlPlane(
log.Infoln("\t<Empty>") log.Infoln("\t<Empty>")
} }
// Create dialer group and append it to outbounds. // Create dialer group and append it to outbounds.
dialerGroup := outbound.NewDialerGroup(option, group.Name, dialers, *policy, core.OutboundAliveChangeCallback(uint8(len(outbounds)))) dialerGroup := outbound.NewDialerGroup(option, group.Name, dialers, *policy,
core.outboundAliveChangeCallback(uint8(len(outbounds)), disableKernelAliveCallback))
outbounds = append(outbounds, dialerGroup) outbounds = append(outbounds, dialerGroup)
} }
@ -339,16 +351,7 @@ func NewControlPlane(
return nil, fmt.Errorf("RoutingMatcherBuilder.BuildUserspace: %w", err) return nil, fmt.Errorf("RoutingMatcherBuilder.BuildUserspace: %w", err)
} }
/// Dial mode. // New control plane.
dialMode, err := consts.ParseDialMode(global.DialMode)
if err != nil {
return nil, err
}
sniffingTimeout := global.SniffingTimeout
if dialMode == consts.DialMode_Ip {
sniffingTimeout = 0
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
plane := &ControlPlane{ plane := &ControlPlane{
log: log, log: log,
@ -553,7 +556,7 @@ func (c *ControlPlane) dnsUpstreamReadyCallback(dnsUpstream *dns.Upstream) (err
return nil return nil
} }
func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip.AddrPort, domain string) (dialTarget string, shouldReroute bool) { func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip.AddrPort, domain string) (dialTarget string, shouldReroute bool, dialIp bool) {
dialMode := consts.DialMode_Ip dialMode := consts.DialMode_Ip
if !outbound.IsReserved() && domain != "" { if !outbound.IsReserved() && domain != "" {
@ -601,6 +604,7 @@ func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip
switch dialMode { switch dialMode {
case consts.DialMode_Ip: case consts.DialMode_Ip:
dialTarget = dst.String() dialTarget = dst.String()
dialIp = true
case consts.DialMode_Domain: case consts.DialMode_Domain:
if strings.HasPrefix(domain, "[") && strings.HasSuffix(domain, "]") { if strings.HasPrefix(domain, "[") && strings.HasSuffix(domain, "]") {
// Sniffed domain may be like `[2606:4700:20::681a:d1f]`. We should remove the brackets. // Sniffed domain may be like `[2606:4700:20::681a:d1f]`. We should remove the brackets.
@ -609,6 +613,7 @@ func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip
if _, err := netip.ParseAddr(domain); err == nil { if _, err := netip.ParseAddr(domain); err == nil {
// domain is IPv4 or IPv6 (has colon) // domain is IPv4 or IPv6 (has colon)
dialTarget = net.JoinHostPort(domain, strconv.Itoa(int(dst.Port()))) dialTarget = net.JoinHostPort(domain, strconv.Itoa(int(dst.Port())))
dialIp = true
} else if _, _, err := net.SplitHostPort(domain); err == nil { } else if _, _, err := net.SplitHostPort(domain); err == nil {
// domain is already domain:port // domain is already domain:port
@ -622,7 +627,7 @@ func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip
"to": dialTarget, "to": dialTarget,
}).Debugln("Rewrite dial target to domain") }).Debugln("Rewrite dial target to domain")
} }
return dialTarget, shouldReroute return dialTarget, shouldReroute, dialIp
} }
type Listener struct { type Listener struct {
@ -852,7 +857,8 @@ func (c *ControlPlane) chooseBestDnsDialer(
return nil, fmt.Errorf("bad outbound index: %v", outboundIndex) return nil, fmt.Errorf("bad outbound index: %v", outboundIndex)
} }
dialerGroup := c.outbounds[outboundIndex] dialerGroup := c.outbounds[outboundIndex]
d, latency, err := dialerGroup.Select(&networkType) // DNS always dial IP.
d, latency, err := dialerGroup.Select(&networkType, true)
if err != nil { if err != nil {
continue continue
} }

View File

@ -82,7 +82,7 @@ destRetrieved:
outboundIndex = consts.OutboundControlPlaneRouting outboundIndex = consts.OutboundControlPlaneRouting
} }
dialTarget, shouldReroute := c.ChooseDialTarget(outboundIndex, dst, domain) dialTarget, shouldReroute, dialIp := c.ChooseDialTarget(outboundIndex, dst, domain)
if shouldReroute { if shouldReroute {
outboundIndex = consts.OutboundControlPlaneRouting outboundIndex = consts.OutboundControlPlaneRouting
} }
@ -102,7 +102,7 @@ destRetrieved:
) )
} }
// Reset dialTarget. // Reset dialTarget.
dialTarget, _ = c.ChooseDialTarget(outboundIndex, dst, domain) dialTarget, _, dialIp = c.ChooseDialTarget(outboundIndex, dst, domain)
default: default:
} }
if routingResult.Mark == 0 { if routingResult.Mark == 0 {
@ -118,7 +118,8 @@ destRetrieved:
IpVersion: consts.IpVersionFromAddr(dst.Addr()), IpVersion: consts.IpVersionFromAddr(dst.Addr()),
IsDns: false, IsDns: false,
} }
d, _, err := outbound.Select(networkType) strictIpVersion := dialIp
d, _, err := outbound.Select(networkType, strictIpVersion)
if err != nil { if err != nil {
return fmt.Errorf("failed to select dialer from group %v (%v): %w", outbound.Name, networkType.String(), err) return fmt.Errorf("failed to select dialer from group %v (%v): %w", outbound.Name, networkType.String(), err)
} }

View File

@ -146,7 +146,7 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
outboundIndex = consts.OutboundControlPlaneRouting outboundIndex = consts.OutboundControlPlaneRouting
} }
dialTarget, shouldReroute := c.ChooseDialTarget(outboundIndex, realDst, domain) dialTarget, shouldReroute, dialIp := c.ChooseDialTarget(outboundIndex, realDst, domain)
if shouldReroute { if shouldReroute {
outboundIndex = consts.OutboundControlPlaneRouting outboundIndex = consts.OutboundControlPlaneRouting
} }
@ -173,7 +173,7 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
) )
} }
// Reset dialTarget. // Reset dialTarget.
dialTarget, _ = c.ChooseDialTarget(outboundIndex, realDst, domain) dialTarget, _, dialIp = c.ChooseDialTarget(outboundIndex, realDst, domain)
default: default:
} }
if routingResult.Mark == 0 { if routingResult.Mark == 0 {
@ -201,7 +201,8 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
IpVersion: consts.IpVersionFromAddr(realDst.Addr()), IpVersion: consts.IpVersionFromAddr(realDst.Addr()),
IsDns: true, // UDP relies on DNS check result. IsDns: true, // UDP relies on DNS check result.
} }
dialerForNew, _, err := outbound.Select(networkType) strictIpVersion := dialIp
dialerForNew, _, err := outbound.Select(networkType, strictIpVersion)
if err != nil { if err != nil {
return fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err) return fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err)
} }