feat: dns routing (#26)

This commit is contained in:
mzz
2023-02-25 02:38:21 +08:00
committed by GitHub
parent 33ad434f8a
commit 8bd6a77398
48 changed files with 2758 additions and 1449 deletions

203
component/dns/dns.go Normal file
View File

@ -0,0 +1,203 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package dns
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/routing"
"github.com/v2rayA/dae/config"
"golang.org/x/net/dns/dnsmessage"
"net/netip"
"net/url"
"sync"
)
var BadUpstreamFormatError = fmt.Errorf("bad upstream format")
type Dns struct {
upstream []*UpstreamResolver
upstream2IndexMu sync.Mutex
upstream2Index map[*Upstream]int
reqMatcher *RequestMatcher
respMatcher *ResponseMatcher
}
type NewOption struct {
UpstreamReadyCallback func(raw *url.URL, upstream *Upstream) (err error)
}
func New(log *logrus.Logger, dns *config.Dns, opt *NewOption) (s *Dns, err error) {
s = &Dns{
upstream2Index: map[*Upstream]int{
nil: int(consts.DnsRequestOutboundIndex_AsIs),
},
}
// Parse upstream.
upstreamName2Id := map[string]uint8{}
for i, upstreamRaw := range dns.Upstream {
if i >= int(consts.DnsRequestOutboundIndex_UserDefinedMax) ||
i >= int(consts.DnsResponseOutboundIndex_UserDefinedMax) {
return nil, fmt.Errorf("too many upstreams")
}
tag, link := common.GetTagFromLinkLikePlaintext(upstreamRaw)
if tag == "" {
return nil, fmt.Errorf("%w: '%v' has no tag", BadUpstreamFormatError, upstreamRaw)
}
u, err := url.Parse(link)
if err != nil {
return nil, fmt.Errorf("%w: %v", BadUpstreamFormatError, err)
}
r := &UpstreamResolver{
Raw: u,
FinishInitCallback: func(i int) func(raw *url.URL, upstream *Upstream) (err error) {
return func(raw *url.URL, upstream *Upstream) (err error) {
if opt != nil && opt.UpstreamReadyCallback != nil {
if err = opt.UpstreamReadyCallback(raw, upstream); err != nil {
return err
}
}
s.upstream2IndexMu.Lock()
s.upstream2Index[upstream] = i
s.upstream2IndexMu.Unlock()
return nil
}
}(i),
}
upstreamName2Id[tag] = uint8(len(s.upstream))
s.upstream = append(s.upstream, r)
}
// Optimize routings.
if dns.Routing.Request.Rules, err = routing.ApplyRulesOptimizers(dns.Routing.Request.Rules,
&routing.DatReaderOptimizer{Logger: log},
&routing.MergeAndSortRulesOptimizer{},
&routing.DeduplicateParamsOptimizer{},
); err != nil {
return nil, err
}
if dns.Routing.Response.Rules, err = routing.ApplyRulesOptimizers(dns.Routing.Response.Rules,
&routing.DatReaderOptimizer{Logger: log},
&routing.MergeAndSortRulesOptimizer{},
&routing.DeduplicateParamsOptimizer{},
); err != nil {
return nil, err
}
// Parse request routing.
reqMatcherBuilder, err := NewRequestMatcherBuilder(log, dns.Routing.Request.Rules, upstreamName2Id, dns.Routing.Request.Fallback)
if err != nil {
return nil, fmt.Errorf("failed to build DNS request routing: %w", err)
}
s.reqMatcher, err = reqMatcherBuilder.Build()
if err != nil {
return nil, fmt.Errorf("failed to build DNS request routing: %w", err)
}
// Parse response routing.
respMatcherBuilder, err := NewResponseMatcherBuilder(log, dns.Routing.Response.Rules, upstreamName2Id, dns.Routing.Response.Fallback)
if err != nil {
return nil, fmt.Errorf("failed to build DNS response routing: %w", err)
}
s.respMatcher, err = respMatcherBuilder.Build()
if err != nil {
return nil, fmt.Errorf("failed to build DNS response routing: %w", err)
}
if len(dns.Upstream) == 0 {
// Immediately ready.
if err = opt.UpstreamReadyCallback(nil, nil); err != nil {
return nil, err
}
}
return s, nil
}
func (s *Dns) RequestSelect(msg *dnsmessage.Message) (upstream *Upstream, err error) {
if msg.Response {
return nil, fmt.Errorf("DNS request expected but DNS response received")
}
// Prepare routing.
var qname string
var qtype dnsmessage.Type
if len(msg.Questions) == 0 {
qname = ""
qtype = 0
} else {
q := msg.Questions[0]
qname = q.Name.String()
qtype = q.Type
}
// Route.
upstreamIndex, err := s.reqMatcher.Match(qname, qtype)
if err != nil {
return nil, err
}
// nil indicates AsIs.
if upstreamIndex == consts.DnsRequestOutboundIndex_AsIs {
return nil, nil
}
if int(upstreamIndex) >= len(s.upstream) {
return nil, fmt.Errorf("bad upstream index: %v not in [0, %v]", upstreamIndex, len(s.upstream)-1)
}
// Get corresponding upstream.
upstream, err = s.upstream[upstreamIndex].GetUpstream()
if err != nil {
return nil, err
}
return upstream, nil
}
func (s *Dns) ResponseSelect(msg *dnsmessage.Message, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) {
if !msg.Response {
return 0, nil, fmt.Errorf("DNS response expected but DNS request received")
}
// Prepare routing.
var qname string
var qtype dnsmessage.Type
var ips []netip.Addr
if len(msg.Questions) == 0 {
qname = ""
qtype = 0
} else {
q := msg.Questions[0]
qname = q.Name.String()
qtype = q.Type
for _, ans := range msg.Answers {
switch body := ans.Body.(type) {
case *dnsmessage.AResource:
ips = append(ips, netip.AddrFrom4(body.A))
case *dnsmessage.AAAAResource:
ips = append(ips, netip.AddrFrom16(body.AAAA))
}
}
}
s.upstream2IndexMu.Lock()
from := s.upstream2Index[fromUpstream]
s.upstream2IndexMu.Unlock()
// Route.
upstreamIndex, err = s.respMatcher.Match(qname, qtype, ips, consts.DnsRequestOutboundIndex(from))
if err != nil {
return 0, nil, err
}
// Get corresponding upstream if upstream is neither 'accept' nor 'reject'.
if !upstreamIndex.IsReserved() {
if int(upstreamIndex) >= len(s.upstream) {
return 0, nil, fmt.Errorf("bad upstream index: %v not in [0, %v]", upstreamIndex, len(s.upstream)-1)
}
upstream, err = s.upstream[upstreamIndex].GetUpstream()
if err != nil {
return 0, nil, err
}
} else {
// Assign explicitly to let coder know.
upstream = nil
}
return upstreamIndex, upstream, nil
}

