refactor: rename check_url to tcp_check_url and restrict dns_upstream as udp://ip:port

This commit is contained in:
mzz2017
2023-02-08 16:07:23 +08:00
committed by mzz
parent a3d4a06dab
commit 551e79d9e5
6 changed files with 88 additions and 34 deletions

View File

@ -75,11 +75,7 @@ func Run(log *logrus.Logger, param *config.Params) (err error) {
nodeList, nodeList,
param.Group, param.Group,
&param.Routing, &param.Routing,
param.Global.DnsUpstream, &param.Global,
param.Global.CheckUrl,
param.Global.CheckInterval,
param.Global.LanInterface,
param.Global.WanInterface,
) )
if err != nil { if err != nil {
return err return err

View File

@ -10,6 +10,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"github.com/v2rayA/dae/config"
"net/url" "net/url"
"reflect" "reflect"
"strconv" "strconv"
@ -276,6 +277,27 @@ func FuzzyDecode(to interface{}, val string) bool {
} }
case reflect.String: case reflect.String:
v.SetString(val) v.SetString(val)
case reflect.Struct:
switch v.Interface().(type) {
case config.UrlOrEmpty:
if val == "" {
v.Set(reflect.ValueOf(config.UrlOrEmpty{
Url: nil,
Empty: true,
}))
} else {
u, err := url.Parse(val)
if err != nil {
return false
}
v.Set(reflect.ValueOf(config.UrlOrEmpty{
Url: u,
Empty: false,
}))
}
default:
return false
}
default: default:
return false return false
} }

View File

