mirror of
https://github.com/daeuniverse/dae.git
synced 2025-07-04 15:27:55 +07:00
feat: dns routing (#26)
This commit is contained in:
203
component/dns/dns.go
Normal file
203
component/dns/dns.go
Normal file
@ -0,0 +1,203 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/component/routing"
|
||||
"github.com/v2rayA/dae/config"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var BadUpstreamFormatError = fmt.Errorf("bad upstream format")
|
||||
|
||||
type Dns struct {
|
||||
upstream []*UpstreamResolver
|
||||
upstream2IndexMu sync.Mutex
|
||||
upstream2Index map[*Upstream]int
|
||||
reqMatcher *RequestMatcher
|
||||
respMatcher *ResponseMatcher
|
||||
}
|
||||
|
||||
type NewOption struct {
|
||||
UpstreamReadyCallback func(raw *url.URL, upstream *Upstream) (err error)
|
||||
}
|
||||
|
||||
func New(log *logrus.Logger, dns *config.Dns, opt *NewOption) (s *Dns, err error) {
|
||||
s = &Dns{
|
||||
upstream2Index: map[*Upstream]int{
|
||||
nil: int(consts.DnsRequestOutboundIndex_AsIs),
|
||||
},
|
||||
}
|
||||
// Parse upstream.
|
||||
upstreamName2Id := map[string]uint8{}
|
||||
for i, upstreamRaw := range dns.Upstream {
|
||||
if i >= int(consts.DnsRequestOutboundIndex_UserDefinedMax) ||
|
||||
i >= int(consts.DnsResponseOutboundIndex_UserDefinedMax) {
|
||||
return nil, fmt.Errorf("too many upstreams")
|
||||
}
|
||||
|
||||
tag, link := common.GetTagFromLinkLikePlaintext(upstreamRaw)
|
||||
if tag == "" {
|
||||
return nil, fmt.Errorf("%w: '%v' has no tag", BadUpstreamFormatError, upstreamRaw)
|
||||
}
|
||||
u, err := url.Parse(link)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", BadUpstreamFormatError, err)
|
||||
}
|
||||
r := &UpstreamResolver{
|
||||
Raw: u,
|
||||
FinishInitCallback: func(i int) func(raw *url.URL, upstream *Upstream) (err error) {
|
||||
return func(raw *url.URL, upstream *Upstream) (err error) {
|
||||
if opt != nil && opt.UpstreamReadyCallback != nil {
|
||||
if err = opt.UpstreamReadyCallback(raw, upstream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.upstream2IndexMu.Lock()
|
||||
s.upstream2Index[upstream] = i
|
||||
s.upstream2IndexMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
}(i),
|
||||
}
|
||||
upstreamName2Id[tag] = uint8(len(s.upstream))
|
||||
s.upstream = append(s.upstream, r)
|
||||
}
|
||||
// Optimize routings.
|
||||
if dns.Routing.Request.Rules, err = routing.ApplyRulesOptimizers(dns.Routing.Request.Rules,
|
||||
&routing.DatReaderOptimizer{Logger: log},
|
||||
&routing.MergeAndSortRulesOptimizer{},
|
||||
&routing.DeduplicateParamsOptimizer{},
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dns.Routing.Response.Rules, err = routing.ApplyRulesOptimizers(dns.Routing.Response.Rules,
|
||||
&routing.DatReaderOptimizer{Logger: log},
|
||||
&routing.MergeAndSortRulesOptimizer{},
|
||||
&routing.DeduplicateParamsOptimizer{},
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Parse request routing.
|
||||
reqMatcherBuilder, err := NewRequestMatcherBuilder(log, dns.Routing.Request.Rules, upstreamName2Id, dns.Routing.Request.Fallback)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build DNS request routing: %w", err)
|
||||
}
|
||||
s.reqMatcher, err = reqMatcherBuilder.Build()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build DNS request routing: %w", err)
|
||||
}
|
||||
// Parse response routing.
|
||||
respMatcherBuilder, err := NewResponseMatcherBuilder(log, dns.Routing.Response.Rules, upstreamName2Id, dns.Routing.Response.Fallback)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build DNS response routing: %w", err)
|
||||
}
|
||||
s.respMatcher, err = respMatcherBuilder.Build()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build DNS response routing: %w", err)
|
||||
}
|
||||
if len(dns.Upstream) == 0 {
|
||||
// Immediately ready.
|
||||
if err = opt.UpstreamReadyCallback(nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Dns) RequestSelect(msg *dnsmessage.Message) (upstream *Upstream, err error) {
|
||||
if msg.Response {
|
||||
return nil, fmt.Errorf("DNS request expected but DNS response received")
|
||||
}
|
||||
|
||||
// Prepare routing.
|
||||
var qname string
|
||||
var qtype dnsmessage.Type
|
||||
if len(msg.Questions) == 0 {
|
||||
qname = ""
|
||||
qtype = 0
|
||||
} else {
|
||||
q := msg.Questions[0]
|
||||
qname = q.Name.String()
|
||||
qtype = q.Type
|
||||
}
|
||||
// Route.
|
||||
upstreamIndex, err := s.reqMatcher.Match(qname, qtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// nil indicates AsIs.
|
||||
if upstreamIndex == consts.DnsRequestOutboundIndex_AsIs {
|
||||
return nil, nil
|
||||
}
|
||||
if int(upstreamIndex) >= len(s.upstream) {
|
||||
return nil, fmt.Errorf("bad upstream index: %v not in [0, %v]", upstreamIndex, len(s.upstream)-1)
|
||||
}
|
||||
// Get corresponding upstream.
|
||||
upstream, err = s.upstream[upstreamIndex].GetUpstream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return upstream, nil
|
||||
}
|
||||
|
||||
func (s *Dns) ResponseSelect(msg *dnsmessage.Message, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) {
|
||||
if !msg.Response {
|
||||
return 0, nil, fmt.Errorf("DNS response expected but DNS request received")
|
||||
}
|
||||
|
||||
// Prepare routing.
|
||||
var qname string
|
||||
var qtype dnsmessage.Type
|
||||
var ips []netip.Addr
|
||||
if len(msg.Questions) == 0 {
|
||||
qname = ""
|
||||
qtype = 0
|
||||
} else {
|
||||
q := msg.Questions[0]
|
||||
qname = q.Name.String()
|
||||
qtype = q.Type
|
||||
for _, ans := range msg.Answers {
|
||||
switch body := ans.Body.(type) {
|
||||
case *dnsmessage.AResource:
|
||||
ips = append(ips, netip.AddrFrom4(body.A))
|
||||
case *dnsmessage.AAAAResource:
|
||||
ips = append(ips, netip.AddrFrom16(body.AAAA))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.upstream2IndexMu.Lock()
|
||||
from := s.upstream2Index[fromUpstream]
|
||||
s.upstream2IndexMu.Unlock()
|
||||
// Route.
|
||||
upstreamIndex, err = s.respMatcher.Match(qname, qtype, ips, consts.DnsRequestOutboundIndex(from))
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
// Get corresponding upstream if upstream is neither 'accept' nor 'reject'.
|
||||
if !upstreamIndex.IsReserved() {
|
||||
if int(upstreamIndex) >= len(s.upstream) {
|
||||
return 0, nil, fmt.Errorf("bad upstream index: %v not in [0, %v]", upstreamIndex, len(s.upstream)-1)
|
||||
}
|
||||
upstream, err = s.upstream[upstreamIndex].GetUpstream()
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
} else {
|
||||
// Assign explicitly to let coder know.
|
||||
upstream = nil
|
||||
}
|
||||
return upstreamIndex, upstream, nil
|
||||
}
|
47
component/dns/function_parser.go
Normal file
47
component/dns/function_parser.go
Normal file
@ -0,0 +1,47 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/component/routing"
|
||||
"github.com/v2rayA/dae/pkg/config_parser"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var typeNames = map[string]dnsmessage.Type{
|
||||
"A": dnsmessage.TypeA,
|
||||
"NS": dnsmessage.TypeNS,
|
||||
"CNAME": dnsmessage.TypeCNAME,
|
||||
"SOA": dnsmessage.TypeSOA,
|
||||
"PTR": dnsmessage.TypePTR,
|
||||
"MX": dnsmessage.TypeMX,
|
||||
"TXT": dnsmessage.TypeTXT,
|
||||
"AAAA": dnsmessage.TypeAAAA,
|
||||
"SRV": dnsmessage.TypeSRV,
|
||||
"OPT": dnsmessage.TypeOPT,
|
||||
"WKS": dnsmessage.TypeWKS,
|
||||
"HINFO": dnsmessage.TypeHINFO,
|
||||
"MINFO": dnsmessage.TypeMINFO,
|
||||
"AXFR": dnsmessage.TypeAXFR,
|
||||
"ALL": dnsmessage.TypeALL,
|
||||
}
|
||||
|
||||
func TypeParserFactory(callback func(f *config_parser.Function, types []dnsmessage.Type, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) {
|
||||
var types []dnsmessage.Type
|
||||
for _, v := range paramValueGroup {
|
||||
t, ok := typeNames[strings.ToUpper(v)]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown DNS request type: %v", v)
|
||||
}
|
||||
types = append(types, t)
|
||||
}
|
||||
return callback(f, types, overrideOutbound)
|
||||
}
|
||||
}
|
213
component/dns/request_routing.go
Normal file
213
component/dns/request_routing.go
Normal file
@ -0,0 +1,213 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/component/routing"
|
||||
"github.com/v2rayA/dae/component/routing/domain_matcher"
|
||||
"github.com/v2rayA/dae/config"
|
||||
"github.com/v2rayA/dae/pkg/config_parser"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type RequestMatcherBuilder struct {
|
||||
upstreamName2Id map[string]uint8
|
||||
simulatedDomainSet []routing.DomainSet
|
||||
fallback *routing.Outbound
|
||||
rules []requestMatchSet
|
||||
}
|
||||
|
||||
func NewRequestMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, upstreamName2Id map[string]uint8, fallback config.FunctionOrString) (b *RequestMatcherBuilder, err error) {
|
||||
b = &RequestMatcherBuilder{upstreamName2Id: upstreamName2Id}
|
||||
rulesBuilder := routing.NewRulesBuilder(log)
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_QName, routing.PlainParserFactory(b.addQName))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_QType, TypeParserFactory(b.addQType))
|
||||
if err = rulesBuilder.Apply(rules); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = b.addFallback(fallback); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *RequestMatcherBuilder) upstreamToId(upstream string) (upstreamId consts.DnsRequestOutboundIndex, err error) {
|
||||
switch upstream {
|
||||
case consts.DnsRequestOutboundIndex_AsIs.String():
|
||||
upstreamId = consts.DnsRequestOutboundIndex_AsIs
|
||||
case consts.DnsRequestOutboundIndex_LogicalAnd.String():
|
||||
upstreamId = consts.DnsRequestOutboundIndex_LogicalAnd
|
||||
case consts.DnsRequestOutboundIndex_LogicalOr.String():
|
||||
upstreamId = consts.DnsRequestOutboundIndex_LogicalOr
|
||||
default:
|
||||
_upstreamId, ok := b.upstreamName2Id[upstream]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("upstream %v not found; please define it in section \"dns.upstream\"", strconv.Quote(upstream))
|
||||
}
|
||||
upstreamId = consts.DnsRequestOutboundIndex(_upstreamId)
|
||||
}
|
||||
return upstreamId, nil
|
||||
}
|
||||
|
||||
func (b *RequestMatcherBuilder) addQName(f *config_parser.Function, key string, values []string, upstream *routing.Outbound) (err error) {
|
||||
switch consts.RoutingDomainKey(key) {
|
||||
case consts.RoutingDomainKey_Regex,
|
||||
consts.RoutingDomainKey_Full,
|
||||
consts.RoutingDomainKey_Keyword,
|
||||
consts.RoutingDomainKey_Suffix:
|
||||
default:
|
||||
return fmt.Errorf("addQName: unsupported key: %v", key)
|
||||
}
|
||||
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
|
||||
Key: consts.RoutingDomainKey(key),
|
||||
RuleIndex: len(b.simulatedDomainSet),
|
||||
Domains: values,
|
||||
})
|
||||
upstreamId, err := b.upstreamToId(upstream.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, requestMatchSet{
|
||||
Type: consts.MatchType_DomainSet,
|
||||
Not: f.Not,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
|
||||
for i, value := range values {
|
||||
upstreamName := consts.OutboundLogicalOr.String()
|
||||
if i == len(values)-1 {
|
||||
upstreamName = upstream.Name
|
||||
}
|
||||
upstreamId, err := b.upstreamToId(upstreamName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, requestMatchSet{
|
||||
Type: consts.MatchType_QType,
|
||||
Value: uint16(value),
|
||||
Not: f.Not,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RequestMatcherBuilder) addFallback(fallbackOutbound config.FunctionOrString) (err error) {
|
||||
upstream, err := routing.ParseOutbound(config.FunctionOrStringToFunction(fallbackOutbound))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
upstreamId, err := b.upstreamToId(upstream.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, requestMatchSet{
|
||||
Type: consts.MatchType_Fallback,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RequestMatcherBuilder) Build() (matcher *RequestMatcher, err error) {
|
||||
var m RequestMatcher
|
||||
// Build domainMatcher
|
||||
m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen)
|
||||
for _, domains := range b.simulatedDomainSet {
|
||||
m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
|
||||
}
|
||||
if err = m.domainMatcher.Build(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write routings.
|
||||
// Fallback rule MUST be the last.
|
||||
if b.rules[len(b.rules)-1].Type != consts.MatchType_Fallback {
|
||||
return nil, fmt.Errorf("fallback rule MUST be the last")
|
||||
}
|
||||
m.matches = b.rules
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
type RequestMatcher struct {
|
||||
domainMatcher routing.DomainMatcher // All domain matchSets use one DomainMatcher.
|
||||
|
||||
matches []requestMatchSet
|
||||
}
|
||||
|
||||
type requestMatchSet struct {
|
||||
Value uint16
|
||||
Not bool
|
||||
Type consts.MatchType
|
||||
Upstream uint8
|
||||
}
|
||||
|
||||
func (m *RequestMatcher) Match(
|
||||
qName string,
|
||||
qType dnsmessage.Type,
|
||||
) (upstreamIndex consts.DnsRequestOutboundIndex, err error) {
|
||||
var domainMatchBitmap []uint32
|
||||
if qName != "" {
|
||||
domainMatchBitmap = m.domainMatcher.MatchDomainBitmap(qName)
|
||||
}
|
||||
|
||||
goodSubrule := false
|
||||
badRule := false
|
||||
for i, match := range m.matches {
|
||||
if badRule || goodSubrule {
|
||||
goto beforeNextLoop
|
||||
}
|
||||
switch match.Type {
|
||||
case consts.MatchType_DomainSet:
|
||||
if domainMatchBitmap != nil && (domainMatchBitmap[i/32]>>(i%32))&1 > 0 {
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_QType:
|
||||
if qType == dnsmessage.Type(match.Value) {
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_Fallback:
|
||||
goodSubrule = true
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown match type: %v", match.Type)
|
||||
}
|
||||
beforeNextLoop:
|
||||
upstream := consts.DnsRequestOutboundIndex(match.Upstream)
|
||||
if upstream != consts.DnsRequestOutboundIndex_LogicalOr {
|
||||
// This match_set reaches the end of subrule.
|
||||
// We are now at end of rule, or next match_set belongs to another
|
||||
// subrule.
|
||||
|
||||
if goodSubrule == match.Not {
|
||||
// This subrule does not hit.
|
||||
badRule = true
|
||||
}
|
||||
|
||||
// Reset goodSubrule.
|
||||
goodSubrule = false
|
||||
}
|
||||
|
||||
if upstream&consts.DnsRequestOutboundIndex_LogicalMask !=
|
||||
consts.DnsRequestOutboundIndex_LogicalMask {
|
||||
// Tail of a rule (line).
|
||||
// Decide whether to hit.
|
||||
if !badRule {
|
||||
return upstream, nil
|
||||
}
|
||||
badRule = false
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("no match set hit")
|
||||
}
|
320
component/dns/response_routing.go
Normal file
320
component/dns/response_routing.go
Normal file
@ -0,0 +1,320 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/component/routing"
|
||||
"github.com/v2rayA/dae/component/routing/domain_matcher"
|
||||
"github.com/v2rayA/dae/config"
|
||||
"github.com/v2rayA/dae/pkg/config_parser"
|
||||
"github.com/v2rayA/dae/pkg/trie"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ValidCidrChars = trie.NewValidChars([]byte{'0', '1'})
|
||||
|
||||
type ResponseMatcherBuilder struct {
|
||||
upstreamName2Id map[string]uint8
|
||||
simulatedDomainSet []routing.DomainSet
|
||||
ipSet []*trie.Trie
|
||||
fallback *routing.Outbound
|
||||
rules []responseMatchSet
|
||||
}
|
||||
|
||||
func NewResponseMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, upstreamName2Id map[string]uint8, fallback config.FunctionOrString) (b *ResponseMatcherBuilder, err error) {
|
||||
b = &ResponseMatcherBuilder{upstreamName2Id: upstreamName2Id}
|
||||
rulesBuilder := routing.NewRulesBuilder(log)
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_QName, routing.PlainParserFactory(b.addQName))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_QType, TypeParserFactory(b.addQType))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_Ip, routing.IpParserFactory(b.addIp))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_Upstream, routing.EmptyKeyPlainParserFactory(b.addUpstream))
|
||||
if err = rulesBuilder.Apply(rules); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = b.addFallback(fallback); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) upstreamToId(upstream string) (upstreamId consts.DnsResponseOutboundIndex, err error) {
|
||||
switch upstream {
|
||||
case consts.DnsResponseOutboundIndex_Accept.String():
|
||||
upstreamId = consts.DnsResponseOutboundIndex_Accept
|
||||
case consts.DnsResponseOutboundIndex_Reject.String():
|
||||
upstreamId = consts.DnsResponseOutboundIndex_Reject
|
||||
case consts.DnsResponseOutboundIndex_LogicalAnd.String():
|
||||
upstreamId = consts.DnsResponseOutboundIndex_LogicalAnd
|
||||
case consts.DnsResponseOutboundIndex_LogicalOr.String():
|
||||
upstreamId = consts.DnsResponseOutboundIndex_LogicalOr
|
||||
default:
|
||||
_upstreamId, ok := b.upstreamName2Id[upstream]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("upstream %v not found; please define it in \"dns.upstream\"", strconv.Quote(upstream))
|
||||
}
|
||||
upstreamId = consts.DnsResponseOutboundIndex(_upstreamId)
|
||||
}
|
||||
return upstreamId, nil
|
||||
}
|
||||
|
||||
func prefix2bin128(prefix netip.Prefix) (bin128 string) {
|
||||
bits := prefix.Bits()
|
||||
if prefix.Addr().Is4() {
|
||||
bits += 96
|
||||
}
|
||||
ip := prefix.Addr().As16()
|
||||
buf := buffer.NewBuffer(128)
|
||||
defer buf.Put()
|
||||
loop:
|
||||
for i := 0; i < len(ip); i++ {
|
||||
for j := 0; j < 8; j++ {
|
||||
if (ip[i]>>j)&1 == 1 {
|
||||
buf.WriteByte('1')
|
||||
} else {
|
||||
buf.WriteByte('0')
|
||||
}
|
||||
bits--
|
||||
if bits == 0 {
|
||||
break loop
|
||||
}
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) addIp(f *config_parser.Function, cidrs []netip.Prefix, upstream *routing.Outbound) (err error) {
|
||||
upstreamId, err := b.upstreamToId(upstream.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rule := responseMatchSet{
|
||||
Value: uint16(len(b.ipSet)),
|
||||
Type: consts.MatchType_IpSet,
|
||||
Not: f.Not,
|
||||
Upstream: uint8(upstreamId),
|
||||
}
|
||||
var keys []string
|
||||
// Convert netip.Prefix -> '0' '1' string
|
||||
for _, prefix := range cidrs {
|
||||
keys = append(keys, prefix2bin128(prefix))
|
||||
}
|
||||
t, err := trie.NewTrie(keys, ValidCidrChars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.ipSet = append(b.ipSet, t)
|
||||
b.rules = append(b.rules, rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) addQName(f *config_parser.Function, key string, values []string, upstream *routing.Outbound) (err error) {
|
||||
switch consts.RoutingDomainKey(key) {
|
||||
case consts.RoutingDomainKey_Regex,
|
||||
consts.RoutingDomainKey_Full,
|
||||
consts.RoutingDomainKey_Keyword,
|
||||
consts.RoutingDomainKey_Suffix:
|
||||
default:
|
||||
return fmt.Errorf("addQName: unsupported key: %v", key)
|
||||
}
|
||||
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
|
||||
Key: consts.RoutingDomainKey(key),
|
||||
RuleIndex: len(b.simulatedDomainSet),
|
||||
Domains: values,
|
||||
})
|
||||
upstreamId, err := b.upstreamToId(upstream.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, responseMatchSet{
|
||||
Type: consts.MatchType_DomainSet,
|
||||
Not: f.Not,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) addUpstream(f *config_parser.Function, values []string, upstream *routing.Outbound) (err error) {
|
||||
for i, value := range values {
|
||||
upstreamName := consts.OutboundLogicalOr.String()
|
||||
if i == len(values)-1 {
|
||||
upstreamName = upstream.Name
|
||||
}
|
||||
upstreamId, err := b.upstreamToId(upstreamName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lastUpstreamId, err := b.upstreamToId(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, responseMatchSet{
|
||||
Type: consts.MatchType_Upstream,
|
||||
Value: uint16(lastUpstreamId),
|
||||
Not: f.Not,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
|
||||
for i, value := range values {
|
||||
upstreamName := consts.OutboundLogicalOr.String()
|
||||
if i == len(values)-1 {
|
||||
upstreamName = upstream.Name
|
||||
}
|
||||
upstreamId, err := b.upstreamToId(upstreamName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, responseMatchSet{
|
||||
Type: consts.MatchType_QType,
|
||||
Value: uint16(value),
|
||||
Not: f.Not,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) addFallback(fallbackOutbound config.FunctionOrString) (err error) {
|
||||
upstream, err := routing.ParseOutbound(config.FunctionOrStringToFunction(fallbackOutbound))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
upstreamId, err := b.upstreamToId(upstream.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, responseMatchSet{
|
||||
Type: consts.MatchType_Fallback,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) Build() (matcher *ResponseMatcher, err error) {
|
||||
var m ResponseMatcher
|
||||
// Build domainMatcher.
|
||||
m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen)
|
||||
for _, domains := range b.simulatedDomainSet {
|
||||
m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
|
||||
}
|
||||
if err = m.domainMatcher.Build(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// IpSet.
|
||||
m.ipSet = b.ipSet
|
||||
|
||||
// Write routings.
|
||||
// Fallback rule MUST be the last.
|
||||
if b.rules[len(b.rules)-1].Type != consts.MatchType_Fallback {
|
||||
return nil, fmt.Errorf("fallback rule MUST be the last")
|
||||
}
|
||||
m.matches = b.rules
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
type ResponseMatcher struct {
|
||||
domainMatcher routing.DomainMatcher // All domain matchSets use one DomainMatcher.
|
||||
ipSet []*trie.Trie
|
||||
|
||||
matches []responseMatchSet
|
||||
}
|
||||
|
||||
type responseMatchSet struct {
|
||||
Value uint16
|
||||
Not bool
|
||||
Type consts.MatchType
|
||||
Upstream uint8
|
||||
}
|
||||
|
||||
func (m *ResponseMatcher) Match(
|
||||
qName string,
|
||||
qType dnsmessage.Type,
|
||||
ips []netip.Addr,
|
||||
upstream consts.DnsRequestOutboundIndex,
|
||||
) (upstreamIndex consts.DnsResponseOutboundIndex, err error) {
|
||||
if qName == "" {
|
||||
return 0, fmt.Errorf("qName cannot be empty")
|
||||
}
|
||||
qName = strings.TrimSuffix(strings.ToLower(qName), ".")
|
||||
domainMatchBitmap := m.domainMatcher.MatchDomainBitmap(qName)
|
||||
bin128 := make([]string, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
bin128 = append(bin128, prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(ip.As16()), 128)))
|
||||
}
|
||||
|
||||
goodSubrule := false
|
||||
badRule := false
|
||||
for i, match := range m.matches {
|
||||
if badRule || goodSubrule {
|
||||
goto beforeNextLoop
|
||||
}
|
||||
switch match.Type {
|
||||
case consts.MatchType_DomainSet:
|
||||
if domainMatchBitmap != nil && (domainMatchBitmap[i/32]>>(i%32))&1 > 0 {
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_IpSet:
|
||||
for _, bin128 := range bin128 {
|
||||
// Check if any of IP hit the rule.
|
||||
if m.ipSet[match.Value].HasPrefix(bin128) {
|
||||
goodSubrule = true
|
||||
break
|
||||
}
|
||||
}
|
||||
case consts.MatchType_QType:
|
||||
if qType == dnsmessage.Type(match.Value) {
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_Upstream:
|
||||
if upstream == consts.DnsRequestOutboundIndex(match.Value) {
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_Fallback:
|
||||
goodSubrule = true
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown match type: %v", match.Type)
|
||||
}
|
||||
beforeNextLoop:
|
||||
upstream := consts.DnsResponseOutboundIndex(match.Upstream)
|
||||
if upstream != consts.DnsResponseOutboundIndex_LogicalOr {
|
||||
// This match_set reaches the end of subrule.
|
||||
// We are now at end of rule, or next match_set belongs to another
|
||||
// subrule.
|
||||
|
||||
if goodSubrule == match.Not {
|
||||
// This subrule does not hit.
|
||||
badRule = true
|
||||
}
|
||||
|
||||
// Reset goodSubrule.
|
||||
goodSubrule = false
|
||||
}
|
||||
|
||||
if upstream&consts.DnsResponseOutboundIndex_LogicalMask !=
|
||||
consts.DnsResponseOutboundIndex_LogicalMask {
|
||||
// Tail of a rule (line).
|
||||
// Decide whether to hit.
|
||||
if !badRule {
|
||||
return upstream, nil
|
||||
}
|
||||
badRule = false
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("no match set hit")
|
||||
}
|
153
component/dns/upstream.go
Normal file
153
component/dns/upstream.go
Normal file
@ -0,0 +1,153 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2022-2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/mzz2017/softwind/protocol/direct"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/common/netutils"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UpstreamScheme string
|
||||
|
||||
const (
|
||||
UpstreamScheme_TCP UpstreamScheme = "tcp"
|
||||
UpstreamScheme_UDP UpstreamScheme = "udp"
|
||||
UpstreamScheme_TCP_UDP UpstreamScheme = "tcp+udp"
|
||||
)
|
||||
|
||||
func (s UpstreamScheme) ContainsTcp() bool {
|
||||
switch s {
|
||||
case UpstreamScheme_TCP,
|
||||
UpstreamScheme_TCP_UDP:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func ParseRawUpstream(raw *url.URL) (scheme UpstreamScheme, hostname string, port uint16, err error) {
|
||||
var __port string
|
||||
switch scheme = UpstreamScheme(raw.Scheme); scheme {
|
||||
case UpstreamScheme_TCP, UpstreamScheme_UDP, UpstreamScheme_TCP_UDP:
|
||||
__port = raw.Port()
|
||||
if __port == "" {
|
||||
__port = "53"
|
||||
}
|
||||
default:
|
||||
return "", "", 0, fmt.Errorf("unexpected dns_upstream format")
|
||||
}
|
||||
_port, err := strconv.ParseUint(__port, 10, 16)
|
||||
if err != nil {
|
||||
return "", "", 0, fmt.Errorf("failed to parse dns_upstream port: %v", err)
|
||||
}
|
||||
port = uint16(_port)
|
||||
hostname = raw.Hostname()
|
||||
return scheme, hostname, port, nil
|
||||
}
|
||||
|
||||
type Upstream struct {
|
||||
Scheme UpstreamScheme
|
||||
Hostname string
|
||||
Port uint16
|
||||
*netutils.Ip46
|
||||
}
|
||||
|
||||
func NewUpstream(ctx context.Context, upstream *url.URL) (up *Upstream, err error) {
|
||||
scheme, hostname, port, err := ParseRawUpstream(upstream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
systemDns, err := netutils.SystemDns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = netutils.TryUpdateSystemDns1s()
|
||||
}
|
||||
}()
|
||||
|
||||
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve dns_upstream: %w", err)
|
||||
}
|
||||
if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() {
|
||||
return nil, fmt.Errorf("dns_upstream has no record")
|
||||
}
|
||||
|
||||
return &Upstream{
|
||||
Scheme: scheme,
|
||||
Hostname: hostname,
|
||||
Port: port,
|
||||
Ip46: ip46,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *Upstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4protos []consts.L4ProtoStr) {
|
||||
if u.Ip4.IsValid() && u.Ip6.IsValid() {
|
||||
ipversions = []consts.IpVersionStr{consts.IpVersionStr_4, consts.IpVersionStr_6}
|
||||
} else {
|
||||
if u.Ip4.IsValid() {
|
||||
ipversions = []consts.IpVersionStr{consts.IpVersionStr_4}
|
||||
} else {
|
||||
ipversions = []consts.IpVersionStr{consts.IpVersionStr_6}
|
||||
}
|
||||
}
|
||||
switch u.Scheme {
|
||||
case UpstreamScheme_TCP:
|
||||
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_TCP}
|
||||
case UpstreamScheme_UDP:
|
||||
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP}
|
||||
case UpstreamScheme_TCP_UDP:
|
||||
// UDP first.
|
||||
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP, consts.L4ProtoStr_TCP}
|
||||
}
|
||||
return ipversions, l4protos
|
||||
}
|
||||
|
||||
func (u *Upstream) String() string {
|
||||
return string(u.Scheme) + "://" + net.JoinHostPort(u.Hostname, strconv.Itoa(int(u.Port)))
|
||||
}
|
||||
|
||||
type UpstreamResolver struct {
|
||||
Raw *url.URL
|
||||
// FinishInitCallback may be invoked again if err is not nil
|
||||
FinishInitCallback func(raw *url.URL, upstream *Upstream) (err error)
|
||||
mu sync.Mutex
|
||||
upstream *Upstream
|
||||
init bool
|
||||
}
|
||||
|
||||
func (u *UpstreamResolver) GetUpstream() (_ *Upstream, err error) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
if !u.init {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
if err = u.FinishInitCallback(u.Raw, u.upstream); err != nil {
|
||||
u.upstream = nil
|
||||
return
|
||||
}
|
||||
u.init = true
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
|
||||
defer cancel()
|
||||
if u.upstream, err = NewUpstream(ctx, u.Raw); err != nil {
|
||||
return nil, fmt.Errorf("failed to init dns upstream: %v", err)
|
||||
}
|
||||
}
|
||||
return u.upstream, nil
|
||||
}
|
@ -90,11 +90,8 @@ func (a *AliveDialerSet) GetMinLatency() (d *Dialer, latency time.Duration) {
|
||||
}
|
||||
|
||||
func (a *AliveDialerSet) printLatencies() {
|
||||
if !a.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
return
|
||||
}
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintf("%v (%v):\n", a.dialerGroupName, a.CheckTyp.String()))
|
||||
builder.WriteString(fmt.Sprintf("Group '%v' [%v]:\n", a.dialerGroupName, a.CheckTyp.String()))
|
||||
for _, d := range a.inorderedAliveDialerSet {
|
||||
latency, ok := a.dialerToLatency[d]
|
||||
if !ok {
|
||||
@ -210,9 +207,13 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
|
||||
string(a.selectionPolicy): a.minLatency.latency,
|
||||
"group": a.dialerGroupName,
|
||||
"network": a.CheckTyp.String(),
|
||||
"new dialer": a.minLatency.dialer.Name(),
|
||||
"old dialer": oldDialerName,
|
||||
"new_dialer": a.minLatency.dialer.Name(),
|
||||
"old_dialer": oldDialerName,
|
||||
}).Infof("Group %vselects dialer", re)
|
||||
|
||||
if a.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
a.printLatencies()
|
||||
}
|
||||
} else {
|
||||
// Alive -> not alive
|
||||
defer a.aliveChangeCallback(false)
|
||||
@ -221,9 +222,6 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
|
||||
"network": a.CheckTyp.String(),
|
||||
}).Infof("Group has no dialer alive")
|
||||
}
|
||||
if a.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
a.printLatencies()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if alive && minPolicy && a.minLatency.dialer == nil {
|
||||
|
@ -118,7 +118,7 @@ func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOptio
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ip46, err := netutils.ParseIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), false)
|
||||
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -153,7 +153,7 @@ func ParseCheckDnsOption(ctx context.Context, dnsHostPort string) (opt *CheckDns
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bad port: %v", err)
|
||||
}
|
||||
ip46, err := netutils.ParseIp46(ctx, direct.SymmetricDirect, systemDns, host, false)
|
||||
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, host, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -409,6 +409,10 @@ func (d *Dialer) NotifyCheck() {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dialer) MustGetLatencies10(typ *NetworkType) *LatenciesN {
|
||||
return d.mustGetCollection(typ).Latencies10
|
||||
}
|
||||
|
||||
// RegisterAliveDialerSet is thread-safe.
|
||||
func (d *Dialer) RegisterAliveDialerSet(a *AliveDialerSet) {
|
||||
if a == nil {
|
||||
|
@ -19,7 +19,7 @@ type DialerSelectionPolicy struct {
|
||||
FixedIndex int
|
||||
}
|
||||
|
||||
func NewDialerSelectionPolicyFromGroupParam(param *config.GroupParam) (policy *DialerSelectionPolicy, err error) {
|
||||
func NewDialerSelectionPolicyFromGroupParam(param *config.Group) (policy *DialerSelectionPolicy, err error) {
|
||||
switch val := param.Policy.(type) {
|
||||
case string:
|
||||
switch consts.DialerSelectionPolicy(val) {
|
||||
|
@ -15,6 +15,8 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ValidDomainChars = trie.NewValidChars([]byte("0123456789abcdefghijklmnopqrstuvwxyz-.^"))
|
||||
|
||||
type AhocorasickSlimtrie struct {
|
||||
validAcIndexes []int
|
||||
validTrieIndexes []int
|
||||
@ -173,7 +175,7 @@ func (n *AhocorasickSlimtrie) Build() (err error) {
|
||||
}
|
||||
toBuild = ToSuffixTrieStrings(toBuild)
|
||||
sort.Strings(toBuild)
|
||||
n.trie[i], err = trie.NewTrie(toBuild)
|
||||
n.trie[i], err = trie.NewTrie(toBuild, ValidDomainChars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -94,7 +94,6 @@ var TestSample = []string{
|
||||
}
|
||||
|
||||
type RoutingMatcherBuilder struct {
|
||||
*routing.DefaultMatcherBuilder
|
||||
outboundName2Id map[string]uint8
|
||||
simulatedDomainSet []routing.DomainSet
|
||||
Fallback string
|
||||
@ -118,7 +117,7 @@ func (b *RoutingMatcherBuilder) AddDomain(f *config_parser.Function, key string,
|
||||
consts.RoutingDomainKey_Keyword,
|
||||
consts.RoutingDomainKey_Suffix:
|
||||
default:
|
||||
b.err = fmt.Errorf("AddDomain: unsupported key: %v", key)
|
||||
b.err = fmt.Errorf("addDomain: unsupported key: %v", key)
|
||||
return
|
||||
}
|
||||
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
|
||||
@ -132,22 +131,16 @@ func getDomain() (simulatedDomainSet []routing.DomainSet, err error) {
|
||||
var rules []*config_parser.RoutingRule
|
||||
sections, err := config_parser.Parse(`
|
||||
routing {
|
||||
pname(NetworkManager, dnsmasq, systemd-resolved) -> must_direct # Traffic of DNS in local must be direct to avoid loop when binding to WAN.
|
||||
pname(sogou-qimpanel, sogou-qimpanel-watchdog) -> block
|
||||
ip(geoip:private, 224.0.0.0/3, 'ff00::/8') -> direct # Put it in front unless you know what you're doing.
|
||||
domain(geosite:bing)->us
|
||||
domain(full:dns.google) && port(53) -> direct
|
||||
domain(full:dns.google) -> direct
|
||||
domain(geosite:category-ads-all) -> block
|
||||
ip(geoip:private) -> direct
|
||||
ip(geoip:cn) -> direct
|
||||
domain(geosite:cn) -> direct
|
||||
fallback: my_group
|
||||
}`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var r config.Routing
|
||||
if err = config.RoutingRuleAndParamParser(reflect.ValueOf(&r), sections[0]); err != nil {
|
||||
if err = config.SectionParser(reflect.ValueOf(&r), sections[0]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rules, err = routing.ApplyRulesOptimizers(r.Rules,
|
||||
@ -159,8 +152,13 @@ routing {
|
||||
return nil, fmt.Errorf("ApplyRulesOptimizers error:\n%w", err)
|
||||
}
|
||||
builder := RoutingMatcherBuilder{}
|
||||
if err = routing.ApplyMatcherBuilder(logrus.StandardLogger(), &builder, rules, r.Fallback); err != nil {
|
||||
return nil, fmt.Errorf("ApplyMatcherBuilder: %w", err)
|
||||
rb := routing.NewRulesBuilder(logrus.StandardLogger())
|
||||
rb.RegisterFunctionParser("domain", func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) {
|
||||
builder.AddDomain(f, key, paramValueGroup, overrideOutbound)
|
||||
return nil
|
||||
})
|
||||
if err = rb.Apply(rules); err != nil {
|
||||
return nil, fmt.Errorf("Apply: %w", err)
|
||||
}
|
||||
return builder.simulatedDomainSet, nil
|
||||
}
|
||||
|
139
component/routing/function_parser.go
Normal file
139
component/routing/function_parser.go
Normal file
@ -0,0 +1,139 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/pkg/config_parser"
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type FunctionParser func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error)
|
||||
|
||||
// Preset function parser factories.
|
||||
|
||||
// PlainParserFactory is for style unity.
|
||||
func PlainParserFactory(callback func(f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
return callback(f, key, paramValueGroup, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
// EmptyKeyPlainParserFactory only accepts function with empty key.
|
||||
func EmptyKeyPlainParserFactory(callback func(f *config_parser.Function, values []string, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
if key != "" {
|
||||
return fmt.Errorf("this function cannot accept a key")
|
||||
}
|
||||
return callback(f, paramValueGroup, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func IpParserFactory(callback func(f *config_parser.Function, cidrs []netip.Prefix, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
cidrs, err := parsePrefixes(paramValueGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return callback(f, cidrs, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func MacParserFactory(callback func(f *config_parser.Function, macAddrs [][6]byte, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
var macAddrs [][6]byte
|
||||
for _, v := range paramValueGroup {
|
||||
mac, err := common.ParseMac(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
macAddrs = append(macAddrs, mac)
|
||||
}
|
||||
return callback(f, macAddrs, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func PortRangeParserFactory(callback func(f *config_parser.Function, portRanges [][2]uint16, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
var portRanges [][2]uint16
|
||||
for _, v := range paramValueGroup {
|
||||
portRange, err := common.ParsePortRange(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
portRanges = append(portRanges, portRange)
|
||||
}
|
||||
return callback(f, portRanges, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func L4ProtoParserFactory(callback func(f *config_parser.Function, l4protoType consts.L4ProtoType, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
var l4protoType consts.L4ProtoType
|
||||
for _, v := range paramValueGroup {
|
||||
switch v {
|
||||
case "tcp":
|
||||
l4protoType |= consts.L4ProtoType_TCP
|
||||
case "udp":
|
||||
l4protoType |= consts.L4ProtoType_UDP
|
||||
}
|
||||
}
|
||||
return callback(f, l4protoType, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func IpVersionParserFactory(callback func(f *config_parser.Function, ipVersion consts.IpVersionType, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
var ipVersion consts.IpVersionType
|
||||
for _, v := range paramValueGroup {
|
||||
switch v {
|
||||
case "4":
|
||||
ipVersion |= consts.IpVersion_4
|
||||
case "6":
|
||||
ipVersion |= consts.IpVersion_6
|
||||
}
|
||||
}
|
||||
return callback(f, ipVersion, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func ProcessNameParserFactory(callback func(f *config_parser.Function, procNames [][consts.TaskCommLen]byte, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
var procNames [][consts.TaskCommLen]byte
|
||||
for _, v := range paramValueGroup {
|
||||
if len([]byte(v)) > consts.TaskCommLen {
|
||||
log.Infof(`pname routing: trim "%v" to "%v" because it is too long.`, v, string([]byte(v)[:consts.TaskCommLen]))
|
||||
}
|
||||
procNames = append(procNames, toProcessName(v))
|
||||
}
|
||||
return callback(f, procNames, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func parsePrefixes(values []string) (cidrs []netip.Prefix, err error) {
|
||||
for _, value := range values {
|
||||
toParse := value
|
||||
if strings.LastIndexByte(value, '/') == -1 {
|
||||
toParse += "/32"
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(toParse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse %v: %w", value, err)
|
||||
}
|
||||
cidrs = append(cidrs, prefix)
|
||||
}
|
||||
return cidrs, nil
|
||||
}
|
||||
|
||||
func toProcessName(processName string) (procName [consts.TaskCommLen]byte) {
|
||||
n := []byte(processName)
|
||||
copy(procName[:], n)
|
||||
return procName
|
||||
}
|
@ -8,18 +8,11 @@ package routing
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/pkg/config_parser"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var FakeOutbound_MUST_DIRECT = consts.OutboundMustDirect.String()
|
||||
var FakeOutbound_AND = consts.OutboundLogicalAnd.String()
|
||||
var FakeOutbound_OR = consts.OutboundLogicalOr.String()
|
||||
|
||||
type DomainSet struct {
|
||||
Key consts.RoutingDomainKey
|
||||
RuleIndex int
|
||||
@ -31,22 +24,71 @@ type Outbound struct {
|
||||
Mark uint32
|
||||
}
|
||||
|
||||
type MatcherBuilder interface {
|
||||
AddDomain(f *config_parser.Function, key string, values []string, outbound *Outbound)
|
||||
AddIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound)
|
||||
AddPort(f *config_parser.Function, values [][2]uint16, outbound *Outbound)
|
||||
AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound)
|
||||
AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound *Outbound)
|
||||
AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *Outbound)
|
||||
AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *Outbound)
|
||||
AddSourceMac(f *config_parser.Function, values [][6]byte, outbound *Outbound)
|
||||
AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *Outbound)
|
||||
AddFallback(outbound *Outbound)
|
||||
AddAnyBefore(f *config_parser.Function, key string, values []string, outbound *Outbound)
|
||||
AddAnyAfter(f *config_parser.Function, key string, values []string, outbound *Outbound)
|
||||
type RulesBuilder struct {
|
||||
log *logrus.Logger
|
||||
parsers map[string]FunctionParser
|
||||
}
|
||||
|
||||
func GroupParamValuesByKey(params []*config_parser.Param) (keyToValues map[string][]string, keyOrder []string) {
|
||||
func NewRulesBuilder(log *logrus.Logger) *RulesBuilder {
|
||||
return &RulesBuilder{
|
||||
log: log,
|
||||
parsers: make(map[string]FunctionParser),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *RulesBuilder) RegisterFunctionParser(funcName string, parser FunctionParser) {
|
||||
b.parsers[funcName] = parser
|
||||
}
|
||||
|
||||
func (b *RulesBuilder) Apply(rules []*config_parser.RoutingRule) (err error) {
|
||||
for _, rule := range rules {
|
||||
b.log.Debugln("[rule]", rule.String(true))
|
||||
outbound, err := ParseOutbound(&rule.Outbound)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// rule is like: domain(domain:baidu.com) && port(443) -> proxy
|
||||
for iFunc, f := range rule.AndFunctions {
|
||||
// f is like: domain(domain:baidu.com)
|
||||
functionParser, ok := b.parsers[f.Name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown function: %v", f.Name)
|
||||
}
|
||||
paramValueGroups, keyOrder := groupParamValuesByKey(f.Params)
|
||||
for jMatchSet, key := range keyOrder {
|
||||
paramValueGroup := paramValueGroups[key]
|
||||
// Preprocess the outbound.
|
||||
overrideOutbound := &Outbound{
|
||||
Name: consts.OutboundLogicalOr.String(),
|
||||
Mark: outbound.Mark,
|
||||
}
|
||||
if jMatchSet == len(keyOrder)-1 {
|
||||
overrideOutbound.Name = consts.OutboundLogicalAnd.String()
|
||||
if iFunc == len(rule.AndFunctions)-1 {
|
||||
overrideOutbound.Name = outbound.Name
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Debug
|
||||
symNot := ""
|
||||
if f.Not {
|
||||
symNot = "!"
|
||||
}
|
||||
b.log.Debugf("\t%v%v(%v) -> %v", symNot, f.Name, key, overrideOutbound.Name)
|
||||
}
|
||||
|
||||
if err = functionParser(b.log, f, key, paramValueGroup, overrideOutbound); err != nil {
|
||||
return fmt.Errorf("failed to parse '%v': %w", f.String(false), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func groupParamValuesByKey(params []*config_parser.Param) (keyToValues map[string][]string, keyOrder []string) {
|
||||
groups := make(map[string][]string)
|
||||
for _, param := range params {
|
||||
if _, ok := groups[param.Key]; !ok {
|
||||
@ -57,28 +99,7 @@ func GroupParamValuesByKey(params []*config_parser.Param) (keyToValues map[strin
|
||||
return groups, keyOrder
|
||||
}
|
||||
|
||||
func ParsePrefixes(values []string) (cidrs []netip.Prefix, err error) {
|
||||
for _, value := range values {
|
||||
toParse := value
|
||||
if strings.LastIndexByte(value, '/') == -1 {
|
||||
toParse += "/32"
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(toParse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse %v: %w", value, err)
|
||||
}
|
||||
cidrs = append(cidrs, prefix)
|
||||
}
|
||||
return cidrs, nil
|
||||
}
|
||||
|
||||
func ToProcessName(processName string) (procName [consts.TaskCommLen]byte) {
|
||||
n := []byte(processName)
|
||||
copy(procName[:], n)
|
||||
return procName
|
||||
}
|
||||
|
||||
func parseOutbound(rawOutbound *config_parser.Function) (outbound *Outbound, err error) {
|
||||
func ParseOutbound(rawOutbound *config_parser.Function) (outbound *Outbound, err error) {
|
||||
outbound = &Outbound{
|
||||
Name: rawOutbound.Name,
|
||||
Mark: 0,
|
||||
@ -98,164 +119,3 @@ func parseOutbound(rawOutbound *config_parser.Function) (outbound *Outbound, err
|
||||
}
|
||||
return outbound, nil
|
||||
}
|
||||
|
||||
func ApplyMatcherBuilder(log *logrus.Logger, builder MatcherBuilder, rules []*config_parser.RoutingRule, fallbackOutbound interface{}) (err error) {
|
||||
for _, rule := range rules {
|
||||
log.Debugln("[rule]", rule.String(true))
|
||||
outbound, err := parseOutbound(&rule.Outbound)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// rule is like: domain(domain:baidu.com) && port(443) -> proxy
|
||||
for iFunc, f := range rule.AndFunctions {
|
||||
// f is like: domain(domain:baidu.com)
|
||||
paramValueGroups, keyOrder := GroupParamValuesByKey(f.Params)
|
||||
for jMatchSet, key := range keyOrder {
|
||||
paramValueGroup := paramValueGroups[key]
|
||||
// Preprocess the outbound.
|
||||
overrideOutbound := &Outbound{
|
||||
Name: FakeOutbound_OR,
|
||||
Mark: outbound.Mark,
|
||||
}
|
||||
if jMatchSet == len(keyOrder)-1 {
|
||||
overrideOutbound.Name = FakeOutbound_AND
|
||||
if iFunc == len(rule.AndFunctions)-1 {
|
||||
overrideOutbound.Name = outbound.Name
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Debug
|
||||
symNot := ""
|
||||
if f.Not {
|
||||
symNot = "!"
|
||||
}
|
||||
log.Debugf("\t%v%v(%v) -> %v", symNot, f.Name, key, overrideOutbound)
|
||||
}
|
||||
|
||||
builder.AddAnyBefore(f, key, paramValueGroup, overrideOutbound)
|
||||
switch f.Name {
|
||||
case consts.Function_Domain:
|
||||
builder.AddDomain(f, key, paramValueGroup, overrideOutbound)
|
||||
case consts.Function_Ip, consts.Function_SourceIp:
|
||||
cidrs, err := ParsePrefixes(paramValueGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f.Name == consts.Function_Ip {
|
||||
builder.AddIp(f, cidrs, overrideOutbound)
|
||||
} else {
|
||||
builder.AddSourceIp(f, cidrs, overrideOutbound)
|
||||
}
|
||||
case consts.Function_Mac:
|
||||
var macAddrs [][6]byte
|
||||
for _, v := range paramValueGroup {
|
||||
mac, err := common.ParseMac(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
macAddrs = append(macAddrs, mac)
|
||||
}
|
||||
builder.AddSourceMac(f, macAddrs, overrideOutbound)
|
||||
case consts.Function_Port, consts.Function_SourcePort:
|
||||
var portRanges [][2]uint16
|
||||
for _, v := range paramValueGroup {
|
||||
portRange, err := common.ParsePortRange(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
portRanges = append(portRanges, portRange)
|
||||
}
|
||||
if f.Name == consts.Function_Port {
|
||||
builder.AddPort(f, portRanges, overrideOutbound)
|
||||
} else {
|
||||
builder.AddSourcePort(f, portRanges, overrideOutbound)
|
||||
}
|
||||
case consts.Function_L4Proto:
|
||||
var l4protoType consts.L4ProtoType
|
||||
for _, v := range paramValueGroup {
|
||||
switch v {
|
||||
case "tcp":
|
||||
l4protoType |= consts.L4ProtoType_TCP
|
||||
case "udp":
|
||||
l4protoType |= consts.L4ProtoType_UDP
|
||||
}
|
||||
}
|
||||
builder.AddL4Proto(f, l4protoType, overrideOutbound)
|
||||
case consts.Function_IpVersion:
|
||||
var ipVersion consts.IpVersionType
|
||||
for _, v := range paramValueGroup {
|
||||
switch v {
|
||||
case "4":
|
||||
ipVersion |= consts.IpVersion_4
|
||||
case "6":
|
||||
ipVersion |= consts.IpVersion_6
|
||||
}
|
||||
}
|
||||
builder.AddIpVersion(f, ipVersion, overrideOutbound)
|
||||
case consts.Function_ProcessName:
|
||||
var procNames [][consts.TaskCommLen]byte
|
||||
for _, v := range paramValueGroup {
|
||||
if len([]byte(v)) > consts.TaskCommLen {
|
||||
log.Infof(`pname routing: trim "%v" to "%v" because it is too long.`, v, string([]byte(v)[:consts.TaskCommLen]))
|
||||
}
|
||||
procNames = append(procNames, ToProcessName(v))
|
||||
}
|
||||
builder.AddProcessName(f, procNames, overrideOutbound)
|
||||
default:
|
||||
return fmt.Errorf("unsupported function name: %v", f.Name)
|
||||
}
|
||||
builder.AddAnyAfter(f, key, paramValueGroup, overrideOutbound)
|
||||
}
|
||||
}
|
||||
}
|
||||
var rawFallback *config_parser.Function
|
||||
switch fallback := fallbackOutbound.(type) {
|
||||
case string:
|
||||
rawFallback = &config_parser.Function{Name: fallback}
|
||||
case *config_parser.Function:
|
||||
rawFallback = fallback
|
||||
default:
|
||||
return fmt.Errorf("unknown type of 'fallback' in section routing: %T", fallback)
|
||||
}
|
||||
fallback, err := parseOutbound(rawFallback)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
builder.AddAnyBefore(&config_parser.Function{
|
||||
Name: "fallback",
|
||||
}, "", nil, fallback)
|
||||
builder.AddFallback(fallback)
|
||||
builder.AddAnyAfter(&config_parser.Function{
|
||||
Name: "fallback",
|
||||
}, "", nil, fallback)
|
||||
return nil
|
||||
}
|
||||
|
||||
type DefaultMatcherBuilder struct {
|
||||
}
|
||||
|
||||
func (d *DefaultMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddSourceMac(f *config_parser.Function, values [][6]byte, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddFallback(outbound *Outbound) {}
|
||||
func (d *DefaultMatcherBuilder) AddAnyBefore(f *config_parser.Function, key string, values []string, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddAnyAfter(f *config_parser.Function, key string, values []string, outbound *Outbound) {
|
||||
}
|
||||
|
@ -157,7 +157,7 @@ func (r *CryptoFrameRelocation) BytesFromPool() []byte {
|
||||
return pool.Get(0)
|
||||
}
|
||||
right := r.o[len(r.o)-1]
|
||||
return r.copyBytes(0, 0, len(r.o)-1, len(right.Data)-1, r.length)
|
||||
return r.copyBytesToPool(0, 0, len(r.o)-1, len(right.Data)-1, r.length)
|
||||
}
|
||||
|
||||
// RangeFromPool copy bytes from iUpperAppOffset to jUpperAppOffset.
|
||||
@ -191,11 +191,11 @@ func (r *CryptoFrameRelocation) RangeFromPool(i, j int) []byte {
|
||||
}
|
||||
}
|
||||
|
||||
return r.copyBytes(iOuter, iInner, jOuter, jInner, j-i+1)
|
||||
return r.copyBytesToPool(iOuter, iInner, jOuter, jInner, j-i+1)
|
||||
}
|
||||
|
||||
// copyBytes copy bytes including i and j.
|
||||
func (r *CryptoFrameRelocation) copyBytes(iOuter, iInner, jOuter, jInner, size int) []byte {
|
||||
// copyBytesToPool copy bytes including i and j.
|
||||
func (r *CryptoFrameRelocation) copyBytesToPool(iOuter, iInner, jOuter, jInner, size int) []byte {
|
||||
b := pool.Get(size)
|
||||
//io := r.o[iOuter]
|
||||
k := 0
|
||||
|
@ -8,6 +8,7 @@ package sniffing
|
||||
import (
|
||||
"github.com/mzz2017/softwind/pool"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Sniffer struct {
|
||||
@ -15,12 +16,13 @@ type Sniffer struct {
|
||||
buf []byte
|
||||
bufAt int
|
||||
stream bool
|
||||
readMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewStreamSniffer(r io.Reader, bufSize int) *Sniffer {
|
||||
s := &Sniffer{
|
||||
r: r,
|
||||
buf: pool.Get(bufSize),
|
||||
buf: make([]byte, bufSize),
|
||||
stream: true,
|
||||
}
|
||||
return s
|
||||
@ -37,6 +39,8 @@ func NewPacketSniffer(data []byte) *Sniffer {
|
||||
type sniff func() (d string, err error)
|
||||
|
||||
func (s *Sniffer) SniffTcp() (d string, err error) {
|
||||
s.readMu.Lock()
|
||||
defer s.readMu.Unlock()
|
||||
if s.stream {
|
||||
n, err := s.r.Read(s.buf)
|
||||
if err != nil {
|
||||
@ -65,6 +69,8 @@ func (s *Sniffer) SniffTcp() (d string, err error) {
|
||||
}
|
||||
|
||||
func (s *Sniffer) Read(p []byte) (n int, err error) {
|
||||
s.readMu.Lock()
|
||||
defer s.readMu.Unlock()
|
||||
if s.buf != nil && s.bufAt < len(s.buf) {
|
||||
// Read buf first.
|
||||
n = copy(p, s.buf[s.bufAt:])
|
||||
@ -84,6 +90,5 @@ func (s *Sniffer) Read(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (s *Sniffer) Close() (err error) {
|
||||
// DO NOT use pool.Put() here because Close() may not interrupt the reading, which will modify the value of the pool buffer.
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user