View File

@ -0,0 +1,47 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package dns
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/component/routing"
"github.com/v2rayA/dae/pkg/config_parser"
"golang.org/x/net/dns/dnsmessage"
"strings"
)
var typeNames = map[string]dnsmessage.Type{
"A": dnsmessage.TypeA,
"NS": dnsmessage.TypeNS,
"CNAME": dnsmessage.TypeCNAME,
"SOA": dnsmessage.TypeSOA,
"PTR": dnsmessage.TypePTR,
"MX": dnsmessage.TypeMX,
"TXT": dnsmessage.TypeTXT,
"AAAA": dnsmessage.TypeAAAA,
"SRV": dnsmessage.TypeSRV,
"OPT": dnsmessage.TypeOPT,
"WKS": dnsmessage.TypeWKS,
"HINFO": dnsmessage.TypeHINFO,
"MINFO": dnsmessage.TypeMINFO,
"AXFR": dnsmessage.TypeAXFR,
"ALL": dnsmessage.TypeALL,
}
func TypeParserFactory(callback func(f *config_parser.Function, types []dnsmessage.Type, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) {
var types []dnsmessage.Type
for _, v := range paramValueGroup {
t, ok := typeNames[strings.ToUpper(v)]
if !ok {
return fmt.Errorf("unknown DNS request type: %v", v)
}
types = append(types, t)
}
return callback(f, types, overrideOutbound)
}
}

View File

@ -0,0 +1,213 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package dns
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/routing"
"github.com/v2rayA/dae/component/routing/domain_matcher"
"github.com/v2rayA/dae/config"
"github.com/v2rayA/dae/pkg/config_parser"
"golang.org/x/net/dns/dnsmessage"
"strconv"
)
type RequestMatcherBuilder struct {
upstreamName2Id map[string]uint8
simulatedDomainSet []routing.DomainSet
fallback *routing.Outbound
rules []requestMatchSet
}
func NewRequestMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, upstreamName2Id map[string]uint8, fallback config.FunctionOrString) (b *RequestMatcherBuilder, err error) {
b = &RequestMatcherBuilder{upstreamName2Id: upstreamName2Id}
rulesBuilder := routing.NewRulesBuilder(log)
rulesBuilder.RegisterFunctionParser(consts.Function_QName, routing.PlainParserFactory(b.addQName))
rulesBuilder.RegisterFunctionParser(consts.Function_QType, TypeParserFactory(b.addQType))
if err = rulesBuilder.Apply(rules); err != nil {
return nil, err
}
if err = b.addFallback(fallback); err != nil {
return nil, err
}
return b, nil
}
func (b *RequestMatcherBuilder) upstreamToId(upstream string) (upstreamId consts.DnsRequestOutboundIndex, err error) {
switch upstream {
case consts.DnsRequestOutboundIndex_AsIs.String():
upstreamId = consts.DnsRequestOutboundIndex_AsIs
case consts.DnsRequestOutboundIndex_LogicalAnd.String():
upstreamId = consts.DnsRequestOutboundIndex_LogicalAnd
case consts.DnsRequestOutboundIndex_LogicalOr.String():
upstreamId = consts.DnsRequestOutboundIndex_LogicalOr
default:
_upstreamId, ok := b.upstreamName2Id[upstream]
if !ok {
return 0, fmt.Errorf("upstream %v not found; please define it in section \"dns.upstream\"", strconv.Quote(upstream))
}
upstreamId = consts.DnsRequestOutboundIndex(_upstreamId)
}
return upstreamId, nil
}
func (b *RequestMatcherBuilder) addQName(f *config_parser.Function, key string, values []string, upstream *routing.Outbound) (err error) {
switch consts.RoutingDomainKey(key) {
case consts.RoutingDomainKey_Regex,
consts.RoutingDomainKey_Full,
consts.RoutingDomainKey_Keyword,
consts.RoutingDomainKey_Suffix:
default:
return fmt.Errorf("addQName: unsupported key: %v", key)
}
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
Key: consts.RoutingDomainKey(key),
RuleIndex: len(b.simulatedDomainSet),
Domains: values,
})
upstreamId, err := b.upstreamToId(upstream.Name)
if err != nil {
return err
}
b.rules = append(b.rules, requestMatchSet{
Type: consts.MatchType_DomainSet,
Not: f.Not,
Upstream: uint8(upstreamId),
})
return nil
}
func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
upstreamName = upstream.Name
}
upstreamId, err := b.upstreamToId(upstreamName)
if err != nil {
return err
}
b.rules = append(b.rules, requestMatchSet{
Type: consts.MatchType_QType,
Value: uint16(value),
Not: f.Not,
Upstream: uint8(upstreamId),
})
}
return nil
}
func (b *RequestMatcherBuilder) addFallback(fallbackOutbound config.FunctionOrString) (err error) {
upstream, err := routing.ParseOutbound(config.FunctionOrStringToFunction(fallbackOutbound))
if err != nil {
return err
}
upstreamId, err := b.upstreamToId(upstream.Name)
if err != nil {
return err
}
b.rules = append(b.rules, requestMatchSet{
Type: consts.MatchType_Fallback,
Upstream: uint8(upstreamId),
})
return nil
}
func (b *RequestMatcherBuilder) Build() (matcher *RequestMatcher, err error) {
var m RequestMatcher
// Build domainMatcher
m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen)
for _, domains := range b.simulatedDomainSet {
m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
}
if err = m.domainMatcher.Build(); err != nil {
return nil, err
}
// Write routings.
// Fallback rule MUST be the last.
if b.rules[len(b.rules)-1].Type != consts.MatchType_Fallback {
return nil, fmt.Errorf("fallback rule MUST be the last")
}
m.matches = b.rules
return &m, nil
}
type RequestMatcher struct {
domainMatcher routing.DomainMatcher // All domain matchSets use one DomainMatcher.
matches []requestMatchSet
}
type requestMatchSet struct {
Value uint16
Not bool
Type consts.MatchType
Upstream uint8
}
func (m *RequestMatcher) Match(
qName string,
qType dnsmessage.Type,
) (upstreamIndex consts.DnsRequestOutboundIndex, err error) {
var domainMatchBitmap []uint32
if qName != "" {
domainMatchBitmap = m.domainMatcher.MatchDomainBitmap(qName)
}
goodSubrule := false
badRule := false
for i, match := range m.matches {
if badRule || goodSubrule {
goto beforeNextLoop
}
switch match.Type {
case consts.MatchType_DomainSet:
if domainMatchBitmap != nil && (domainMatchBitmap[i/32]>>(i%32))&1 > 0 {
goodSubrule = true
}
case consts.MatchType_QType:
if qType == dnsmessage.Type(match.Value) {
goodSubrule = true
}
case consts.MatchType_Fallback:
goodSubrule = true
default:
return 0, fmt.Errorf("unknown match type: %v", match.Type)
}
beforeNextLoop:
upstream := consts.DnsRequestOutboundIndex(match.Upstream)
if upstream != consts.DnsRequestOutboundIndex_LogicalOr {
// This match_set reaches the end of subrule.
// We are now at end of rule, or next match_set belongs to another
// subrule.
if goodSubrule == match.Not {
// This subrule does not hit.
badRule = true
}
// Reset goodSubrule.
goodSubrule = false
}
if upstream&consts.DnsRequestOutboundIndex_LogicalMask !=
consts.DnsRequestOutboundIndex_LogicalMask {
// Tail of a rule (line).
// Decide whether to hit.
if !badRule {
return upstream, nil
}
badRule = false
}
}
return 0, fmt.Errorf("no match set hit")
}