@ -8,16 +8,23 @@ package config
import ( import (
"fmt" "fmt"
"github.com/v2rayA/dae/pkg/config_parser" "github.com/v2rayA/dae/pkg/config_parser"
"net/url"
"reflect" "reflect"
"time" "time"
) )
type UrlOrEmpty struct {
Url *url.URL
Empty bool
}
type Global struct { type Global struct {
TproxyPort uint16 `mapstructure:"tproxy_port" default:"12345"` TproxyPort uint16 `mapstructure:"tproxy_port" default:"12345"`
LogLevel string `mapstructure:"log_level" default:"info"` LogLevel string `mapstructure:"log_level" default:"info"`
CheckUrl string `mapstructure:"check_url" default:"https://connectivitycheck.gstatic.com/generate_204"` TcpCheckUrl string `mapstructure:"tcp_check_url" default:"https://connectivitycheck.gstatic.com/generate_204"`
UdpCheckDns string `mapstructure:"udp_check_dns" default:"8.8.8.8:53"`
CheckInterval time.Duration `mapstructure:"check_interval" default:"15s"` CheckInterval time.Duration `mapstructure:"check_interval" default:"15s"`
DnsUpstream string `mapstructure:"dns_upstream" require:""` DnsUpstream UrlOrEmpty `mapstructure:"dns_upstream" require:""`
LanInterface []string `mapstructure:"lan_interface"` LanInterface []string `mapstructure:"lan_interface"`
WanInterface []string `mapstructure:"wan_interface"` WanInterface []string `mapstructure:"wan_interface"`
} }
@ -81,13 +88,13 @@ func New(sections []*config_parser.Section) (params *Params, err error) {
if !ok { if !ok {
return nil, fmt.Errorf("no parser is specified in field %v", structField.Name) return nil, fmt.Errorf("no parser is specified in field %v", structField.Name)
} }
parserFunc, ok := ParserMap[parserName] parser, ok := ParserMap[parserName]
if !ok { if !ok {
return nil, fmt.Errorf("unknown parser %v in field %v", parserName, structField.Name) return nil, fmt.Errorf("unknown parser %v in field %v", parserName, structField.Name)
} }
// Parse section and unmarshal to field. // Parse section and unmarshal to field.
if err := parserFunc(field.Addr(), section.Val); err != nil { if err := parser(field.Addr(), section.Val); err != nil {
return nil, fmt.Errorf("failed to parse \"%v\": %w", sectionName, err) return nil, fmt.Errorf("failed to parse \"%v\": %w", sectionName, err)
} }
section.Parsed = true section.Parsed = true

View File

@ -29,7 +29,6 @@ import (
"strings" "strings"
"sync" "sync"
"syscall" "syscall"
"time"
) )
type ControlPlane struct { type ControlPlane struct {
@ -56,11 +55,7 @@ func NewControlPlane(
nodes []string, nodes []string,
groups []config.Group, groups []config.Group,
routingA *config.Routing, routingA *config.Routing,
dnsUpstream string, global *config.Global,
checkUrl string,
checkInterval time.Duration,
lanInterface []string,
wanInterface []string,
) (c *ControlPlane, err error) { ) (c *ControlPlane, err error) {
kernelVersion, e := internal.KernelVersion() kernelVersion, e := internal.KernelVersion()
if e != nil { if e != nil {
@ -73,12 +68,12 @@ func NewControlPlane(
kernelVersion.String(), kernelVersion.String(),
consts.ChecksumFeatureVersion.String()) consts.ChecksumFeatureVersion.String())
} }
if len(wanInterface) > 0 && kernelVersion.Less(consts.CgSocketCookieFeatureVersion) { if len(global.WanInterface) > 0 && kernelVersion.Less(consts.CgSocketCookieFeatureVersion) {
return nil, fmt.Errorf("your kernel version %v does not support bind to WAN; expect >=%v; remove wan_interface in config file and try again", return nil, fmt.Errorf("your kernel version %v does not support bind to WAN; expect >=%v; remove wan_interface in config file and try again",
kernelVersion.String(), kernelVersion.String(),
consts.CgSocketCookieFeatureVersion.String()) consts.CgSocketCookieFeatureVersion.String())
} }
if len(lanInterface) > 0 && kernelVersion.Less(consts.SkAssignFeatureVersion) { if len(global.LanInterface) > 0 && kernelVersion.Less(consts.SkAssignFeatureVersion) {
return nil, fmt.Errorf("your kernel version %v does not support bind to LAN; expect >=%v; remove lan_interface in config file and try again", return nil, fmt.Errorf("your kernel version %v does not support bind to LAN; expect >=%v; remove lan_interface in config file and try again",
kernelVersion.String(), kernelVersion.String(),
consts.SkAssignFeatureVersion.String()) consts.SkAssignFeatureVersion.String())
@ -117,8 +112,8 @@ func NewControlPlane(
if err = selectivelyLoadBpfObjects(log, &bpf, &loadBpfOptions{ if err = selectivelyLoadBpfObjects(log, &bpf, &loadBpfOptions{
PinPath: pinPath, PinPath: pinPath,
CollectionOptions: collectionOpts, CollectionOptions: collectionOpts,
BindLan: len(lanInterface) > 0, BindLan: len(global.LanInterface) > 0,
BindWan: len(wanInterface) > 0, BindWan: len(global.WanInterface) > 0,
}); err != nil { }); err != nil {
return nil, fmt.Errorf("load eBPF objects: %w", err) return nil, fmt.Errorf("load eBPF objects: %w", err)
} }
@ -164,26 +159,26 @@ func NewControlPlane(
}() }()
/// Bind to links. Binding should be advance of dialerGroups to avoid un-routable old connection. /// Bind to links. Binding should be advance of dialerGroups to avoid un-routable old connection.
// Add clsact qdisc // Add clsact qdisc
for _, ifname := range common.Deduplicate(append(append([]string{}, lanInterface...), wanInterface...)) { for _, ifname := range common.Deduplicate(append(append([]string{}, global.LanInterface...), global.WanInterface...)) {
_ = core.addQdisc(ifname) _ = core.addQdisc(ifname)
} }
// Bind to LAN // Bind to LAN
if len(lanInterface) > 0 { if len(global.LanInterface) > 0 {
if err = core.setupRoutingPolicy(); err != nil { if err = core.setupRoutingPolicy(); err != nil {
return nil, err return nil, err
} }
for _, ifname := range lanInterface { for _, ifname := range global.LanInterface {
if err = core.bindLan(ifname); err != nil { if err = core.bindLan(ifname); err != nil {
return nil, fmt.Errorf("bindLan: %v: %w", ifname, err) return nil, fmt.Errorf("bindLan: %v: %w", ifname, err)
} }
} }
} }
// Bind to WAN // Bind to WAN
if len(wanInterface) > 0 { if len(global.WanInterface) > 0 {
if err = core.setupSkPidMonitor(); err != nil { if err = core.setupSkPidMonitor(); err != nil {
return nil, err return nil, err
} }
for _, ifname := range wanInterface { for _, ifname := range global.WanInterface {
if err = core.bindWan(ifname); err != nil { if err = core.bindWan(ifname); err != nil {
return nil, fmt.Errorf("bindWan: %v: %w", ifname, err) return nil, fmt.Errorf("bindWan: %v: %w", ifname, err)
} }
@ -193,8 +188,8 @@ func NewControlPlane(
/// DialerGroups (outbounds). /// DialerGroups (outbounds).
option := &dialer.GlobalOption{ option := &dialer.GlobalOption{
Log: log, Log: log,
CheckUrl: checkUrl, CheckUrl: global.TcpCheckUrl,
CheckInterval: checkInterval, CheckInterval: global.CheckInterval,
} }
outbounds := []*outbound.DialerGroup{ outbounds := []*outbound.DialerGroup{
outbound.NewDialerGroup(option, consts.OutboundDirect.String(), outbound.NewDialerGroup(option, consts.OutboundDirect.String(),
@ -273,10 +268,9 @@ func NewControlPlane(
/// DNS upstream. /// DNS upstream.
var dnsAddrPort netip.AddrPort var dnsAddrPort netip.AddrPort
if dnsUpstream != "" { if !global.DnsUpstream.Empty {
dnsAddrPort, err = netip.ParseAddrPort(dnsUpstream) if dnsAddrPort, err = resolveDnsUpstream(global.DnsUpstream.Url); err != nil {
if err != nil { return nil, err
return nil, fmt.Errorf("failed to parse DNS upstream: \"%v\": %w", dnsUpstream, err)
} }
dnsAddr16 := dnsAddrPort.Addr().As16() dnsAddr16 := dnsAddrPort.Addr().As16()
if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfIpPort{ if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfIpPort{
@ -286,6 +280,7 @@ func NewControlPlane(
return nil, err return nil, err
} }
} else { } else {
// Empty.
if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfIpPort{ if err = bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfIpPort{
Ip: [4]uint32{}, Ip: [4]uint32{},
// Zero port indicates no element, because bpf_map_lookup_elem cannot return 0 for map_type_array. // Zero port indicates no element, because bpf_map_lookup_elem cannot return 0 for map_type_array.
@ -297,7 +292,7 @@ func NewControlPlane(
/// Listen address. /// Listen address.
listenIp := "::1" listenIp := "::1"
if len(wanInterface) > 0 { if len(global.WanInterface) > 0 {
listenIp = "0.0.0.0" listenIp = "0.0.0.0"
} }
return &ControlPlane{ return &ControlPlane{

34
control/control_utils.go Normal file
View File

@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2023, mzz2017 <mzz@tuta.io>
*/
package control
import (
"fmt"
"net"
"net/netip"
"net/url"
)
func resolveDnsUpstream(dnsUpstream *url.URL) (addrPort netip.AddrPort, err error) {
if dnsUpstream.Scheme != "udp" {
return netip.AddrPort{}, fmt.Errorf("dns_upstream now only supports udp://")
}
port := dnsUpstream.Port()
if port == "" {
port = "53"
}
hostname := dnsUpstream.Hostname()
ips, _ := net.LookupIP(hostname)
if len(ips) == 0 {
return netip.AddrPort{}, fmt.Errorf("cannot resolve hostname of dns upstream: %v", hostname)
}
// resolve hostname
dnsAddrPort, err := netip.ParseAddrPort(net.JoinHostPort(ips[0].String(), port))
if err != nil {
return netip.AddrPort{}, fmt.Errorf("failed to parse DNS upstream: \"%v\": %w", dnsUpstream.String(), err)
}
return dnsAddrPort, nil
}

View File

@ -6,14 +6,14 @@ global {
log_level: info log_level: info
# Node connectivity check. # Node connectivity check.
check_url: 'https://connectivitycheck.gstatic.com/generate_204' tcp_check_url: 'https://connectivitycheck.gstatic.com/generate_204'
check_interval: 30s check_interval: 30s
# Now only support UDP and format IP:Port. Empty value '' indicates as-is. # Now only support udp://IP:Port. Empty value '' indicates as-is.
# Please make sure DNS traffic will go through and be forwarded by dae. # Please make sure DNS traffic will go through and be forwarded by dae.
# The upstream DNS answer MUST NOT be polluted. # The upstream DNS answer MUST NOT be polluted.
# The request to dns upstream follows routing defined below. # The request to dns upstream follows routing defined below.
dns_upstream: '8.8.8.8:53' dns_upstream: 'udp://8.8.8.8:53'
# The LAN interface to bind. Use it if you only want to proxy LAN instead of localhost. # The LAN interface to bind. Use it if you only want to proxy LAN instead of localhost.
# Multiple interfaces split by ",". # Multiple interfaces split by ",".