Merge pull request #12 from v2rayA/feat_include

This commit is contained in:
mzz 2023-02-10 00:00:05 +08:00 committed by GitHub
commit d9d4b94e93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 347 additions and 44 deletions

View File

@ -121,7 +121,5 @@ See [example.dae](https://github.com/v2rayA/dae/blob/main/example.dae).
But the problem is, after the Linux network stack, before entering the network card, we modify the source IP of this packet, causing the Linux network stack to only make a simple checksum, and the NIC also assumes that this packet is not sent from local, so no further checksum completing.
1. MACv2 extension extraction.
1. Log to userspace.
1. Support include section.
1. Subscription section supports "file://"
1. Subscription section supports key.
1. Subscription section supports key. And support to filter by subscription key.
1. ...

View File

@ -1,6 +1,8 @@
package internal
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"github.com/sirupsen/logrus"
@ -9,6 +11,8 @@ import (
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
)
@ -80,26 +84,82 @@ func resolveSubscriptionAsSIP008(log *logrus.Logger, b []byte) (nodes []string,
return nodes, nil
}
func ResolveSubscription(log *logrus.Logger, subscription string) (nodes []string, err error) {
func resolveFile(u *url.URL, configDir string) (b []byte, err error) {
if u.Host == "" {
return nil, fmt.Errorf("not support absolute path")
}
/// Relative location.
// Make sure path is secure.
path := filepath.Join(configDir, u.Host, u.Path)
if err = common.EnsureFileInSubDir(path, configDir); err != nil {
return nil, err
}
/// Read and resolve
f, err := os.Open(path)
if err != nil {
return nil, err
}
// Check file access.
fi, err := f.Stat()
if err != nil {
return nil, err
}
if fi.IsDir() {
return nil, fmt.Errorf("subscription file cannot be a directory: %v", path)
}
if fi.Mode()&0037 > 0 {
return nil, fmt.Errorf("permissions %04o for '%v' are too open; requires the file is NOT writable by the same group and NOT accessible by others; suggest 0640 or 0600", fi.Mode()&0777, path)
}
// Resolve the first line instruction.
fReader := bufio.NewReader(f)
b, err = fReader.Peek(1)
if err != nil {
return nil, err
}
if string(b[0]) == "@" {
// Instruction line. But not support yet.
_, _, err = fReader.ReadLine()
if err != nil {
return nil, err
}
}
b, err = io.ReadAll(fReader)
if err != nil {
return nil, err
}
return bytes.TrimSpace(b), err
}
func ResolveSubscription(log *logrus.Logger, configDir string, subscription string) (nodes []string, err error) {
u, err := url.Parse(subscription)
if err != nil {
return nil, fmt.Errorf("failed to parse subscription \"%v\": %w", subscription, err)
}
log.Debugf("ResolveSubscription: %v", subscription)
var (
b []byte
resp *http.Response
)
switch u.Scheme {
case "file":
// TODO
b, err = resolveFile(u, configDir)
if err != nil {
return nil, err
}
goto resolve
default:
}
log.Debugf("ResolveSubscription: %v", subscription)
resp, err := http.Get(subscription)
resp, err = http.Get(subscription)
if err != nil {
return nil, err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
b, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
resolve:
if nodes, err = resolveSubscriptionAsSIP008(log, b); err == nil {
return nodes, nil
} else {

View File

@ -7,10 +7,11 @@ import (
"github.com/v2rayA/dae/cmd/internal"
"github.com/v2rayA/dae/config"
"github.com/v2rayA/dae/control"
"github.com/v2rayA/dae/pkg/config_parser"
"github.com/v2rayA/dae/pkg/logger"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
)
@ -30,13 +31,15 @@ var (
internal.AutoSu()
// Read config from --config cfgFile.
param, err := readConfig(cfgFile)
param, includes, err := readConfig(cfgFile)
if err != nil {
logrus.Fatalln("readConfig:", err)
}
log := logger.NewLogger(param.Global.LogLevel, disableTimestamp)
logrus.SetLevel(log.Level)
log.Infof("Include config files: [%v]", strings.Join(includes, ", "))
if err := Run(log, param); err != nil {
logrus.Fatalln(err)
}
@ -55,7 +58,7 @@ func Run(log *logrus.Logger, param *config.Params) (err error) {
nodeList := make([]string, len(param.Node))
copy(nodeList, param.Node)
for _, sub := range param.Subscription {
nodes, err := internal.ResolveSubscription(log, sub)
nodes, err := internal.ResolveSubscription(log, filepath.Dir(cfgFile), sub)
if err != nil {
log.Warnf(`failed to resolve subscription "%v": %v`, sub, err)
}
@ -97,17 +100,14 @@ func Run(log *logrus.Logger, param *config.Params) (err error) {
return nil
}
func readConfig(cfgFile string) (params *config.Params, err error) {
b, err := os.ReadFile(cfgFile)
func readConfig(cfgFile string) (params *config.Params, entries []string, err error) {
merger := config.NewMerger(cfgFile)
sections, entries, err := merger.Merge()
if err != nil {
return nil, err
}
sections, err := config_parser.Parse(string(b))
if err != nil {
return nil, fmt.Errorf("\n%w", err)
return nil, nil, err
}
if params, err = config.New(sections); err != nil {
return nil, err
return nil, nil, err
}
return params, nil
return params, entries, nil
}

View File

@ -21,7 +21,7 @@ var (
os.Exit(1)
}
// Read config from --config cfgFile.
_, err := readConfig(cfgFile)
_, _, err := readConfig(cfgFile)
if err != nil {
fmt.Println(err)
os.Exit(1)

View File

@ -11,6 +11,7 @@ import (
"encoding/hex"
"fmt"
"net/url"
"path/filepath"
"reflect"
"strconv"
"strings"
@ -307,3 +308,34 @@ func FuzzyDecode(to interface{}, val string) bool {
}
return true
}
func EnsureFileInSubDir(filePath string, dir string) (err error) {
fileDir := filepath.Dir(filePath)
if len(dir) == 0 {
return fmt.Errorf("bad dir: %v", dir)
}
rel, err := filepath.Rel(dir, fileDir)
if err != nil {
return err
}
if strings.HasPrefix(rel, "..") {
return fmt.Errorf("file is out of scope: %v", rel)
}
return nil
}
func MapKeys(m interface{}) (keys []string, err error) {
v := reflect.ValueOf(m)
if v.Kind() != reflect.Map {
return nil, fmt.Errorf("MapKeys requires map[string]*")
}
if v.Type().Key().Kind() != reflect.String {
return nil, fmt.Errorf("MapKeys requires map[string]*")
}
_keys := v.MapKeys()
keys = make([]string, 0, len(_keys))
for _, k := range _keys {
keys = append(keys, k.String())
}
return keys, nil
}

View File

@ -26,6 +26,7 @@ type AliveDialerSet struct {
dialerGroupName string
l4proto consts.L4ProtoStr
ipversion consts.IpVersionStr
tolerance time.Duration
aliveChangeCallback func(alive bool)
@ -43,6 +44,7 @@ func NewAliveDialerSet(
dialerGroupName string,
l4proto consts.L4ProtoStr,
ipversion consts.IpVersionStr,
tolerance time.Duration,
selectionPolicy consts.DialerSelectionPolicy,
dialers []*Dialer,
aliveChangeCallback func(alive bool),
@ -53,6 +55,7 @@ func NewAliveDialerSet(
dialerGroupName: dialerGroupName,
l4proto: l4proto,
ipversion: ipversion,
tolerance: tolerance,
aliveChangeCallback: aliveChangeCallback,
dialerToIndex: make(map[*Dialer]int),
dialerToLatency: make(map[*Dialer]time.Duration),
@ -146,14 +149,19 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
bakOldBestDialer := a.minLatency.dialer
// Calc minLatency.
a.dialerToLatency[dialer] = latency
if alive && latency < a.minLatency.latency {
if alive && latency <= a.minLatency.latency-a.tolerance {
a.minLatency.latency = latency
a.minLatency.dialer = dialer
} else if a.minLatency.dialer == dialer {
a.minLatency.latency = time.Hour
a.minLatency.dialer = nil
a.calcMinLatency()
// Now `a.minLatency.dialer` will be nil if there is no alive dialer.
if latency > a.minLatency.latency {
// Latency increases.
a.minLatency.latency = time.Hour
a.minLatency.dialer = nil
a.calcMinLatency()
// Now `a.minLatency.dialer` will be nil if there is no alive dialer.
} else {
a.minLatency.latency = latency
}
}
currentAlive := a.minLatency.dialer != nil
// If best dialer changed.
@ -169,7 +177,8 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
string(a.selectionPolicy): a.minLatency.latency,
"group": a.dialerGroupName,
"network": string(a.l4proto) + string(a.ipversion),
"dialer": a.minLatency.dialer.Name(),
"new dialer": a.minLatency.dialer.Name(),
"old dialer": bakOldBestDialer.Name(),
}).Infof("Group %vselects dialer", re)
} else {
// Alive -> not alive
@ -195,8 +204,11 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
func (a *AliveDialerSet) calcMinLatency() {
for _, d := range a.inorderedAliveDialerSet {
latency := a.dialerToLatency[d]
if latency < a.minLatency.latency {
latency, ok := a.dialerToLatency[d]
if !ok {
continue
}
if latency <= a.minLatency.latency-a.tolerance {
a.minLatency.latency = latency
a.minLatency.dialer = d
}

View File

@ -37,6 +37,7 @@ type GlobalOption struct {
TcpCheckOptionRaw TcpCheckOptionRaw // Lazy parse
UdpCheckOptionRaw UdpCheckOptionRaw // Lazy parse
CheckInterval time.Duration
CheckTolerance time.Duration
}
type InstanceOption struct {

View File

@ -39,10 +39,10 @@ type DialerGroup struct {
func NewDialerGroup(option *dialer.GlobalOption, name string, dialers []*dialer.Dialer, p DialerSelectionPolicy, aliveChangeCallback func(alive bool, l4proto uint8, ipversion uint8)) *DialerGroup {
log := option.Log
var registeredAliveDialerSet bool
aliveTcp4DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_TCP, consts.IpVersionStr_4, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_TCP, 4) }, true)
aliveTcp6DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_TCP, consts.IpVersionStr_6, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_TCP, 6) }, true)
aliveUdp4DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_UDP, consts.IpVersionStr_4, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_UDP, 4) }, true)
aliveUdp6DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_UDP, consts.IpVersionStr_6, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_UDP, 6) }, true)
aliveTcp4DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_TCP, consts.IpVersionStr_4, option.CheckTolerance, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_TCP, 4) }, true)
aliveTcp6DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_TCP, consts.IpVersionStr_6, option.CheckTolerance, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_TCP, 6) }, true)
aliveUdp4DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_UDP, consts.IpVersionStr_4, option.CheckTolerance, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_UDP, 4) }, true)
aliveUdp6DialerSet := dialer.NewAliveDialerSet(log, name, consts.L4ProtoStr_UDP, consts.IpVersionStr_6, option.CheckTolerance, p.Policy, dialers, func(alive bool) { aliveChangeCallback(alive, unix.IPPROTO_UDP, 6) }, true)
switch p.Policy {
case consts.DialerSelectionPolicy_Random,

View File

@ -14,14 +14,15 @@ import (
)
type Global struct {
TproxyPort uint16 `mapstructure:"tproxy_port" default:"12345"`
LogLevel string `mapstructure:"log_level" default:"info"`
TcpCheckUrl string `mapstructure:"tcp_check_url" default:"http://cp.cloudflare.com"`
UdpCheckDns string `mapstructure:"udp_check_dns" default:"cloudflare-dns.com:53"`
CheckInterval time.Duration `mapstructure:"check_interval" default:"30s"`
DnsUpstream common.UrlOrEmpty `mapstructure:"dns_upstream" require:""`
LanInterface []string `mapstructure:"lan_interface"`
WanInterface []string `mapstructure:"wan_interface"`
TproxyPort uint16 `mapstructure:"tproxy_port" default:"12345"`
LogLevel string `mapstructure:"log_level" default:"info"`
TcpCheckUrl string `mapstructure:"tcp_check_url" default:"http://cp.cloudflare.com"`
UdpCheckDns string `mapstructure:"udp_check_dns" default:"cloudflare-dns.com:53"`
CheckInterval time.Duration `mapstructure:"check_interval" default:"30s"`
CheckTolerance time.Duration `mapstructure:"check_tolerance" default:"0"`
DnsUpstream common.UrlOrEmpty `mapstructure:"dns_upstream" require:""`
LanInterface []string `mapstructure:"lan_interface"`
WanInterface []string `mapstructure:"wan_interface"`
}
type Group struct {
@ -47,7 +48,7 @@ type Params struct {
Routing Routing `mapstructure:"routing" parser:"RoutingRuleAndParamParser"`
}
// New params from sections. This func assumes merging (section "include") and deduplication for sections has been executed.
// New params from sections. This func assumes merging (section "include") and deduplication for section names has been executed.
func New(sections []*config_parser.Section) (params *Params, err error) {
// Set up name to section for further use.
type Section struct {
@ -95,8 +96,11 @@ func New(sections []*config_parser.Section) (params *Params, err error) {
section.Parsed = true
}
// Report unknown. Not "unused" because we assume deduplication has been executed before this func.
// Report unknown. Not "unused" because we assume section name deduplication has been executed before this func.
for name, section := range nameToSection {
if section.Val.Name == "include" {
continue
}
if !section.Parsed {
return nil, fmt.Errorf("unknown section: %v", name)
}

193
config/config_merger.go Normal file
View File

@ -0,0 +1,193 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2023, mzz2017 <mzz@tuta.io>
*/
package config
import (
"errors"
"fmt"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/pkg/config_parser"
"io"
"os"
"path/filepath"
"strings"
)
var (
CircularIncludeError = fmt.Errorf("circular include is not allowed")
)
type Merger struct {
entry string
entryDir string
entryToSectionMap map[string]map[string][]*config_parser.Item
}
func NewMerger(entry string) *Merger {
return &Merger{
entry: entry,
entryDir: filepath.Dir(entry),
entryToSectionMap: map[string]map[string][]*config_parser.Item{},
}
}
func (m *Merger) Merge() (sections []*config_parser.Section, entries []string, err error) {
err = m.dfsMerge(m.entry, "")
if err != nil {
return nil, nil, err
}
entries, err = common.MapKeys(m.entryToSectionMap)
if err != nil {
return nil, nil, err
}
return m.convertMapToSections(m.entryToSectionMap[m.entry]), entries, nil
}
func (m *Merger) readEntry(entry string) (err error) {
// Check circular include.
_, exist := m.entryToSectionMap[entry]
if exist {
return CircularIncludeError
}
// Check filename
if !strings.HasSuffix(entry, ".dae") {
return fmt.Errorf("invalid config filename %v: must has suffix .dae", entry)
}
// Check file path security.
if err = common.EnsureFileInSubDir(entry, m.entryDir); err != nil {
return fmt.Errorf("failed in checking path of config file %v: %w", entry, err)
}
f, err := os.Open(entry)
if err != nil {
return fmt.Errorf("failed to read config file %v: %w", entry, err)
}
// Check file access.
fi, err := f.Stat()
if err != nil {
return err
}
if fi.IsDir() {
return fmt.Errorf("cannot include a directory: %v", entry)
}
if fi.Mode()&0037 > 0 {
return fmt.Errorf("permissions %04o for '%v' are too open; requires the file is NOT writable by the same group and NOT accessible by others; suggest 0640 or 0600", fi.Mode()&0777, entry)
}
// Read and parse.
b, err := io.ReadAll(f)
if err != nil {
return err
}
entrySections, err := config_parser.Parse(string(b))
if err != nil {
return fmt.Errorf("failed to parse config file %v:\n%w", entry, err)
}
m.entryToSectionMap[entry] = m.convertSectionsToMap(entrySections)
return nil
}
func unsqueezeEntries(patternEntries []string) (unsqueezed []string, err error) {
unsqueezed = make([]string, 0, len(patternEntries))
for _, pattern := range patternEntries {
files, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
for _, file := range files {
// We only support .dae
if !strings.HasSuffix(file, ".dae") {
continue
}
fi, err := os.Stat(file)
if err != nil {
return nil, err
}
if fi.IsDir() {
continue
}
unsqueezed = append(unsqueezed, file)
}
}
if len(unsqueezed) == 0 {
unsqueezed = nil
}
return unsqueezed, nil
}
func (m *Merger) dfsMerge(entry string, fatherEntry string) (err error) {
// Read entry and check circular include.
if err = m.readEntry(entry); err != nil {
if errors.Is(err, CircularIncludeError) {
return fmt.Errorf("%w: %v -> %v -> ... -> %v", err, fatherEntry, entry, fatherEntry)
}
return err
}
sectionMap := m.entryToSectionMap[entry]
// Extract childEntries.
includes := sectionMap["include"]
var patterEntries = make([]string, 0, len(includes))
for _, include := range includes {
switch v := include.Value.(type) {
case *config_parser.Param:
nextEntry := v.String(true)
patterEntries = append(patterEntries, filepath.Join(m.entryDir, nextEntry))
default:
return fmt.Errorf("unsupported include grammar in %v: %v", entry, include.String())
}
}
// DFS and merge children recursively.
childEntries, err := unsqueezeEntries(patterEntries)
if err != nil {
return err
}
for _, nextEntry := range childEntries {
if err = m.dfsMerge(nextEntry, entry); err != nil {
return err
}
}
/// Merge into father. Do not need to retrieve sectionMap again because go map is a reference.
if fatherEntry == "" {
// We are already on the top.
return nil
}
fatherSectionMap := m.entryToSectionMap[fatherEntry]
for sec := range sectionMap {
items := m.mergeItems(fatherSectionMap[sec], sectionMap[sec])
fatherSectionMap[sec] = items
}
return nil
}
func (m *Merger) convertSectionsToMap(sections []*config_parser.Section) (sectionMap map[string][]*config_parser.Item) {
sectionMap = make(map[string][]*config_parser.Item)
for _, sec := range sections {
items, ok := sectionMap[sec.Name]
if ok {
sectionMap[sec.Name] = m.mergeItems(items, sec.Items)
} else {
sectionMap[sec.Name] = sec.Items
}
}
return sectionMap
}
func (m *Merger) convertMapToSections(sectionMap map[string][]*config_parser.Item) (sections []*config_parser.Section) {
sections = make([]*config_parser.Section, 0, len(sectionMap))
for name, items := range sectionMap {
sections = append(sections, &config_parser.Section{
Name: name,
Items: items,
})
}
return sections
}
func (m *Merger) mergeItems(to, from []*config_parser.Item) (items []*config_parser.Item) {
items = make([]*config_parser.Item, len(to)+len(from))
copy(items, to)
copy(items[len(to):], from)
return items
}

View File

@ -196,6 +196,7 @@ func NewControlPlane(
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: global.TcpCheckUrl},
UdpCheckOptionRaw: dialer.UdpCheckOptionRaw{Raw: global.UdpCheckDns},
CheckInterval: global.CheckInterval,
CheckTolerance: global.CheckTolerance,
}
outbounds := []*outbound.DialerGroup{
outbound.NewDialerGroup(option, consts.OutboundDirect.String(),

View File

@ -10,6 +10,8 @@ global {
tcp_check_url: 'http://cp.cloudflare.com'
udp_check_dns: 'dns.google:53'
check_interval: 30s
# Group will switch node only when new_latency <= old_latency - tolerance
check_tolerance: 50ms
# Value can be scheme://host:port or empty string ''.
# The scheme can be tcp/udp/tcp+udp. Empty string '' indicates as-is.
@ -74,7 +76,7 @@ routing {
ip(geoip:private, 224.0.0.0/3, 'ff00::/8') -> direct # Put it first unless you know what you're doing.
# Write your rules below.
# dae arms DNS rush-answer filter so we can use 8.8.8.8 regardless of DNS pollution.
# dae arms DNS rush-answer filter so we can use dns.google regardless of DNS pollution.
domain(full:dns.google) && port(53) -> direct
pname(firefox) && domain(ip.sb) -> direct