View File

@ -0,0 +1,320 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package dns
import (
"fmt"
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/routing"
"github.com/v2rayA/dae/component/routing/domain_matcher"
"github.com/v2rayA/dae/config"
"github.com/v2rayA/dae/pkg/config_parser"
"github.com/v2rayA/dae/pkg/trie"
"golang.org/x/net/dns/dnsmessage"
"net/netip"
"strconv"
"strings"
)
var ValidCidrChars = trie.NewValidChars([]byte{'0', '1'})
type ResponseMatcherBuilder struct {
upstreamName2Id map[string]uint8
simulatedDomainSet []routing.DomainSet
ipSet []*trie.Trie
fallback *routing.Outbound
rules []responseMatchSet
}
func NewResponseMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, upstreamName2Id map[string]uint8, fallback config.FunctionOrString) (b *ResponseMatcherBuilder, err error) {
b = &ResponseMatcherBuilder{upstreamName2Id: upstreamName2Id}
rulesBuilder := routing.NewRulesBuilder(log)
rulesBuilder.RegisterFunctionParser(consts.Function_QName, routing.PlainParserFactory(b.addQName))
rulesBuilder.RegisterFunctionParser(consts.Function_QType, TypeParserFactory(b.addQType))
rulesBuilder.RegisterFunctionParser(consts.Function_Ip, routing.IpParserFactory(b.addIp))
rulesBuilder.RegisterFunctionParser(consts.Function_Upstream, routing.EmptyKeyPlainParserFactory(b.addUpstream))
if err = rulesBuilder.Apply(rules); err != nil {
return nil, err
}
if err = b.addFallback(fallback); err != nil {
return nil, err
}
return b, nil
}
func (b *ResponseMatcherBuilder) upstreamToId(upstream string) (upstreamId consts.DnsResponseOutboundIndex, err error) {
switch upstream {
case consts.DnsResponseOutboundIndex_Accept.String():
upstreamId = consts.DnsResponseOutboundIndex_Accept
case consts.DnsResponseOutboundIndex_Reject.String():
upstreamId = consts.DnsResponseOutboundIndex_Reject
case consts.DnsResponseOutboundIndex_LogicalAnd.String():
upstreamId = consts.DnsResponseOutboundIndex_LogicalAnd
case consts.DnsResponseOutboundIndex_LogicalOr.String():
upstreamId = consts.DnsResponseOutboundIndex_LogicalOr
default:
_upstreamId, ok := b.upstreamName2Id[upstream]
if !ok {
return 0, fmt.Errorf("upstream %v not found; please define it in \"dns.upstream\"", strconv.Quote(upstream))
}
upstreamId = consts.DnsResponseOutboundIndex(_upstreamId)
}
return upstreamId, nil
}
func prefix2bin128(prefix netip.Prefix) (bin128 string) {
bits := prefix.Bits()
if prefix.Addr().Is4() {
bits += 96
}
ip := prefix.Addr().As16()
buf := buffer.NewBuffer(128)
defer buf.Put()
loop:
for i := 0; i < len(ip); i++ {
for j := 0; j < 8; j++ {
if (ip[i]>>j)&1 == 1 {
buf.WriteByte('1')
} else {
buf.WriteByte('0')
}
bits--
if bits == 0 {
break loop
}
}
}
return buf.String()
}
func (b *ResponseMatcherBuilder) addIp(f *config_parser.Function, cidrs []netip.Prefix, upstream *routing.Outbound) (err error) {
upstreamId, err := b.upstreamToId(upstream.Name)
if err != nil {
return err
}
rule := responseMatchSet{
Value: uint16(len(b.ipSet)),
Type: consts.MatchType_IpSet,
Not: f.Not,
Upstream: uint8(upstreamId),
}
var keys []string
// Convert netip.Prefix -> '0' '1' string
for _, prefix := range cidrs {
keys = append(keys, prefix2bin128(prefix))
}
t, err := trie.NewTrie(keys, ValidCidrChars)
if err != nil {
return err
}
b.ipSet = append(b.ipSet, t)
b.rules = append(b.rules, rule)
return nil
}
func (b *ResponseMatcherBuilder) addQName(f *config_parser.Function, key string, values []string, upstream *routing.Outbound) (err error) {
switch consts.RoutingDomainKey(key) {
case consts.RoutingDomainKey_Regex,
consts.RoutingDomainKey_Full,
consts.RoutingDomainKey_Keyword,
consts.RoutingDomainKey_Suffix:
default:
return fmt.Errorf("addQName: unsupported key: %v", key)
}
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
Key: consts.RoutingDomainKey(key),
RuleIndex: len(b.simulatedDomainSet),
Domains: values,
})
upstreamId, err := b.upstreamToId(upstream.Name)
if err != nil {
return err
}
b.rules = append(b.rules, responseMatchSet{
Type: consts.MatchType_DomainSet,
Not: f.Not,
Upstream: uint8(upstreamId),
})
return nil
}
func (b *ResponseMatcherBuilder) addUpstream(f *config_parser.Function, values []string, upstream *routing.Outbound) (err error) {
for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
upstreamName = upstream.Name
}
upstreamId, err := b.upstreamToId(upstreamName)
if err != nil {
return err
}
lastUpstreamId, err := b.upstreamToId(value)
if err != nil {
return err
}
b.rules = append(b.rules, responseMatchSet{
Type: consts.MatchType_Upstream,
Value: uint16(lastUpstreamId),
Not: f.Not,
Upstream: uint8(upstreamId),
})
}
return nil
}
func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
upstreamName = upstream.Name
}
upstreamId, err := b.upstreamToId(upstreamName)
if err != nil {
return err
}
b.rules = append(b.rules, responseMatchSet{
Type: consts.MatchType_QType,
Value: uint16(value),
Not: f.Not,
Upstream: uint8(upstreamId),
})
}
return nil
}
func (b *ResponseMatcherBuilder) addFallback(fallbackOutbound config.FunctionOrString) (err error) {
upstream, err := routing.ParseOutbound(config.FunctionOrStringToFunction(fallbackOutbound))
if err != nil {
return err
}
upstreamId, err := b.upstreamToId(upstream.Name)
if err != nil {
return err
}
b.rules = append(b.rules, responseMatchSet{
Type: consts.MatchType_Fallback,
Upstream: uint8(upstreamId),
})
return nil
}
func (b *ResponseMatcherBuilder) Build() (matcher *ResponseMatcher, err error) {
var m ResponseMatcher
// Build domainMatcher.
m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen)
for _, domains := range b.simulatedDomainSet {
m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
}
if err = m.domainMatcher.Build(); err != nil {
return nil, err
}
// IpSet.
m.ipSet = b.ipSet
// Write routings.
// Fallback rule MUST be the last.
if b.rules[len(b.rules)-1].Type != consts.MatchType_Fallback {
return nil, fmt.Errorf("fallback rule MUST be the last")
}
m.matches = b.rules
return &m, nil
}
type ResponseMatcher struct {
domainMatcher routing.DomainMatcher // All domain matchSets use one DomainMatcher.
ipSet []*trie.Trie
matches []responseMatchSet
}
type responseMatchSet struct {
Value uint16
Not bool
Type consts.MatchType
Upstream uint8
}
func (m *ResponseMatcher) Match(
qName string,
qType dnsmessage.Type,
ips []netip.Addr,
upstream consts.DnsRequestOutboundIndex,
) (upstreamIndex consts.DnsResponseOutboundIndex, err error) {
if qName == "" {
return 0, fmt.Errorf("qName cannot be empty")
}
qName = strings.TrimSuffix(strings.ToLower(qName), ".")
domainMatchBitmap := m.domainMatcher.MatchDomainBitmap(qName)
bin128 := make([]string, 0, len(ips))
for _, ip := range ips {
bin128 = append(bin128, prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(ip.As16()), 128)))
}
goodSubrule := false
badRule := false
for i, match := range m.matches {
if badRule || goodSubrule {
goto beforeNextLoop
}
switch match.Type {
case consts.MatchType_DomainSet:
if domainMatchBitmap != nil && (domainMatchBitmap[i/32]>>(i%32))&1 > 0 {
goodSubrule = true
}
case consts.MatchType_IpSet:
for _, bin128 := range bin128 {
// Check if any of IP hit the rule.
if m.ipSet[match.Value].HasPrefix(bin128) {
goodSubrule = true
break
}
}
case consts.MatchType_QType:
if qType == dnsmessage.Type(match.Value) {
goodSubrule = true
}
case consts.MatchType_Upstream:
if upstream == consts.DnsRequestOutboundIndex(match.Value) {
goodSubrule = true
}
case consts.MatchType_Fallback:
goodSubrule = true
default:
return 0, fmt.Errorf("unknown match type: %v", match.Type)
}
beforeNextLoop:
upstream := consts.DnsResponseOutboundIndex(match.Upstream)
if upstream != consts.DnsResponseOutboundIndex_LogicalOr {
// This match_set reaches the end of subrule.
// We are now at end of rule, or next match_set belongs to another
// subrule.
if goodSubrule == match.Not {
// This subrule does not hit.
badRule = true
}
// Reset goodSubrule.
goodSubrule = false
}
if upstream&consts.DnsResponseOutboundIndex_LogicalMask !=
consts.DnsResponseOutboundIndex_LogicalMask {
// Tail of a rule (line).
// Decide whether to hit.
if !badRule {
return upstream, nil
}
badRule = false
}
}
return 0, fmt.Errorf("no match set hit")
}

