optimize: give warning instead of error when invalid domain is given in routing

This commit is contained in:
mzz2017
2023-03-25 00:35:45 +08:00
parent 2d75511c7e
commit 07ff753cf3
5 changed files with 42 additions and 28 deletions

View File

@ -7,17 +7,18 @@ package dns
import ( import (
"fmt" "fmt"
"github.com/sirupsen/logrus"
"github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/component/routing" "github.com/daeuniverse/dae/component/routing"
"github.com/daeuniverse/dae/component/routing/domain_matcher" "github.com/daeuniverse/dae/component/routing/domain_matcher"
"github.com/daeuniverse/dae/config" "github.com/daeuniverse/dae/config"
"github.com/daeuniverse/dae/pkg/config_parser" "github.com/daeuniverse/dae/pkg/config_parser"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
"strconv" "strconv"
) )
type RequestMatcherBuilder struct { type RequestMatcherBuilder struct {
log *logrus.Logger
upstreamName2Id map[string]uint8 upstreamName2Id map[string]uint8
simulatedDomainSet []routing.DomainSet simulatedDomainSet []routing.DomainSet
fallback *routing.Outbound fallback *routing.Outbound
@ -25,7 +26,7 @@ type RequestMatcherBuilder struct {
} }
func NewRequestMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, upstreamName2Id map[string]uint8, fallback config.FunctionOrString) (b *RequestMatcherBuilder, err error) { func NewRequestMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, upstreamName2Id map[string]uint8, fallback config.FunctionOrString) (b *RequestMatcherBuilder, err error) {
b = &RequestMatcherBuilder{upstreamName2Id: upstreamName2Id} b = &RequestMatcherBuilder{log: log, upstreamName2Id: upstreamName2Id}
rulesBuilder := routing.NewRulesBuilder(log) rulesBuilder := routing.NewRulesBuilder(log)
rulesBuilder.RegisterFunctionParser(consts.Function_QName, routing.PlainParserFactory(b.addQName)) rulesBuilder.RegisterFunctionParser(consts.Function_QName, routing.PlainParserFactory(b.addQName))
rulesBuilder.RegisterFunctionParser(consts.Function_QType, TypeParserFactory(b.addQType)) rulesBuilder.RegisterFunctionParser(consts.Function_QType, TypeParserFactory(b.addQType))
@ -123,7 +124,7 @@ func (b *RequestMatcherBuilder) addFallback(fallbackOutbound config.FunctionOrSt
func (b *RequestMatcherBuilder) Build() (matcher *RequestMatcher, err error) { func (b *RequestMatcherBuilder) Build() (matcher *RequestMatcher, err error) {
var m RequestMatcher var m RequestMatcher
// Build domainMatcher // Build domainMatcher
m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen) m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(b.log, consts.MaxMatchSetLen)
for _, domains := range b.simulatedDomainSet { for _, domains := range b.simulatedDomainSet {
m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key) m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
} }

View File

@ -7,14 +7,14 @@ package dns
import ( import (
"fmt" "fmt"
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
"github.com/sirupsen/logrus"
"github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/component/routing" "github.com/daeuniverse/dae/component/routing"
"github.com/daeuniverse/dae/component/routing/domain_matcher" "github.com/daeuniverse/dae/component/routing/domain_matcher"
"github.com/daeuniverse/dae/config" "github.com/daeuniverse/dae/config"
"github.com/daeuniverse/dae/pkg/config_parser" "github.com/daeuniverse/dae/pkg/config_parser"
"github.com/daeuniverse/dae/pkg/trie" "github.com/daeuniverse/dae/pkg/trie"
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
"net/netip" "net/netip"
"strconv" "strconv"
@ -24,6 +24,7 @@ import (
var ValidCidrChars = trie.NewValidChars([]byte{'0', '1'}) var ValidCidrChars = trie.NewValidChars([]byte{'0', '1'})
type ResponseMatcherBuilder struct { type ResponseMatcherBuilder struct {
log *logrus.Logger
upstreamName2Id map[string]uint8 upstreamName2Id map[string]uint8
simulatedDomainSet []routing.DomainSet simulatedDomainSet []routing.DomainSet
ipSet []*trie.Trie ipSet []*trie.Trie
@ -32,7 +33,7 @@ type ResponseMatcherBuilder struct {
} }
func NewResponseMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, upstreamName2Id map[string]uint8, fallback config.FunctionOrString) (b *ResponseMatcherBuilder, err error) { func NewResponseMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, upstreamName2Id map[string]uint8, fallback config.FunctionOrString) (b *ResponseMatcherBuilder, err error) {
b = &ResponseMatcherBuilder{upstreamName2Id: upstreamName2Id} b = &ResponseMatcherBuilder{log: log, upstreamName2Id: upstreamName2Id}
rulesBuilder := routing.NewRulesBuilder(log) rulesBuilder := routing.NewRulesBuilder(log)
rulesBuilder.RegisterFunctionParser(consts.Function_QName, routing.PlainParserFactory(b.addQName)) rulesBuilder.RegisterFunctionParser(consts.Function_QName, routing.PlainParserFactory(b.addQName))
rulesBuilder.RegisterFunctionParser(consts.Function_QType, TypeParserFactory(b.addQType)) rulesBuilder.RegisterFunctionParser(consts.Function_QType, TypeParserFactory(b.addQType))
@ -208,7 +209,7 @@ func (b *ResponseMatcherBuilder) addFallback(fallbackOutbound config.FunctionOrS
func (b *ResponseMatcherBuilder) Build() (matcher *ResponseMatcher, err error) { func (b *ResponseMatcherBuilder) Build() (matcher *ResponseMatcher, err error) {
var m ResponseMatcher var m ResponseMatcher
// Build domainMatcher. // Build domainMatcher.
m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen) m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(b.log, consts.MaxMatchSetLen)
for _, domains := range b.simulatedDomainSet { for _, domains := range b.simulatedDomainSet {
m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key) m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
} }

View File

@ -7,9 +7,10 @@ package domain_matcher
import ( import (
"fmt" "fmt"
"github.com/v2rayA/ahocorasick-domain"
"github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/pkg/trie" "github.com/daeuniverse/dae/pkg/trie"
"github.com/sirupsen/logrus"
"github.com/v2rayA/ahocorasick-domain"
"regexp" "regexp"
"sort" "sort"
"strings" "strings"
@ -18,6 +19,8 @@ import (
var ValidDomainChars = trie.NewValidChars([]byte("0123456789abcdefghijklmnopqrstuvwxyz-.^")) var ValidDomainChars = trie.NewValidChars([]byte("0123456789abcdefghijklmnopqrstuvwxyz-.^"))
type AhocorasickSlimtrie struct { type AhocorasickSlimtrie struct {
log *logrus.Logger
validAcIndexes []int validAcIndexes []int
validTrieIndexes []int validTrieIndexes []int
validRegexpIndexes []int validRegexpIndexes []int
@ -30,8 +33,9 @@ type AhocorasickSlimtrie struct {
err error err error
} }
func NewAhocorasickSlimtrie(bitLength int) *AhocorasickSlimtrie { func NewAhocorasickSlimtrie(log *logrus.Logger, bitLength int) *AhocorasickSlimtrie {
return &AhocorasickSlimtrie{ return &AhocorasickSlimtrie{
log: log,
ac: make([]*ahocorasick.Matcher, bitLength), ac: make([]*ahocorasick.Matcher, bitLength),
trie: make([]*trie.Trie, bitLength), trie: make([]*trie.Trie, bitLength),
regexp: make([][]*regexp.Regexp, bitLength), regexp: make([][]*regexp.Regexp, bitLength),
@ -43,13 +47,24 @@ func (n *AhocorasickSlimtrie) AddSet(bitIndex int, patterns []string, typ consts
if n.err != nil { if n.err != nil {
return return
} }
switch typ { nextPattern:
case consts.RoutingDomainKey_Full: for _, d := range patterns {
for _, d := range patterns { switch typ {
case consts.RoutingDomainKey_Full:
for _, r := range []byte(d) {
if !ValidDomainChars.IsValidChar(r) {
n.log.Warnf("DomainMatcher: skip bad full domain: %v: unexpected chat: %v", d, r)
continue nextPattern
}
}
n.toBuildTrie[bitIndex] = append(n.toBuildTrie[bitIndex], "^"+d+"$") n.toBuildTrie[bitIndex] = append(n.toBuildTrie[bitIndex], "^"+d+"$")
} case consts.RoutingDomainKey_Suffix:
case consts.RoutingDomainKey_Suffix: for _, r := range []byte(d) {
for _, d := range patterns { if !ValidDomainChars.IsValidChar(r) {
n.log.Warnf("DomainMatcher: skip bad suffix domain: %v: unexpected chat: %v", d, r)
continue nextPattern
}
}
if strings.HasPrefix(d, ".") { if strings.HasPrefix(d, ".") {
// abc.example.com // abc.example.com
n.toBuildTrie[bitIndex] = append(n.toBuildTrie[bitIndex], d+"$") n.toBuildTrie[bitIndex] = append(n.toBuildTrie[bitIndex], d+"$")
@ -61,24 +76,20 @@ func (n *AhocorasickSlimtrie) AddSet(bitIndex int, patterns []string, typ consts
n.toBuildTrie[bitIndex] = append(n.toBuildTrie[bitIndex], "^"+d+"$") n.toBuildTrie[bitIndex] = append(n.toBuildTrie[bitIndex], "^"+d+"$")
// cannot match abcexample.com // cannot match abcexample.com
} }
} case consts.RoutingDomainKey_Keyword:
case consts.RoutingDomainKey_Keyword: // Only use ac automaton for "keyword" matching to save memory.
// Only use ac automaton for "keyword" matching to save memory.
for _, d := range patterns {
n.toBuildAc[bitIndex] = append(n.toBuildAc[bitIndex], []byte(d)) n.toBuildAc[bitIndex] = append(n.toBuildAc[bitIndex], []byte(d))
} case consts.RoutingDomainKey_Regex:
case consts.RoutingDomainKey_Regex:
for _, d := range patterns {
r, err := regexp.Compile(d) r, err := regexp.Compile(d)
if err != nil { if err != nil {
n.err = fmt.Errorf("failed to compile regex: %v", d) n.err = fmt.Errorf("failed to compile regex: %v", d)
return return
} }
n.regexp[bitIndex] = append(n.regexp[bitIndex], r) n.regexp[bitIndex] = append(n.regexp[bitIndex], r)
default:
n.err = fmt.Errorf("unknown RoutingDomainKey: %v", typ)
return
} }
default:
n.err = fmt.Errorf("unknown RoutingDomainKey: %v", typ)
return
} }
} }
func (n *AhocorasickSlimtrie) MatchDomainBitmap(domain string) (bitmap []uint32) { func (n *AhocorasickSlimtrie) MatchDomainBitmap(domain string) (bitmap []uint32) {

View File

@ -317,7 +317,7 @@ func NewControlPlane(
} }
routingMatcher, err := builder.BuildUserspace(core.bpf.LpmArrayMap) routingMatcher, err := builder.BuildUserspace(core.bpf.LpmArrayMap)
if err != nil { if err != nil {
return nil, fmt.Errorf("RoutingMatcherBuilder.BuildKernspace: %w", err) return nil, fmt.Errorf("RoutingMatcherBuilder.BuildUserspace: %w", err)
} }
/// Dial mode. /// Dial mode.

View File

@ -21,6 +21,7 @@ import (
) )
type RoutingMatcherBuilder struct { type RoutingMatcherBuilder struct {
log *logrus.Logger
outboundName2Id map[string]uint8 outboundName2Id map[string]uint8
bpf *bpfObjects bpf *bpfObjects
rules []bpfMatchSet rules []bpfMatchSet
@ -30,7 +31,7 @@ type RoutingMatcherBuilder struct {
} }
func NewRoutingMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, outboundName2Id map[string]uint8, bpf *bpfObjects, fallback config.FunctionOrString) (b *RoutingMatcherBuilder, err error) { func NewRoutingMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, outboundName2Id map[string]uint8, bpf *bpfObjects, fallback config.FunctionOrString) (b *RoutingMatcherBuilder, err error) {
b = &RoutingMatcherBuilder{outboundName2Id: outboundName2Id, bpf: bpf} b = &RoutingMatcherBuilder{log: log, outboundName2Id: outboundName2Id, bpf: bpf}
rulesBuilder := routing.NewRulesBuilder(log) rulesBuilder := routing.NewRulesBuilder(log)
rulesBuilder.RegisterFunctionParser(consts.Function_Domain, routing.PlainParserFactory(b.addDomain)) rulesBuilder.RegisterFunctionParser(consts.Function_Domain, routing.PlainParserFactory(b.addDomain))
rulesBuilder.RegisterFunctionParser(consts.Function_Ip, routing.IpParserFactory(b.addIp)) rulesBuilder.RegisterFunctionParser(consts.Function_Ip, routing.IpParserFactory(b.addIp))
@ -318,7 +319,7 @@ func (b *RoutingMatcherBuilder) BuildKernspace(log *logrus.Logger) (err error) {
func (b *RoutingMatcherBuilder) BuildUserspace(lpmArrayMap *ebpf.Map) (matcher *RoutingMatcher, err error) { func (b *RoutingMatcherBuilder) BuildUserspace(lpmArrayMap *ebpf.Map) (matcher *RoutingMatcher, err error) {
// Build domainMatcher // Build domainMatcher
domainMatcher := domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen) domainMatcher := domain_matcher.NewAhocorasickSlimtrie(b.log, consts.MaxMatchSetLen)
for _, domains := range b.simulatedDomainSet { for _, domains := range b.simulatedDomainSet {
domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key) domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
} }