diff --git a/README.md b/README.md index 25e1366..82990b8 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,5 @@ See [example.dae](https://github.com/v2rayA/dae/blob/main/example.dae). But the problem is, after the Linux network stack, before entering the network card, we modify the source IP of this packet, causing the Linux network stack to only make a simple checksum, and the NIC also assumes that this packet is not sent from local, so no further checksum completing. 1. MACv2 extension extraction. 1. Log to userspace. -1. Support include section. -1. Subscription section supports "file://" -1. Subscription section supports key. +1. Subscription section supports key. And support to filter by subscription key. 1. ... diff --git a/cmd/internal/subscription.go b/cmd/internal/subscription.go index 721067e..35c08e0 100644 --- a/cmd/internal/subscription.go +++ b/cmd/internal/subscription.go @@ -1,6 +1,8 @@ package internal import ( + "bufio" + "bytes" "encoding/json" "fmt" "github.com/sirupsen/logrus" @@ -9,6 +11,8 @@ import ( "net" "net/http" "net/url" + "os" + "path/filepath" "strconv" "strings" ) @@ -80,26 +84,82 @@ func resolveSubscriptionAsSIP008(log *logrus.Logger, b []byte) (nodes []string, return nodes, nil } -func ResolveSubscription(log *logrus.Logger, subscription string) (nodes []string, err error) { +func resolveFile(u *url.URL, configDir string) (b []byte, err error) { + if u.Host == "" { + return nil, fmt.Errorf("not support absolute path") + } + /// Relative location. + // Make sure path is secure. + path := filepath.Join(configDir, u.Host, u.Path) + if err = common.EnsureFileInSubDir(path, configDir); err != nil { + return nil, err + } + /// Read and resolve + f, err := os.Open(path) + if err != nil { + return nil, err + } + // Check file access. + fi, err := f.Stat() + if err != nil { + return nil, err + } + if fi.IsDir() { + return nil, fmt.Errorf("subscription file cannot be a directory: %v", path) + } + if fi.Mode()&0037 > 0 { + return nil, fmt.Errorf("permissions %04o for '%v' are too open; requires the file is NOT writable by the same group and NOT accessible by others; suggest 0640 or 0600", fi.Mode()&0777, path) + } + // Resolve the first line instruction. + fReader := bufio.NewReader(f) + b, err = fReader.Peek(1) + if err != nil { + return nil, err + } + if string(b[0]) == "@" { + // Instruction line. But not support yet. + _, _, err = fReader.ReadLine() + if err != nil { + return nil, err + } + } + + b, err = io.ReadAll(fReader) + if err != nil { + return nil, err + } + return bytes.TrimSpace(b), err +} + +func ResolveSubscription(log *logrus.Logger, configDir string, subscription string) (nodes []string, err error) { u, err := url.Parse(subscription) if err != nil { return nil, fmt.Errorf("failed to parse subscription \"%v\": %w", subscription, err) } + log.Debugf("ResolveSubscription: %v", subscription) + var ( + b []byte + resp *http.Response + ) switch u.Scheme { case "file": - // TODO + b, err = resolveFile(u, configDir) + if err != nil { + return nil, err + } + goto resolve default: } - log.Debugf("ResolveSubscription: %v", subscription) - resp, err := http.Get(subscription) + resp, err = http.Get(subscription) if err != nil { return nil, err } defer resp.Body.Close() - b, err := io.ReadAll(resp.Body) + b, err = io.ReadAll(resp.Body) if err != nil { return nil, err } +resolve: if nodes, err = resolveSubscriptionAsSIP008(log, b); err == nil { return nodes, nil } else { diff --git a/cmd/run.go b/cmd/run.go index ea3c917..e2a5976 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -7,10 +7,11 @@ import ( "github.com/v2rayA/dae/cmd/internal" "github.com/v2rayA/dae/config" "github.com/v2rayA/dae/control" - "github.com/v2rayA/dae/pkg/config_parser" "github.com/v2rayA/dae/pkg/logger" "os" "os/signal" + "path/filepath" + "strings" "syscall" ) @@ -30,13 +31,15 @@ var ( internal.AutoSu() // Read config from --config cfgFile. - param, err := readConfig(cfgFile) + param, includes, err := readConfig(cfgFile) if err != nil { logrus.Fatalln("readConfig:", err) } log := logger.NewLogger(param.Global.LogLevel, disableTimestamp) logrus.SetLevel(log.Level) + + log.Infof("Include config files: [%v]", strings.Join(includes, ", ")) if err := Run(log, param); err != nil { logrus.Fatalln(err) } @@ -55,7 +58,7 @@ func Run(log *logrus.Logger, param *config.Params) (err error) { nodeList := make([]string, len(param.Node)) copy(nodeList, param.Node) for _, sub := range param.Subscription { - nodes, err := internal.ResolveSubscription(log, sub) + nodes, err := internal.ResolveSubscription(log, filepath.Dir(cfgFile), sub) if err != nil { log.Warnf(`failed to resolve subscription "%v": %v`, sub, err) } @@ -97,17 +100,14 @@ func Run(log *logrus.Logger, param *config.Params) (err error) { return nil } -func readConfig(cfgFile string) (params *config.Params, err error) { - b, err := os.ReadFile(cfgFile) +func readConfig(cfgFile string) (params *config.Params, entries []string, err error) { + merger := config.NewMerger(cfgFile) + sections, entries, err := merger.Merge() if err != nil { - return nil, err - } - sections, err := config_parser.Parse(string(b)) - if err != nil { - return nil, fmt.Errorf("\n%w", err) + return nil, nil, err } if params, err = config.New(sections); err != nil { - return nil, err + return nil, nil, err } - return params, nil + return params, entries, nil } diff --git a/cmd/validate.go b/cmd/validate.go index 93e929b..ad46e1b 100644 --- a/cmd/validate.go +++ b/cmd/validate.go @@ -21,7 +21,7 @@ var ( os.Exit(1) } // Read config from --config cfgFile. - _, err := readConfig(cfgFile) + _, _, err := readConfig(cfgFile) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/common/utils.go b/common/utils.go index e6791f6..74c1155 100644 --- a/common/utils.go +++ b/common/utils.go @@ -11,6 +11,7 @@ import ( "encoding/hex" "fmt" "net/url" + "path/filepath" "reflect" "strconv" "strings" @@ -307,3 +308,34 @@ func FuzzyDecode(to interface{}, val string) bool { } return true } + +func EnsureFileInSubDir(filePath string, dir string) (err error) { + fileDir := filepath.Dir(filePath) + if len(dir) == 0 { + return fmt.Errorf("bad dir: %v", dir) + } + rel, err := filepath.Rel(dir, fileDir) + if err != nil { + return err + } + if strings.HasPrefix(rel, "..") { + return fmt.Errorf("file is out of scope: %v", rel) + } + return nil +} + +func MapKeys(m interface{}) (keys []string, err error) { + v := reflect.ValueOf(m) + if v.Kind() != reflect.Map { + return nil, fmt.Errorf("MapKeys requires map[string]*") + } + if v.Type().Key().Kind() != reflect.String { + return nil, fmt.Errorf("MapKeys requires map[string]*") + } + _keys := v.MapKeys() + keys = make([]string, 0, len(_keys)) + for _, k := range _keys { + keys = append(keys, k.String()) + } + return keys, nil +} diff --git a/component/outbound/dialer/alive_dialer_set.go b/component/outbound/dialer/alive_dialer_set.go index 1cda121..a98a780 100644 --- a/component/outbound/dialer/alive_dialer_set.go +++ b/component/outbound/dialer/alive_dialer_set.go @@ -26,6 +26,7 @@ type AliveDialerSet struct { dialerGroupName string l4proto consts.L4ProtoStr ipversion consts.IpVersionStr + tolerance time.Duration aliveChangeCallback func(alive bool) @@ -43,6 +44,7 @@ func NewAliveDialerSet( dialerGroupName string, l4proto consts.L4ProtoStr, ipversion consts.IpVersionStr, + tolerance time.Duration, selectionPolicy consts.DialerSelectionPolicy, dialers []*Dialer, aliveChangeCallback func(alive bool), @@ -53,6 +55,7 @@ func NewAliveDialerSet( dialerGroupName: dialerGroupName, l4proto: l4proto, ipversion: ipversion, + tolerance: tolerance, aliveChangeCallback: aliveChangeCallback, dialerToIndex: make(map[*Dialer]int), dialerToLatency: make(map[*Dialer]time.Duration), @@ -146,14 +149,19 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) { bakOldBestDialer := a.minLatency.dialer // Calc minLatency. a.dialerToLatency[dialer] = latency - if alive && latency < a.minLatency.latency { + if alive && latency <= a.minLatency.latency-a.tolerance { a.minLatency.latency = latency a.minLatency.dialer = dialer } else if a.minLatency.dialer == dialer { - a.minLatency.latency = time.Hour - a.minLatency.dialer = nil - a.calcMinLatency() - // Now `a.minLatency.dialer` will be nil if there is no alive dialer. + if latency > a.minLatency.latency { + // Latency increases. + a.minLatency.latency = time.Hour + a.minLatency.dialer = nil + a.calcMinLatency() + // Now `a.minLatency.dialer` will be nil if there is no alive dialer. + } else { + a.minLatency.latency = latency + } } currentAlive := a.minLatency.dialer != nil // If best dialer changed. @@ -169,7 +177,8 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) { string(a.selectionPolicy): a.minLatency.latency, "group": a.dialerGroupName, "network": string(a.l4proto) + string(a.ipversion), - "dialer": a.minLatency.dialer.Name(), + "new dialer": a.minLatency.dialer.Name(), + "old dialer": bakOldBestDialer.Name(), }).Infof("Group %vselects dialer", re) } else { // Alive -> not alive @@ -195,8 +204,11 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) { func (a *AliveDialerSet) calcMinLatency() { for _, d := range a.inorderedAliveDialerSet { - latency := a.dialerToLatency[d] - if latency < a.minLatency.latency { + latency, ok := a.dialerToLatency[d] + if !ok { + continue + } + if latency <= a.minLatency.latency-a.tolerance { a.minLatency.latency = latency a.minLatency.dialer = d } diff --git a/component/outbound/dialer/dialer.go b/component/outbound/dialer/dialer.go index c502594..38e7763 100644 --- a/component/outbound/dialer/dialer.go +++ b/component/outbound/dialer/dialer.go @@ -37,6 +37,7 @@ type GlobalOption struct { TcpCheckOptionRaw TcpCheckOptionRaw // Lazy parse UdpCheckOptionRaw UdpCheckOptionRaw // Lazy parse CheckInterval time.Duration + CheckTolerance time.Duration } type InstanceOption struct { diff --git a/component/outbound/dialer_group.go b/component/outbound/dialer_group.go index bb8e312..179c50e 100644 --- a/component/outbound/dialer_group.go +++ b/component/outbound/dialer_group.go @@ -39,10 +39,10 @@ type DialerGroup struct { func NewDialerGroup(option *dialer.GlobalOption, name string, dialers []*dialer.Dialer, p DialerSelectionPolicy, aliveChangeCallback func(alive bool, l4proto uint8, ipversion uint8)) *DialerGroup { log := option.Log var registeredAliveDialerSet bool - aliveTcp4DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_TCP, consts.IpVersionStr_4, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_TCP, 4) }, true) - aliveTcp6DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_TCP, consts.IpVersionStr_6, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_TCP, 6) }, true) - aliveUdp4DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_UDP, consts.IpVersionStr_4, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_UDP, 4) }, true) - aliveUdp6DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_UDP, consts.IpVersionStr_6, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_UDP, 6) }, true) + aliveTcp4DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_TCP, consts.IpVersionStr_4, option.CheckTolerance, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_TCP, 4) }, true) + aliveTcp6DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_TCP, consts.IpVersionStr_6, option.CheckTolerance, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_TCP, 6) }, true) + aliveUdp4DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_UDP, consts.IpVersionStr_4, option.CheckTolerance, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_UDP, 4) }, true) + aliveUdp6DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_UDP, consts.IpVersionStr_6, option.CheckTolerance, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_UDP, 6) }, true) switch p.Policy { case consts.DialerSelectionPolicy_Random, diff --git a/config/config.go b/config/config.go index 52cee23..1baff44 100644 --- a/config/config.go +++ b/config/config.go @@ -14,14 +14,15 @@ import ( ) type Global struct { - TproxyPort uint16 `mapstructure:"tproxy_port" default:"12345"` - LogLevel string `mapstructure:"log_level" default:"info"` - TcpCheckUrl string `mapstructure:"tcp_check_url" default:"http://cp.cloudflare.com"` - UdpCheckDns string `mapstructure:"udp_check_dns" default:"cloudflare-dns.com:53"` - CheckInterval time.Duration `mapstructure:"check_interval" default:"30s"` - DnsUpstream common.UrlOrEmpty `mapstructure:"dns_upstream" require:""` - LanInterface []string `mapstructure:"lan_interface"` - WanInterface []string `mapstructure:"wan_interface"` + TproxyPort uint16 `mapstructure:"tproxy_port" default:"12345"` + LogLevel string `mapstructure:"log_level" default:"info"` + TcpCheckUrl string `mapstructure:"tcp_check_url" default:"http://cp.cloudflare.com"` + UdpCheckDns string `mapstructure:"udp_check_dns" default:"cloudflare-dns.com:53"` + CheckInterval time.Duration `mapstructure:"check_interval" default:"30s"` + CheckTolerance time.Duration `mapstructure:"check_tolerance" default:"0"` + DnsUpstream common.UrlOrEmpty `mapstructure:"dns_upstream" require:""` + LanInterface []string `mapstructure:"lan_interface"` + WanInterface []string `mapstructure:"wan_interface"` } type Group struct { @@ -47,7 +48,7 @@ type Params struct { Routing Routing `mapstructure:"routing" parser:"RoutingRuleAndParamParser"` } -// New params from sections. This func assumes merging (section "include") and deduplication for sections has been executed. +// New params from sections. This func assumes merging (section "include") and deduplication for section names has been executed. func New(sections []*config_parser.Section) (params *Params, err error) { // Set up name to section for further use. type Section struct { @@ -95,8 +96,11 @@ func New(sections []*config_parser.Section) (params *Params, err error) { section.Parsed = true } - // Report unknown. Not "unused" because we assume deduplication has been executed before this func. + // Report unknown. Not "unused" because we assume section name deduplication has been executed before this func. for name, section := range nameToSection { + if section.Val.Name == "include" { + continue + } if !section.Parsed { return nil, fmt.Errorf("unknown section: %v", name) } diff --git a/config/config_merger.go b/config/config_merger.go new file mode 100644 index 0000000..748de50 --- /dev/null +++ b/config/config_merger.go @@ -0,0 +1,193 @@ +/* + * SPDX-License-Identifier: AGPL-3.0-only + * Copyright (c) since 2023, mzz2017 + */ + +package config + +import ( + "errors" + "fmt" + "github.com/v2rayA/dae/common" + "github.com/v2rayA/dae/pkg/config_parser" + "io" + "os" + "path/filepath" + "strings" +) + +var ( + CircularIncludeError = fmt.Errorf("circular include is not allowed") +) + +type Merger struct { + entry string + entryDir string + entryToSectionMap map[string]map[string][]*config_parser.Item +} + +func NewMerger(entry string) *Merger { + return &Merger{ + entry: entry, + entryDir: filepath.Dir(entry), + entryToSectionMap: map[string]map[string][]*config_parser.Item{}, + } +} + +func (m *Merger) Merge() (sections []*config_parser.Section, entries []string, err error) { + err = m.dfsMerge(m.entry, "") + if err != nil { + return nil, nil, err + } + entries, err = common.MapKeys(m.entryToSectionMap) + if err != nil { + return nil, nil, err + } + return m.convertMapToSections(m.entryToSectionMap[m.entry]), entries, nil +} + +func (m *Merger) readEntry(entry string) (err error) { + // Check circular include. + _, exist := m.entryToSectionMap[entry] + if exist { + return CircularIncludeError + } + + // Check filename + if !strings.HasSuffix(entry, ".dae") { + return fmt.Errorf("invalid config filename %v: must has suffix .dae", entry) + } + // Check file path security. + if err = common.EnsureFileInSubDir(entry, m.entryDir); err != nil { + return fmt.Errorf("failed in checking path of config file %v: %w", entry, err) + } + f, err := os.Open(entry) + if err != nil { + return fmt.Errorf("failed to read config file %v: %w", entry, err) + } + // Check file access. + fi, err := f.Stat() + if err != nil { + return err + } + if fi.IsDir() { + return fmt.Errorf("cannot include a directory: %v", entry) + } + if fi.Mode()&0037 > 0 { + return fmt.Errorf("permissions %04o for '%v' are too open; requires the file is NOT writable by the same group and NOT accessible by others; suggest 0640 or 0600", fi.Mode()&0777, entry) + } + // Read and parse. + b, err := io.ReadAll(f) + if err != nil { + return err + } + entrySections, err := config_parser.Parse(string(b)) + if err != nil { + return fmt.Errorf("failed to parse config file %v:\n%w", entry, err) + } + m.entryToSectionMap[entry] = m.convertSectionsToMap(entrySections) + return nil +} + +func unsqueezeEntries(patternEntries []string) (unsqueezed []string, err error) { + unsqueezed = make([]string, 0, len(patternEntries)) + for _, pattern := range patternEntries { + files, err := filepath.Glob(pattern) + if err != nil { + return nil, err + } + for _, file := range files { + // We only support .dae + if !strings.HasSuffix(file, ".dae") { + continue + } + fi, err := os.Stat(file) + if err != nil { + return nil, err + } + if fi.IsDir() { + continue + } + unsqueezed = append(unsqueezed, file) + } + } + if len(unsqueezed) == 0 { + unsqueezed = nil + } + return unsqueezed, nil +} + +func (m *Merger) dfsMerge(entry string, fatherEntry string) (err error) { + // Read entry and check circular include. + if err = m.readEntry(entry); err != nil { + if errors.Is(err, CircularIncludeError) { + return fmt.Errorf("%w: %v -> %v -> ... -> %v", err, fatherEntry, entry, fatherEntry) + } + return err + } + sectionMap := m.entryToSectionMap[entry] + // Extract childEntries. + includes := sectionMap["include"] + var patterEntries = make([]string, 0, len(includes)) + for _, include := range includes { + switch v := include.Value.(type) { + case *config_parser.Param: + nextEntry := v.String(true) + patterEntries = append(patterEntries, filepath.Join(m.entryDir, nextEntry)) + default: + return fmt.Errorf("unsupported include grammar in %v: %v", entry, include.String()) + } + } + // DFS and merge children recursively. + childEntries, err := unsqueezeEntries(patterEntries) + if err != nil { + return err + } + for _, nextEntry := range childEntries { + if err = m.dfsMerge(nextEntry, entry); err != nil { + return err + } + } + /// Merge into father. Do not need to retrieve sectionMap again because go map is a reference. + if fatherEntry == "" { + // We are already on the top. + return nil + } + fatherSectionMap := m.entryToSectionMap[fatherEntry] + for sec := range sectionMap { + items := m.mergeItems(fatherSectionMap[sec], sectionMap[sec]) + fatherSectionMap[sec] = items + } + return nil +} + +func (m *Merger) convertSectionsToMap(sections []*config_parser.Section) (sectionMap map[string][]*config_parser.Item) { + sectionMap = make(map[string][]*config_parser.Item) + for _, sec := range sections { + items, ok := sectionMap[sec.Name] + if ok { + sectionMap[sec.Name] = m.mergeItems(items, sec.Items) + } else { + sectionMap[sec.Name] = sec.Items + } + } + return sectionMap +} + +func (m *Merger) convertMapToSections(sectionMap map[string][]*config_parser.Item) (sections []*config_parser.Section) { + sections = make([]*config_parser.Section, 0, len(sectionMap)) + for name, items := range sectionMap { + sections = append(sections, &config_parser.Section{ + Name: name, + Items: items, + }) + } + return sections +} + +func (m *Merger) mergeItems(to, from []*config_parser.Item) (items []*config_parser.Item) { + items = make([]*config_parser.Item, len(to)+len(from)) + copy(items, to) + copy(items[len(to):], from) + return items +} diff --git a/control/control_plane.go b/control/control_plane.go index 2936284..82fd0fb 100644 --- a/control/control_plane.go +++ b/control/control_plane.go @@ -196,6 +196,7 @@ func NewControlPlane( TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: global.TcpCheckUrl}, UdpCheckOptionRaw: dialer.UdpCheckOptionRaw{Raw: global.UdpCheckDns}, CheckInterval: global.CheckInterval, + CheckTolerance: global.CheckTolerance, } outbounds := []*outbound.DialerGroup{ outbound.NewDialerGroup(option, consts.OutboundDirect.String(), diff --git a/example.dae b/example.dae index 3c45f7c..b2ad2a0 100644 --- a/example.dae +++ b/example.dae @@ -10,6 +10,8 @@ global { tcp_check_url: 'http://cp.cloudflare.com' udp_check_dns: 'dns.google:53' check_interval: 30s + # Group will switch node only when new_latency <= old_latency - tolerance + check_tolerance: 50ms # Value can be scheme://host:port or empty string ''. # The scheme can be tcp/udp/tcp+udp. Empty string '' indicates as-is. @@ -74,7 +76,7 @@ routing { ip(geoip:private, 224.0.0.0/3, 'ff00::/8') -> direct # Put it first unless you know what you're doing. # Write your rules below. - # dae arms DNS rush-answer filter so we can use 8.8.8.8 regardless of DNS pollution. + # dae arms DNS rush-answer filter so we can use dns.google regardless of DNS pollution. domain(full:dns.google) && port(53) -> direct pname(firefox) && domain(ip.sb) -> direct