153
component/dns/upstream.go Normal file
View File

@ -0,0 +1,153 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2022-2023, v2rayA Organization <team@v2raya.org>
*/
package dns
import (
"context"
"fmt"
"github.com/mzz2017/softwind/protocol/direct"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/common/netutils"
"net"
"net/url"
"strconv"
"sync"
"time"
)
type UpstreamScheme string
const (
UpstreamScheme_TCP UpstreamScheme = "tcp"
UpstreamScheme_UDP UpstreamScheme = "udp"
UpstreamScheme_TCP_UDP UpstreamScheme = "tcp+udp"
)
func (s UpstreamScheme) ContainsTcp() bool {
switch s {
case UpstreamScheme_TCP,
UpstreamScheme_TCP_UDP:
return true
default:
return false
}
}
func ParseRawUpstream(raw *url.URL) (scheme UpstreamScheme, hostname string, port uint16, err error) {
var __port string
switch scheme = UpstreamScheme(raw.Scheme); scheme {
case UpstreamScheme_TCP, UpstreamScheme_UDP, UpstreamScheme_TCP_UDP:
__port = raw.Port()
if __port == "" {
__port = "53"
}
default:
return "", "", 0, fmt.Errorf("unexpected dns_upstream format")
}
_port, err := strconv.ParseUint(__port, 10, 16)
if err != nil {
return "", "", 0, fmt.Errorf("failed to parse dns_upstream port: %v", err)
}
port = uint16(_port)
hostname = raw.Hostname()
return scheme, hostname, port, nil
}
type Upstream struct {
Scheme UpstreamScheme
Hostname string
Port uint16
*netutils.Ip46
}
func NewUpstream(ctx context.Context, upstream *url.URL) (up *Upstream, err error) {
scheme, hostname, port, err := ParseRawUpstream(upstream)
if err != nil {
return nil, err
}
systemDns, err := netutils.SystemDns()
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = netutils.TryUpdateSystemDns1s()
}
}()
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, false)
if err != nil {
return nil, fmt.Errorf("failed to resolve dns_upstream: %w", err)
}
if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() {
return nil, fmt.Errorf("dns_upstream has no record")
}
return &Upstream{
Scheme: scheme,
Hostname: hostname,
Port: port,
Ip46: ip46,
}, nil
}
func (u *Upstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4protos []consts.L4ProtoStr) {
if u.Ip4.IsValid() && u.Ip6.IsValid() {
ipversions = []consts.IpVersionStr{consts.IpVersionStr_4, consts.IpVersionStr_6}
} else {
if u.Ip4.IsValid() {
ipversions = []consts.IpVersionStr{consts.IpVersionStr_4}
} else {
ipversions = []consts.IpVersionStr{consts.IpVersionStr_6}
}
}
switch u.Scheme {
case UpstreamScheme_TCP:
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_TCP}
case UpstreamScheme_UDP:
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP}
case UpstreamScheme_TCP_UDP:
// UDP first.
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP, consts.L4ProtoStr_TCP}
}
return ipversions, l4protos
}
func (u *Upstream) String() string {
return string(u.Scheme) + "://" + net.JoinHostPort(u.Hostname, strconv.Itoa(int(u.Port)))
}
type UpstreamResolver struct {
Raw *url.URL
// FinishInitCallback may be invoked again if err is not nil
FinishInitCallback func(raw *url.URL, upstream *Upstream) (err error)
mu sync.Mutex
upstream *Upstream
init bool
}
func (u *UpstreamResolver) GetUpstream() (_ *Upstream, err error) {
u.mu.Lock()
defer u.mu.Unlock()
if !u.init {
defer func() {
if err == nil {
if err = u.FinishInitCallback(u.Raw, u.upstream); err != nil {
u.upstream = nil
return
}
u.init = true
}
}()
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
if u.upstream, err = NewUpstream(ctx, u.Raw); err != nil {
return nil, fmt.Errorf("failed to init dns upstream: %v", err)
}
}
return u.upstream, nil
}

