dae/config/patch.go
2024-01-04 17:28:16 +08:00

65 lines
1.8 KiB
Go

/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2022-2024, daeuniverse Organization <dae@v2raya.org>
*/
package config
import (
"github.com/daeuniverse/dae/common"
"github.com/sirupsen/logrus"
"strings"
"github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/pkg/config_parser"
)
type patch func(params *Config) error
var patches = []patch{
patchTcpCheckHttpMethod,
patchEmptyDns,
patchMustOutbound,
}
func patchTcpCheckHttpMethod(params *Config) error {
if !common.IsValidHttpMethod(params.Global.TcpCheckHttpMethod) {
logrus.Warnf("Unknown HTTP Method '%v'. Fallback to 'CONNECT'.", params.Global.TcpCheckHttpMethod)
params.Global.TcpCheckHttpMethod = "CONNECT"
}
return nil
}
func patchEmptyDns(params *Config) error {
if params.Dns.Routing.Request.Fallback == nil {
params.Dns.Routing.Request.Fallback = consts.DnsRequestOutboundIndex_AsIs.String()
}
if params.Dns.Routing.Response.Fallback == nil {
params.Dns.Routing.Response.Fallback = consts.DnsResponseOutboundIndex_Accept.String()
}
return nil
}
func patchMustOutbound(params *Config) error {
for i := range params.Routing.Rules {
if strings.HasPrefix(params.Routing.Rules[i].Outbound.Name, "must_") {
if params.Routing.Rules[i].Outbound.Name == "must_rules" {
// Reserve must_rules.
continue
}
params.Routing.Rules[i].Outbound.Name = strings.TrimPrefix(params.Routing.Rules[i].Outbound.Name, "must_")
params.Routing.Rules[i].Outbound.Params = append(params.Routing.Rules[i].Outbound.Params, &config_parser.Param{
Val: "must",
})
}
}
if f := FunctionOrStringToFunction(params.Routing.Fallback); strings.HasPrefix(f.Name, "must_") {
f.Name = strings.TrimPrefix(f.Name, "must_")
f.Params = append(f.Params, &config_parser.Param{
Val: "must",
})
params.Routing.Fallback = f
}
return nil
}