From af743cd1c6ae346df8f1cadff24ed3d4d1db5434 Mon Sep 17 00:00:00 2001 From: mzz2017 <2017@duck.com> Date: Thu, 9 Feb 2023 23:17:49 +0800 Subject: [PATCH] feat: support include --- cmd/internal/subscription.go | 15 ++- cmd/run.go | 21 ++-- cmd/validate.go | 2 +- common/utils.go | 18 +++- config/config.go | 7 +- config/config_merger.go | 193 +++++++++++++++++++++++++++++++++++ 6 files changed, 239 insertions(+), 17 deletions(-) create mode 100644 config/config_merger.go diff --git a/cmd/internal/subscription.go b/cmd/internal/subscription.go index bfb4741..35c08e0 100644 --- a/cmd/internal/subscription.go +++ b/cmd/internal/subscription.go @@ -91,7 +91,7 @@ func resolveFile(u *url.URL, configDir string) (b []byte, err error) { /// Relative location. // Make sure path is secure. path := filepath.Join(configDir, u.Host, u.Path) - if err = common.IsFileInSubDir(path, configDir); err != nil { + if err = common.EnsureFileInSubDir(path, configDir); err != nil { return nil, err } /// Read and resolve @@ -99,6 +99,17 @@ func resolveFile(u *url.URL, configDir string) (b []byte, err error) { if err != nil { return nil, err } + // Check file access. + fi, err := f.Stat() + if err != nil { + return nil, err + } + if fi.IsDir() { + return nil, fmt.Errorf("subscription file cannot be a directory: %v", path) + } + if fi.Mode()&0037 > 0 { + return nil, fmt.Errorf("permissions %04o for '%v' are too open; requires the file is NOT writable by the same group and NOT accessible by others; suggest 0640 or 0600", fi.Mode()&0777, path) + } // Resolve the first line instruction. fReader := bufio.NewReader(f) b, err = fReader.Peek(1) @@ -134,7 +145,7 @@ func ResolveSubscription(log *logrus.Logger, configDir string, subscription stri case "file": b, err = resolveFile(u, configDir) if err != nil { - return nil, fmt.Errorf("failed to resolve file: %w", err) + return nil, err } goto resolve default: diff --git a/cmd/run.go b/cmd/run.go index 80e7ed5..e2a5976 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -7,11 +7,11 @@ import ( "github.com/v2rayA/dae/cmd/internal" "github.com/v2rayA/dae/config" "github.com/v2rayA/dae/control" - "github.com/v2rayA/dae/pkg/config_parser" "github.com/v2rayA/dae/pkg/logger" "os" "os/signal" "path/filepath" + "strings" "syscall" ) @@ -31,13 +31,15 @@ var ( internal.AutoSu() // Read config from --config cfgFile. - param, err := readConfig(cfgFile) + param, includes, err := readConfig(cfgFile) if err != nil { logrus.Fatalln("readConfig:", err) } log := logger.NewLogger(param.Global.LogLevel, disableTimestamp) logrus.SetLevel(log.Level) + + log.Infof("Include config files: [%v]", strings.Join(includes, ", ")) if err := Run(log, param); err != nil { logrus.Fatalln(err) } @@ -98,17 +100,14 @@ func Run(log *logrus.Logger, param *config.Params) (err error) { return nil } -func readConfig(cfgFile string) (params *config.Params, err error) { - b, err := os.ReadFile(cfgFile) +func readConfig(cfgFile string) (params *config.Params, entries []string, err error) { + merger := config.NewMerger(cfgFile) + sections, entries, err := merger.Merge() if err != nil { - return nil, err - } - sections, err := config_parser.Parse(string(b)) - if err != nil { - return nil, fmt.Errorf("\n%w", err) + return nil, nil, err } if params, err = config.New(sections); err != nil { - return nil, err + return nil, nil, err } - return params, nil + return params, entries, nil } diff --git a/cmd/validate.go b/cmd/validate.go index 93e929b..ad46e1b 100644 --- a/cmd/validate.go +++ b/cmd/validate.go @@ -21,7 +21,7 @@ var ( os.Exit(1) } // Read config from --config cfgFile. - _, err := readConfig(cfgFile) + _, _, err := readConfig(cfgFile) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/common/utils.go b/common/utils.go index fc28c01..74c1155 100644 --- a/common/utils.go +++ b/common/utils.go @@ -309,7 +309,7 @@ func FuzzyDecode(to interface{}, val string) bool { return true } -func IsFileInSubDir(filePath string, dir string) (err error) { +func EnsureFileInSubDir(filePath string, dir string) (err error) { fileDir := filepath.Dir(filePath) if len(dir) == 0 { return fmt.Errorf("bad dir: %v", dir) @@ -323,3 +323,19 @@ func IsFileInSubDir(filePath string, dir string) (err error) { } return nil } + +func MapKeys(m interface{}) (keys []string, err error) { + v := reflect.ValueOf(m) + if v.Kind() != reflect.Map { + return nil, fmt.Errorf("MapKeys requires map[string]*") + } + if v.Type().Key().Kind() != reflect.String { + return nil, fmt.Errorf("MapKeys requires map[string]*") + } + _keys := v.MapKeys() + keys = make([]string, 0, len(_keys)) + for _, k := range _keys { + keys = append(keys, k.String()) + } + return keys, nil +} diff --git a/config/config.go b/config/config.go index 59ecfa0..1baff44 100644 --- a/config/config.go +++ b/config/config.go @@ -48,7 +48,7 @@ type Params struct { Routing Routing `mapstructure:"routing" parser:"RoutingRuleAndParamParser"` } -// New params from sections. This func assumes merging (section "include") and deduplication for sections has been executed. +// New params from sections. This func assumes merging (section "include") and deduplication for section names has been executed. func New(sections []*config_parser.Section) (params *Params, err error) { // Set up name to section for further use. type Section struct { @@ -96,8 +96,11 @@ func New(sections []*config_parser.Section) (params *Params, err error) { section.Parsed = true } - // Report unknown. Not "unused" because we assume deduplication has been executed before this func. + // Report unknown. Not "unused" because we assume section name deduplication has been executed before this func. for name, section := range nameToSection { + if section.Val.Name == "include" { + continue + } if !section.Parsed { return nil, fmt.Errorf("unknown section: %v", name) } diff --git a/config/config_merger.go b/config/config_merger.go new file mode 100644 index 0000000..748de50 --- /dev/null +++ b/config/config_merger.go @@ -0,0 +1,193 @@ +/* + * SPDX-License-Identifier: AGPL-3.0-only + * Copyright (c) since 2023, mzz2017 + */ + +package config + +import ( + "errors" + "fmt" + "github.com/v2rayA/dae/common" + "github.com/v2rayA/dae/pkg/config_parser" + "io" + "os" + "path/filepath" + "strings" +) + +var ( + CircularIncludeError = fmt.Errorf("circular include is not allowed") +) + +type Merger struct { + entry string + entryDir string + entryToSectionMap map[string]map[string][]*config_parser.Item +} + +func NewMerger(entry string) *Merger { + return &Merger{ + entry: entry, + entryDir: filepath.Dir(entry), + entryToSectionMap: map[string]map[string][]*config_parser.Item{}, + } +} + +func (m *Merger) Merge() (sections []*config_parser.Section, entries []string, err error) { + err = m.dfsMerge(m.entry, "") + if err != nil { + return nil, nil, err + } + entries, err = common.MapKeys(m.entryToSectionMap) + if err != nil { + return nil, nil, err + } + return m.convertMapToSections(m.entryToSectionMap[m.entry]), entries, nil +} + +func (m *Merger) readEntry(entry string) (err error) { + // Check circular include. + _, exist := m.entryToSectionMap[entry] + if exist { + return CircularIncludeError + } + + // Check filename + if !strings.HasSuffix(entry, ".dae") { + return fmt.Errorf("invalid config filename %v: must has suffix .dae", entry) + } + // Check file path security. + if err = common.EnsureFileInSubDir(entry, m.entryDir); err != nil { + return fmt.Errorf("failed in checking path of config file %v: %w", entry, err) + } + f, err := os.Open(entry) + if err != nil { + return fmt.Errorf("failed to read config file %v: %w", entry, err) + } + // Check file access. + fi, err := f.Stat() + if err != nil { + return err + } + if fi.IsDir() { + return fmt.Errorf("cannot include a directory: %v", entry) + } + if fi.Mode()&0037 > 0 { + return fmt.Errorf("permissions %04o for '%v' are too open; requires the file is NOT writable by the same group and NOT accessible by others; suggest 0640 or 0600", fi.Mode()&0777, entry) + } + // Read and parse. + b, err := io.ReadAll(f) + if err != nil { + return err + } + entrySections, err := config_parser.Parse(string(b)) + if err != nil { + return fmt.Errorf("failed to parse config file %v:\n%w", entry, err) + } + m.entryToSectionMap[entry] = m.convertSectionsToMap(entrySections) + return nil +} + +func unsqueezeEntries(patternEntries []string) (unsqueezed []string, err error) { + unsqueezed = make([]string, 0, len(patternEntries)) + for _, pattern := range patternEntries { + files, err := filepath.Glob(pattern) + if err != nil { + return nil, err + } + for _, file := range files { + // We only support .dae + if !strings.HasSuffix(file, ".dae") { + continue + } + fi, err := os.Stat(file) + if err != nil { + return nil, err + } + if fi.IsDir() { + continue + } + unsqueezed = append(unsqueezed, file) + } + } + if len(unsqueezed) == 0 { + unsqueezed = nil + } + return unsqueezed, nil +} + +func (m *Merger) dfsMerge(entry string, fatherEntry string) (err error) { + // Read entry and check circular include. + if err = m.readEntry(entry); err != nil { + if errors.Is(err, CircularIncludeError) { + return fmt.Errorf("%w: %v -> %v -> ... -> %v", err, fatherEntry, entry, fatherEntry) + } + return err + } + sectionMap := m.entryToSectionMap[entry] + // Extract childEntries. + includes := sectionMap["include"] + var patterEntries = make([]string, 0, len(includes)) + for _, include := range includes { + switch v := include.Value.(type) { + case *config_parser.Param: + nextEntry := v.String(true) + patterEntries = append(patterEntries, filepath.Join(m.entryDir, nextEntry)) + default: + return fmt.Errorf("unsupported include grammar in %v: %v", entry, include.String()) + } + } + // DFS and merge children recursively. + childEntries, err := unsqueezeEntries(patterEntries) + if err != nil { + return err + } + for _, nextEntry := range childEntries { + if err = m.dfsMerge(nextEntry, entry); err != nil { + return err + } + } + /// Merge into father. Do not need to retrieve sectionMap again because go map is a reference. + if fatherEntry == "" { + // We are already on the top. + return nil + } + fatherSectionMap := m.entryToSectionMap[fatherEntry] + for sec := range sectionMap { + items := m.mergeItems(fatherSectionMap[sec], sectionMap[sec]) + fatherSectionMap[sec] = items + } + return nil +} + +func (m *Merger) convertSectionsToMap(sections []*config_parser.Section) (sectionMap map[string][]*config_parser.Item) { + sectionMap = make(map[string][]*config_parser.Item) + for _, sec := range sections { + items, ok := sectionMap[sec.Name] + if ok { + sectionMap[sec.Name] = m.mergeItems(items, sec.Items) + } else { + sectionMap[sec.Name] = sec.Items + } + } + return sectionMap +} + +func (m *Merger) convertMapToSections(sectionMap map[string][]*config_parser.Item) (sections []*config_parser.Section) { + sections = make([]*config_parser.Section, 0, len(sectionMap)) + for name, items := range sectionMap { + sections = append(sections, &config_parser.Section{ + Name: name, + Items: items, + }) + } + return sections +} + +func (m *Merger) mergeItems(to, from []*config_parser.Item) (items []*config_parser.Item) { + items = make([]*config_parser.Item, len(to)+len(from)) + copy(items, to) + copy(items[len(to):], from) + return items +}