View File

@ -90,11 +90,8 @@ func (a *AliveDialerSet) GetMinLatency() (d *Dialer, latency time.Duration) {
}
func (a *AliveDialerSet) printLatencies() {
if !a.log.IsLevelEnabled(logrus.TraceLevel) {
return
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("%v (%v):\n", a.dialerGroupName, a.CheckTyp.String()))
builder.WriteString(fmt.Sprintf("Group '%v' [%v]:\n", a.dialerGroupName, a.CheckTyp.String()))
for _, d := range a.inorderedAliveDialerSet {
latency, ok := a.dialerToLatency[d]
if !ok {
@ -210,9 +207,13 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
string(a.selectionPolicy): a.minLatency.latency,
"group": a.dialerGroupName,
"network": a.CheckTyp.String(),
"new dialer": a.minLatency.dialer.Name(),
"old dialer": oldDialerName,
"new_dialer": a.minLatency.dialer.Name(),
"old_dialer": oldDialerName,
}).Infof("Group %vselects dialer", re)
if a.log.IsLevelEnabled(logrus.TraceLevel) {
a.printLatencies()
}
} else {
// Alive -> not alive
defer a.aliveChangeCallback(false)
@ -221,9 +222,6 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
"network": a.CheckTyp.String(),
}).Infof("Group has no dialer alive")
}
if a.log.IsLevelEnabled(logrus.TraceLevel) {
a.printLatencies()
}
}
} else {
if alive && minPolicy && a.minLatency.dialer == nil {

View File

@ -118,7 +118,7 @@ func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOptio
if err != nil {
return nil, err
}
ip46, err := netutils.ParseIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), false)
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), false)
if err != nil {
return nil, err
}
@ -153,7 +153,7 @@ func ParseCheckDnsOption(ctx context.Context, dnsHostPort string) (opt *CheckDns
if err != nil {
return nil, fmt.Errorf("bad port: %v", err)
}
ip46, err := netutils.ParseIp46(ctx, direct.SymmetricDirect, systemDns, host, false)
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, host, false)
if err != nil {
return nil, err
}
@ -409,6 +409,10 @@ func (d *Dialer) NotifyCheck() {
}
}
func (d *Dialer) MustGetLatencies10(typ *NetworkType) *LatenciesN {
return d.mustGetCollection(typ).Latencies10
}
// RegisterAliveDialerSet is thread-safe.
func (d *Dialer) RegisterAliveDialerSet(a *AliveDialerSet) {
if a == nil {

View File

@ -19,7 +19,7 @@ type DialerSelectionPolicy struct {
FixedIndex int
}
func NewDialerSelectionPolicyFromGroupParam(param *config.GroupParam) (policy *DialerSelectionPolicy, err error) {
func NewDialerSelectionPolicyFromGroupParam(param *config.Group) (policy *DialerSelectionPolicy, err error) {
switch val := param.Policy.(type) {
case string:
switch consts.DialerSelectionPolicy(val) {

View File

@ -15,6 +15,8 @@ import (
"strings"
)
var ValidDomainChars = trie.NewValidChars([]byte("0123456789abcdefghijklmnopqrstuvwxyz-.^"))
type AhocorasickSlimtrie struct {
validAcIndexes []int
validTrieIndexes []int
@ -173,7 +175,7 @@ func (n *AhocorasickSlimtrie) Build() (err error) {
}
toBuild = ToSuffixTrieStrings(toBuild)
sort.Strings(toBuild)
n.trie[i], err = trie.NewTrie(toBuild)
n.trie[i], err = trie.NewTrie(toBuild, ValidDomainChars)
if err != nil {
return err
}

View File

@ -94,7 +94,6 @@ var TestSample = []string{
}
type RoutingMatcherBuilder struct {
*routing.DefaultMatcherBuilder
outboundName2Id map[string]uint8
simulatedDomainSet []routing.DomainSet
Fallback string
@ -118,7 +117,7 @@ func (b *RoutingMatcherBuilder) AddDomain(f *config_parser.Function, key string,
consts.RoutingDomainKey_Keyword,
consts.RoutingDomainKey_Suffix:
default:
b.err = fmt.Errorf("AddDomain: unsupported key: %v", key)
b.err = fmt.Errorf("addDomain: unsupported key: %v", key)
return
}
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
@ -132,22 +131,16 @@ func getDomain() (simulatedDomainSet []routing.DomainSet, err error) {
var rules []*config_parser.RoutingRule
sections, err := config_parser.Parse(`
routing {
pname(NetworkManager, dnsmasq, systemd-resolved) -> must_direct # Traffic of DNS in local must be direct to avoid loop when binding to WAN.
pname(sogou-qimpanel, sogou-qimpanel-watchdog) -> block
ip(geoip:private, 224.0.0.0/3, 'ff00::/8') -> direct # Put it in front unless you know what you're doing.
domain(geosite:bing)->us
domain(full:dns.google) && port(53) -> direct
domain(full:dns.google) -> direct
domain(geosite:category-ads-all) -> block
ip(geoip:private) -> direct
ip(geoip:cn) -> direct
domain(geosite:cn) -> direct
fallback: my_group
}`)
if err != nil {
return nil, err
}
var r config.Routing
if err = config.RoutingRuleAndParamParser(reflect.ValueOf(&r), sections[0]); err != nil {
if err = config.SectionParser(reflect.ValueOf(&r), sections[0]); err != nil {
return nil, err
}
if rules, err = routing.ApplyRulesOptimizers(r.Rules,
@ -159,8 +152,13 @@ routing {
return nil, fmt.Errorf("ApplyRulesOptimizers error:\n%w", err)
}
builder := RoutingMatcherBuilder{}
if err = routing.ApplyMatcherBuilder(logrus.StandardLogger(), &builder, rules, r.Fallback); err != nil {
return nil, fmt.Errorf("ApplyMatcherBuilder: %w", err)
rb := routing.NewRulesBuilder(logrus.StandardLogger())
rb.RegisterFunctionParser("domain", func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) {
builder.AddDomain(f, key, paramValueGroup, overrideOutbound)
return nil
})
if err = rb.Apply(rules); err != nil {
return nil, fmt.Errorf("Apply: %w", err)
}
return builder.simulatedDomainSet, nil
}

View File

@ -0,0 +1,139 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package routing
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/pkg/config_parser"
"net/netip"
"strings"
)
type FunctionParser func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error)
// Preset function parser factories.
// PlainParserFactory is for style unity.
func PlainParserFactory(callback func(f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
return callback(f, key, paramValueGroup, overrideOutbound)
}
}
// EmptyKeyPlainParserFactory only accepts function with empty key.
func EmptyKeyPlainParserFactory(callback func(f *config_parser.Function, values []string, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
if key != "" {
return fmt.Errorf("this function cannot accept a key")
}
return callback(f, paramValueGroup, overrideOutbound)
}
}
func IpParserFactory(callback func(f *config_parser.Function, cidrs []netip.Prefix, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
cidrs, err := parsePrefixes(paramValueGroup)
if err != nil {
return err
}
return callback(f, cidrs, overrideOutbound)
}
}
func MacParserFactory(callback func(f *config_parser.Function, macAddrs [][6]byte, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
var macAddrs [][6]byte
for _, v := range paramValueGroup {
mac, err := common.ParseMac(v)
if err != nil {
return err
}
macAddrs = append(macAddrs, mac)
}
return callback(f, macAddrs, overrideOutbound)
}
}
func PortRangeParserFactory(callback func(f *config_parser.Function, portRanges [][2]uint16, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
var portRanges [][2]uint16
for _, v := range paramValueGroup {
portRange, err := common.ParsePortRange(v)
if err != nil {
return err
}
portRanges = append(portRanges, portRange)
}
return callback(f, portRanges, overrideOutbound)
}
}
func L4ProtoParserFactory(callback func(f *config_parser.Function, l4protoType consts.L4ProtoType, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
var l4protoType consts.L4ProtoType
for _, v := range paramValueGroup {
switch v {
case "tcp":
l4protoType |= consts.L4ProtoType_TCP
case "udp":
l4protoType |= consts.L4ProtoType_UDP
}
}
return callback(f, l4protoType, overrideOutbound)
}
}
func IpVersionParserFactory(callback func(f *config_parser.Function, ipVersion consts.IpVersionType, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
var ipVersion consts.IpVersionType
for _, v := range paramValueGroup {
switch v {
case "4":
ipVersion |= consts.IpVersion_4
case "6":
ipVersion |= consts.IpVersion_6
}
}
return callback(f, ipVersion, overrideOutbound)
}
}
func ProcessNameParserFactory(callback func(f *config_parser.Function, procNames [][consts.TaskCommLen]byte, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
var procNames [][consts.TaskCommLen]byte
for _, v := range paramValueGroup {
if len([]byte(v)) > consts.TaskCommLen {
log.Infof(`pname routing: trim "%v" to "%v" because it is too long.`, v, string([]byte(v)[:consts.TaskCommLen]))
}
procNames = append(procNames, toProcessName(v))
}
return callback(f, procNames, overrideOutbound)
}
}
func parsePrefixes(values []string) (cidrs []netip.Prefix, err error) {
for _, value := range values {
toParse := value
if strings.LastIndexByte(value, '/') == -1 {
toParse += "/32"
}
prefix, err := netip.ParsePrefix(toParse)
if err != nil {
return nil, fmt.Errorf("cannot parse %v: %w", value, err)
}
cidrs = append(cidrs, prefix)
}
return cidrs, nil
}
func toProcessName(processName string) (procName [consts.TaskCommLen]byte) {
n := []byte(processName)
copy(procName[:], n)
return procName
}

View File

@ -8,18 +8,11 @@ package routing
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/pkg/config_parser"
"net/netip"
"strconv"
"strings"
)
var FakeOutbound_MUST_DIRECT = consts.OutboundMustDirect.String()
var FakeOutbound_AND = consts.OutboundLogicalAnd.String()
var FakeOutbound_OR = consts.OutboundLogicalOr.String()
type DomainSet struct {
Key consts.RoutingDomainKey
RuleIndex int
@ -31,22 +24,71 @@ type Outbound struct {
Mark uint32
}
type MatcherBuilder interface {
AddDomain(f *config_parser.Function, key string, values []string, outbound *Outbound)
AddIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound)
AddPort(f *config_parser.Function, values [][2]uint16, outbound *Outbound)
AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound)
AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound *Outbound)
AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *Outbound)
AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *Outbound)
AddSourceMac(f *config_parser.Function, values [][6]byte, outbound *Outbound)
AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *Outbound)
AddFallback(outbound *Outbound)
AddAnyBefore(f *config_parser.Function, key string, values []string, outbound *Outbound)
AddAnyAfter(f *config_parser.Function, key string, values []string, outbound *Outbound)
type RulesBuilder struct {
log *logrus.Logger
parsers map[string]FunctionParser
}
func GroupParamValuesByKey(params []*config_parser.Param) (keyToValues map[string][]string, keyOrder []string) {
func NewRulesBuilder(log *logrus.Logger) *RulesBuilder {
return &RulesBuilder{
log: log,
parsers: make(map[string]FunctionParser),
}
}
func (b *RulesBuilder) RegisterFunctionParser(funcName string, parser FunctionParser) {
b.parsers[funcName] = parser
}
func (b *RulesBuilder) Apply(rules []*config_parser.RoutingRule) (err error) {
for _, rule := range rules {
b.log.Debugln("[rule]", rule.String(true))
outbound, err := ParseOutbound(&rule.Outbound)
if err != nil {
return err
}
// rule is like: domain(domain:baidu.com) && port(443) -> proxy
for iFunc, f := range rule.AndFunctions {
// f is like: domain(domain:baidu.com)
functionParser, ok := b.parsers[f.Name]
if !ok {
return fmt.Errorf("unknown function: %v", f.Name)
}
paramValueGroups, keyOrder := groupParamValuesByKey(f.Params)
for jMatchSet, key := range keyOrder {
paramValueGroup := paramValueGroups[key]
// Preprocess the outbound.
overrideOutbound := &Outbound{
Name: consts.OutboundLogicalOr.String(),
Mark: outbound.Mark,
}
if jMatchSet == len(keyOrder)-1 {
overrideOutbound.Name = consts.OutboundLogicalAnd.String()
if iFunc == len(rule.AndFunctions)-1 {
overrideOutbound.Name = outbound.Name
}
}
{
// Debug
symNot := ""
if f.Not {
symNot = "!"
}
b.log.Debugf("\t%v%v(%v) -> %v", symNot, f.Name, key, overrideOutbound.Name)
}
if err = functionParser(b.log, f, key, paramValueGroup, overrideOutbound); err != nil {
return fmt.Errorf("failed to parse '%v': %w", f.String(false), err)
}
}
}
}
return nil
}
func groupParamValuesByKey(params []*config_parser.Param) (keyToValues map[string][]string, keyOrder []string) {
groups := make(map[string][]string)
for _, param := range params {
if _, ok := groups[param.Key]; !ok {
@ -57,28 +99,7 @@ func GroupParamValuesByKey(params []*config_parser.Param) (keyToValues map[strin
return groups, keyOrder
}
func ParsePrefixes(values []string) (cidrs []netip.Prefix, err error) {
for _, value := range values {
toParse := value
if strings.LastIndexByte(value, '/') == -1 {
toParse += "/32"
}
prefix, err := netip.ParsePrefix(toParse)
if err != nil {
return nil, fmt.Errorf("cannot parse %v: %w", value, err)
}
cidrs = append(cidrs, prefix)
}
return cidrs, nil
}
func ToProcessName(processName string) (procName [consts.TaskCommLen]byte) {
n := []byte(processName)
copy(procName[:], n)
return procName
}
func parseOutbound(rawOutbound *config_parser.Function) (outbound *Outbound, err error) {
func ParseOutbound(rawOutbound *config_parser.Function) (outbound *Outbound, err error) {
outbound = &Outbound{
Name: rawOutbound.Name,
Mark: 0,
@ -98,164 +119,3 @@ func parseOutbound(rawOutbound *config_parser.Function) (outbound *Outbound, err
}
return outbound, nil
}
func ApplyMatcherBuilder(log *logrus.Logger, builder MatcherBuilder, rules []*config_parser.RoutingRule, fallbackOutbound interface{}) (err error) {
for _, rule := range rules {
log.Debugln("[rule]", rule.String(true))
outbound, err := parseOutbound(&rule.Outbound)
if err != nil {
return err
}
// rule is like: domain(domain:baidu.com) && port(443) -> proxy
for iFunc, f := range rule.AndFunctions {
// f is like: domain(domain:baidu.com)
paramValueGroups, keyOrder := GroupParamValuesByKey(f.Params)
for jMatchSet, key := range keyOrder {
paramValueGroup := paramValueGroups[key]
// Preprocess the outbound.
overrideOutbound := &Outbound{
Name: FakeOutbound_OR,
Mark: outbound.Mark,
}
if jMatchSet == len(keyOrder)-1 {
overrideOutbound.Name = FakeOutbound_AND
if iFunc == len(rule.AndFunctions)-1 {
overrideOutbound.Name = outbound.Name
}
}
{
// Debug
symNot := ""
if f.Not {
symNot = "!"
}
log.Debugf("\t%v%v(%v) -> %v", symNot, f.Name, key, overrideOutbound)
}
builder.AddAnyBefore(f, key, paramValueGroup, overrideOutbound)
switch f.Name {
case consts.Function_Domain:
builder.AddDomain(f, key, paramValueGroup, overrideOutbound)
case consts.Function_Ip, consts.Function_SourceIp:
cidrs, err := ParsePrefixes(paramValueGroup)
if err != nil {
return err
}
if f.Name == consts.Function_Ip {
builder.AddIp(f, cidrs, overrideOutbound)
} else {
builder.AddSourceIp(f, cidrs, overrideOutbound)
}
case consts.Function_Mac:
var macAddrs [][6]byte
for _, v := range paramValueGroup {
mac, err := common.ParseMac(v)
if err != nil {
return err
}
macAddrs = append(macAddrs, mac)
}
builder.AddSourceMac(f, macAddrs, overrideOutbound)
case consts.Function_Port, consts.Function_SourcePort:
var portRanges [][2]uint16
for _, v := range paramValueGroup {
portRange, err := common.ParsePortRange(v)
if err != nil {
return err
}
portRanges = append(portRanges, portRange)
}
if f.Name == consts.Function_Port {
builder.AddPort(f, portRanges, overrideOutbound)
} else {
builder.AddSourcePort(f, portRanges, overrideOutbound)
}
case consts.Function_L4Proto:
var l4protoType consts.L4ProtoType
for _, v := range paramValueGroup {
switch v {
case "tcp":
l4protoType |= consts.L4ProtoType_TCP
case "udp":
l4protoType |= consts.L4ProtoType_UDP
}
}
builder.AddL4Proto(f, l4protoType, overrideOutbound)
case consts.Function_IpVersion:
var ipVersion consts.IpVersionType
for _, v := range paramValueGroup {
switch v {
case "4":
ipVersion |= consts.IpVersion_4
case "6":
ipVersion |= consts.IpVersion_6
}
}
builder.AddIpVersion(f, ipVersion, overrideOutbound)
case consts.Function_ProcessName:
var procNames [][consts.TaskCommLen]byte
for _, v := range paramValueGroup {
if len([]byte(v)) > consts.TaskCommLen {
log.Infof(`pname routing: trim "%v" to "%v" because it is too long.`, v, string([]byte(v)[:consts.TaskCommLen]))
}
procNames = append(procNames, ToProcessName(v))
}
builder.AddProcessName(f, procNames, overrideOutbound)
default:
return fmt.Errorf("unsupported function name: %v", f.Name)
}
builder.AddAnyAfter(f, key, paramValueGroup, overrideOutbound)
}
}
}
var rawFallback *config_parser.Function
switch fallback := fallbackOutbound.(type) {
case string:
rawFallback = &config_parser.Function{Name: fallback}
case *config_parser.Function:
rawFallback = fallback
default:
return fmt.Errorf("unknown type of 'fallback' in section routing: %T", fallback)
}
fallback, err := parseOutbound(rawFallback)
if err != nil {
return err
}
builder.AddAnyBefore(&config_parser.Function{
Name: "fallback",
}, "", nil, fallback)
builder.AddFallback(fallback)
builder.AddAnyAfter(&config_parser.Function{
Name: "fallback",
}, "", nil, fallback)
return nil
}
type DefaultMatcherBuilder struct {
}
func (d *DefaultMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddSourceMac(f *config_parser.Function, values [][6]byte, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddFallback(outbound *Outbound) {}
func (d *DefaultMatcherBuilder) AddAnyBefore(f *config_parser.Function, key string, values []string, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddAnyAfter(f *config_parser.Function, key string, values []string, outbound *Outbound) {
}

View File

@ -157,7 +157,7 @@ func (r *CryptoFrameRelocation) BytesFromPool() []byte {
return pool.Get(0)
}
right := r.o[len(r.o)-1]
return r.copyBytes(0, 0, len(r.o)-1, len(right.Data)-1, r.length)
return r.copyBytesToPool(0, 0, len(r.o)-1, len(right.Data)-1, r.length)
}
// RangeFromPool copy bytes from iUpperAppOffset to jUpperAppOffset.
@ -191,11 +191,11 @@ func (r *CryptoFrameRelocation) RangeFromPool(i, j int) []byte {
}
}
return r.copyBytes(iOuter, iInner, jOuter, jInner, j-i+1)
return r.copyBytesToPool(iOuter, iInner, jOuter, jInner, j-i+1)
}
// copyBytes copy bytes including i and j.
func (r *CryptoFrameRelocation) copyBytes(iOuter, iInner, jOuter, jInner, size int) []byte {
// copyBytesToPool copy bytes including i and j.
func (r *CryptoFrameRelocation) copyBytesToPool(iOuter, iInner, jOuter, jInner, size int) []byte {
b := pool.Get(size)
//io := r.o[iOuter]
k := 0

View File

@ -8,6 +8,7 @@ package sniffing
import (
"github.com/mzz2017/softwind/pool"
"io"
"sync"
)
type Sniffer struct {
@ -15,12 +16,13 @@ type Sniffer struct {
buf []byte
bufAt int
stream bool
readMu sync.Mutex
}
func NewStreamSniffer(r io.Reader, bufSize int) *Sniffer {
s := &Sniffer{
r: r,
buf: pool.Get(bufSize),
buf: make([]byte, bufSize),
stream: true,
}
return s
@ -37,6 +39,8 @@ func NewPacketSniffer(data []byte) *Sniffer {
type sniff func() (d string, err error)
func (s *Sniffer) SniffTcp() (d string, err error) {
s.readMu.Lock()
defer s.readMu.Unlock()
if s.stream {
n, err := s.r.Read(s.buf)
if err != nil {
@ -65,6 +69,8 @@ func (s *Sniffer) SniffTcp() (d string, err error) {
}
func (s *Sniffer) Read(p []byte) (n int, err error) {
s.readMu.Lock()
defer s.readMu.Unlock()
if s.buf != nil && s.bufAt < len(s.buf) {
// Read buf first.
n = copy(p, s.buf[s.bufAt:])
@ -84,6 +90,5 @@ func (s *Sniffer) Read(p []byte) (n int, err error) {
}
func (s *Sniffer) Close() (err error) {
// DO NOT use pool.Put() here because Close() may not interrupt the reading, which will modify the value of the pool buffer.
return nil
}