feat: support include

This commit is contained in:
mzz2017
2023-02-09 23:17:49 +08:00
parent 158dfb2a27
commit af743cd1c6
6 changed files with 239 additions and 17 deletions

View File

@ -91,7 +91,7 @@ func resolveFile(u *url.URL, configDir string) (b []byte, err error) {
/// Relative location.
// Make sure path is secure.
path := filepath.Join(configDir, u.Host, u.Path)
if err = common.IsFileInSubDir(path, configDir); err != nil {
if err = common.EnsureFileInSubDir(path, configDir); err != nil {
return nil, err
}
/// Read and resolve
@ -99,6 +99,17 @@ func resolveFile(u *url.URL, configDir string) (b []byte, err error) {
if err != nil {
return nil, err
}
// Check file access.
fi, err := f.Stat()
if err != nil {
return nil, err
}
if fi.IsDir() {
return nil, fmt.Errorf("subscription file cannot be a directory: %v", path)
}
if fi.Mode()&0037 > 0 {
return nil, fmt.Errorf("permissions %04o for '%v' are too open; requires the file is NOT writable by the same group and NOT accessible by others; suggest 0640 or 0600", fi.Mode()&0777, path)
}
// Resolve the first line instruction.
fReader := bufio.NewReader(f)
b, err = fReader.Peek(1)
@ -134,7 +145,7 @@ func ResolveSubscription(log *logrus.Logger, configDir string, subscription stri
case "file":
b, err = resolveFile(u, configDir)
if err != nil {
return nil, fmt.Errorf("failed to resolve file: %w", err)
return nil, err
}
goto resolve
default:

View File

@ -7,11 +7,11 @@ import (
"github.com/v2rayA/dae/cmd/internal"
"github.com/v2rayA/dae/config"
"github.com/v2rayA/dae/control"
"github.com/v2rayA/dae/pkg/config_parser"
"github.com/v2rayA/dae/pkg/logger"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
)
@ -31,13 +31,15 @@ var (
internal.AutoSu()
// Read config from --config cfgFile.
param, err := readConfig(cfgFile)
param, includes, err := readConfig(cfgFile)
if err != nil {
logrus.Fatalln("readConfig:", err)
}
log := logger.NewLogger(param.Global.LogLevel, disableTimestamp)
logrus.SetLevel(log.Level)
log.Infof("Include config files: [%v]", strings.Join(includes, ", "))
if err := Run(log, param); err != nil {
logrus.Fatalln(err)
}
@ -98,17 +100,14 @@ func Run(log *logrus.Logger, param *config.Params) (err error) {
return nil
}
func readConfig(cfgFile string) (params *config.Params, err error) {
b, err := os.ReadFile(cfgFile)
func readConfig(cfgFile string) (params *config.Params, entries []string, err error) {
merger := config.NewMerger(cfgFile)
sections, entries, err := merger.Merge()
if err != nil {
return nil, err
}
sections, err := config_parser.Parse(string(b))
if err != nil {
return nil, fmt.Errorf("\n%w", err)
return nil, nil, err
}
if params, err = config.New(sections); err != nil {
return nil, err
return nil, nil, err
}
return params, nil
return params, entries, nil
}

View File

@ -21,7 +21,7 @@ var (
os.Exit(1)
}
// Read config from --config cfgFile.
_, err := readConfig(cfgFile)
_, _, err := readConfig(cfgFile)
if err != nil {
fmt.Println(err)
os.Exit(1)

View File

@ -309,7 +309,7 @@ func FuzzyDecode(to interface{}, val string) bool {
return true
}
func IsFileInSubDir(filePath string, dir string) (err error) {
func EnsureFileInSubDir(filePath string, dir string) (err error) {
fileDir := filepath.Dir(filePath)
if len(dir) == 0 {
return fmt.Errorf("bad dir: %v", dir)
@ -323,3 +323,19 @@ func IsFileInSubDir(filePath string, dir string) (err error) {
}
return nil
}
func MapKeys(m interface{}) (keys []string, err error) {
v := reflect.ValueOf(m)
if v.Kind() != reflect.Map {
return nil, fmt.Errorf("MapKeys requires map[string]*")
}
if v.Type().Key().Kind() != reflect.String {
return nil, fmt.Errorf("MapKeys requires map[string]*")
}
_keys := v.MapKeys()
keys = make([]string, 0, len(_keys))
for _, k := range _keys {
keys = append(keys, k.String())
}
return keys, nil
}

View File

@ -48,7 +48,7 @@ type Params struct {
Routing Routing `mapstructure:"routing" parser:"RoutingRuleAndParamParser"`
}
// New params from sections. This func assumes merging (section "include") and deduplication for sections has been executed.
// New params from sections. This func assumes merging (section "include") and deduplication for section names has been executed.
func New(sections []*config_parser.Section) (params *Params, err error) {
// Set up name to section for further use.
type Section struct {
@ -96,8 +96,11 @@ func New(sections []*config_parser.Section) (params *Params, err error) {
section.Parsed = true
}
// Report unknown. Not "unused" because we assume deduplication has been executed before this func.
// Report unknown. Not "unused" because we assume section name deduplication has been executed before this func.
for name, section := range nameToSection {
if section.Val.Name == "include" {
continue
}
if !section.Parsed {
return nil, fmt.Errorf("unknown section: %v", name)
}

193
config/config_merger.go Normal file
View File

@ -0,0 +1,193 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2023, mzz2017 <mzz@tuta.io>
*/
package config
import (
"errors"
"fmt"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/pkg/config_parser"
"io"
"os"
"path/filepath"
"strings"
)
var (
CircularIncludeError = fmt.Errorf("circular include is not allowed")
)
type Merger struct {
entry string
entryDir string
entryToSectionMap map[string]map[string][]*config_parser.Item
}
func NewMerger(entry string) *Merger {
return &Merger{
entry: entry,
entryDir: filepath.Dir(entry),
entryToSectionMap: map[string]map[string][]*config_parser.Item{},
}
}
func (m *Merger) Merge() (sections []*config_parser.Section, entries []string, err error) {
err = m.dfsMerge(m.entry, "")
if err != nil {
return nil, nil, err
}
entries, err = common.MapKeys(m.entryToSectionMap)
if err != nil {
return nil, nil, err
}
return m.convertMapToSections(m.entryToSectionMap[m.entry]), entries, nil
}
func (m *Merger) readEntry(entry string) (err error) {
// Check circular include.
_, exist := m.entryToSectionMap[entry]
if exist {
return CircularIncludeError
}
// Check filename
if !strings.HasSuffix(entry, ".dae") {
return fmt.Errorf("invalid config filename %v: must has suffix .dae", entry)
}
// Check file path security.
if err = common.EnsureFileInSubDir(entry, m.entryDir); err != nil {
return fmt.Errorf("failed in checking path of config file %v: %w", entry, err)
}
f, err := os.Open(entry)
if err != nil {
return fmt.Errorf("failed to read config file %v: %w", entry, err)
}
// Check file access.
fi, err := f.Stat()
if err != nil {
return err
}
if fi.IsDir() {
return fmt.Errorf("cannot include a directory: %v", entry)
}
if fi.Mode()&0037 > 0 {
return fmt.Errorf("permissions %04o for '%v' are too open; requires the file is NOT writable by the same group and NOT accessible by others; suggest 0640 or 0600", fi.Mode()&0777, entry)
}
// Read and parse.
b, err := io.ReadAll(f)
if err != nil {
return err
}
entrySections, err := config_parser.Parse(string(b))
if err != nil {
return fmt.Errorf("failed to parse config file %v:\n%w", entry, err)
}
m.entryToSectionMap[entry] = m.convertSectionsToMap(entrySections)
return nil
}
func unsqueezeEntries(patternEntries []string) (unsqueezed []string, err error) {
unsqueezed = make([]string, 0, len(patternEntries))
for _, pattern := range patternEntries {
files, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
for _, file := range files {
// We only support .dae
if !strings.HasSuffix(file, ".dae") {
continue
}
fi, err := os.Stat(file)
if err != nil {
return nil, err
}
if fi.IsDir() {
continue
}
unsqueezed = append(unsqueezed, file)
}
}
if len(unsqueezed) == 0 {
unsqueezed = nil
}
return unsqueezed, nil
}
func (m *Merger) dfsMerge(entry string, fatherEntry string) (err error) {
// Read entry and check circular include.
if err = m.readEntry(entry); err != nil {
if errors.Is(err, CircularIncludeError) {
return fmt.Errorf("%w: %v -> %v -> ... -> %v", err, fatherEntry, entry, fatherEntry)
}
return err
}
sectionMap := m.entryToSectionMap[entry]
// Extract childEntries.
includes := sectionMap["include"]
var patterEntries = make([]string, 0, len(includes))
for _, include := range includes {
switch v := include.Value.(type) {
case *config_parser.Param:
nextEntry := v.String(true)
patterEntries = append(patterEntries, filepath.Join(m.entryDir, nextEntry))
default:
return fmt.Errorf("unsupported include grammar in %v: %v", entry, include.String())
}
}
// DFS and merge children recursively.
childEntries, err := unsqueezeEntries(patterEntries)
if err != nil {
return err
}
for _, nextEntry := range childEntries {
if err = m.dfsMerge(nextEntry, entry); err != nil {
return err
}
}
/// Merge into father. Do not need to retrieve sectionMap again because go map is a reference.
if fatherEntry == "" {
// We are already on the top.
return nil
}
fatherSectionMap := m.entryToSectionMap[fatherEntry]
for sec := range sectionMap {
items := m.mergeItems(fatherSectionMap[sec], sectionMap[sec])
fatherSectionMap[sec] = items
}
return nil
}
func (m *Merger) convertSectionsToMap(sections []*config_parser.Section) (sectionMap map[string][]*config_parser.Item) {
sectionMap = make(map[string][]*config_parser.Item)
for _, sec := range sections {
items, ok := sectionMap[sec.Name]
if ok {
sectionMap[sec.Name] = m.mergeItems(items, sec.Items)
} else {
sectionMap[sec.Name] = sec.Items
}
}
return sectionMap
}
func (m *Merger) convertMapToSections(sectionMap map[string][]*config_parser.Item) (sections []*config_parser.Section) {
sections = make([]*config_parser.Section, 0, len(sectionMap))
for name, items := range sectionMap {
sections = append(sections, &config_parser.Section{
Name: name,
Items: items,
})
}
return sections
}
func (m *Merger) mergeItems(to, from []*config_parser.Item) (items []*config_parser.Item) {
items = make([]*config_parser.Item, len(to)+len(from))
copy(items, to)
copy(items[len(to):], from)
return items
}