feat: dns routing (#26)

This commit is contained in:
mzz
2023-02-25 02:38:21 +08:00
committed by GitHub
parent 33ad434f8a
commit 8bd6a77398
48 changed files with 2758 additions and 1449 deletions

View File

@ -4,18 +4,19 @@
**_dae_**, means goose, is a lightweight and high-performance transparent proxy solution.
In order to improve the traffic split performance as much as possible, dae runs the transparent proxy and traffic split suite in the linux kernel by eBPF. Therefore, we have the opportunity to make the direct traffic bypass the forwarding by proxy application and achieve true direct traffic through. Under such a magic trick, there is almost no performance loss and additional resource consumption for direct traffic.
In order to improve the traffic split performance as much as possible, dae runs the transparent proxy and traffic split suite in the linux kernel by eBPF. Therefore, dae has the opportunity to make the direct traffic bypass the forwarding by proxy application and achieve true direct traffic through. Under such a magic trick, there is almost no performance loss and additional resource consumption for direct traffic.
As a successor of [v2rayA](https://github.com/v2rayA/v2rayA), dae abandoned v2ray-core to meet the needs of users more freely.
**Features**
1. Implement `Real direct` traffic split (need ipforward on) to achieve [high performance](https://docs.google.com/spreadsheets/d/1UaWU6nNho7edBNjNqC8dfGXLlW0-cm84MM7sH6Gp7UE/edit?usp=sharing).
1. Implement `Real Direct` traffic split (need ipforward on) to achieve [high performance](https://docs.google.com/spreadsheets/d/1UaWU6nNho7edBNjNqC8dfGXLlW0-cm84MM7sH6Gp7UE/edit?usp=sharing).
1. Support to split traffic by process name in local host.
1. Support to split traffic by MAC address in LAN.
1. Support to split traffic with invert match rules.
1. Support to automatically switch nodes according to policy. That is to say, support to automatically test independent TCP/UDP/IPv4/IPv6 latencies, and then use the best nodes for corresponding traffic according to user-defined policy.
1. Support full-cone NAT for shadowsocks, vmess, socks5 and trojan(-go).
1. Support advanced DNS resolution process.
1. Support full-cone NAT for shadowsocks, trojan(-go) and socks5 (no test).
## Prerequisites
@ -81,13 +82,9 @@ Please refer to [Quick Start Guide](./docs/getting-started/README.md) to start u
## TODO
- [ ] Check dns upstream and source loop (whether upstream is also a client of us) and remind the user to add sip rule.
- [ ] WAN L4Checksum problem.
- [ ] If the NIC checksumming offload is enabled, the Linux network stack will make a simple checksum a packet when it is sent out from local. When NIC discovers that the source IP of the packet is the local IP of the NIC, it will checksum it complete this checksum.
- [ ] But the problem is, after the Linux network stack, before entering the network card, we modify the source IP of this packet, causing the Linux network stack to only make a simple checksum, and the NIC also assumes that this packet is not sent from local, so no further checksum completing.
- [ ] Automatically check dns upstream and source loop (whether upstream is also a client of us) and remind the user to add sip rule.
- [ ] MACv2 extension extraction.
- [ ] Log to userspace.
- [ ] Protocol-oriented node features detecting (or filter), such as full-cone (especially VMess and VLESS).
- [ ] DNS traffic split.
- [ ] Add quick-start guide
- [ ] ...

View File

@ -11,6 +11,7 @@ import (
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"syscall"
)
@ -84,11 +85,15 @@ func Run(log *logrus.Logger, param *config.Params) (err error) {
param.Group,
&param.Routing,
&param.Global,
&param.Dns,
)
if err != nil {
return err
}
// Call GC to release memory.
runtime.GC()
// Serve tproxy TCP/UDP server util signals.
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGILL)

View File

@ -24,11 +24,11 @@ func NewCompactBitList(unitBitSize int) *CompactBitList {
return &CompactBitList{
unitBitSize: unitBitSize,
size: 0,
b: anybuffer.NewBuffer[uint16](1),
b: anybuffer.NewBuffer[uint16](8),
}
}
// Set is not optimized yet.
// Set function is not optimized yet.
func (m *CompactBitList) Set(iUnit int, v uint64) {
if bits.Len64(v) > m.unitBitSize {
panic(fmt.Sprintf("value %v exceeds unit bit size", v))

View File

@ -41,6 +41,16 @@ func (l L4ProtoStr) ToL4Proto() uint8 {
panic("unsupported l4proto")
}
func (l L4ProtoStr) ToL4ProtoType() L4ProtoType {
switch l {
case L4ProtoStr_TCP:
return L4ProtoType_TCP
case L4ProtoStr_UDP:
return L4ProtoType_UDP
}
panic("unsupported l4proto: " + l)
}
type IpVersionStr string
const (
@ -58,6 +68,16 @@ func (v IpVersionStr) ToIpVersion() uint8 {
panic("unsupported ipversion")
}
func (v IpVersionStr) ToIpVersionType() IpVersionType {
switch v {
case IpVersionStr_4:
return IpVersion_4
case IpVersionStr_6:
return IpVersion_6
}
panic("unsupported ipversion")
}
func IpVersionFromAddr(addr netip.Addr) IpVersionStr {
var ipversion IpVersionStr
switch {

66
common/consts/dns.go Normal file
View File

@ -0,0 +1,66 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package consts
import (
"strconv"
"strings"
)
type DnsRequestOutboundIndex uint8
const (
DnsRequestOutboundIndex_AsIs DnsRequestOutboundIndex = 0xFD
DnsRequestOutboundIndex_LogicalOr DnsRequestOutboundIndex = 0xFE
DnsRequestOutboundIndex_LogicalAnd DnsRequestOutboundIndex = 0xFF
DnsRequestOutboundIndex_LogicalMask DnsRequestOutboundIndex = 0xFE
DnsRequestOutboundIndex_UserDefinedMax = DnsRequestOutboundIndex_AsIs - 1
)
func (i DnsRequestOutboundIndex) String() string {
switch i {
case DnsRequestOutboundIndex_AsIs:
return "asis"
case DnsRequestOutboundIndex_LogicalOr:
return "<OR>"
case DnsRequestOutboundIndex_LogicalAnd:
return "<AND>"
default:
return "<index: " + strconv.Itoa(int(i)) + ">"
}
}
type DnsResponseOutboundIndex uint8
const (
DnsResponseOutboundIndex_Accept DnsResponseOutboundIndex = 0xFC
DnsResponseOutboundIndex_Reject DnsResponseOutboundIndex = 0xFD
DnsResponseOutboundIndex_LogicalOr DnsResponseOutboundIndex = 0xFE
DnsResponseOutboundIndex_LogicalAnd DnsResponseOutboundIndex = 0xFF
DnsResponseOutboundIndex_LogicalMask DnsResponseOutboundIndex = 0xFE
DnsResponseOutboundIndex_UserDefinedMax = DnsResponseOutboundIndex_Accept - 1
)
func (i DnsResponseOutboundIndex) String() string {
switch i {
case DnsResponseOutboundIndex_Accept:
return "accept"
case DnsResponseOutboundIndex_Reject:
return "reject"
case DnsResponseOutboundIndex_LogicalOr:
return "<OR>"
case DnsResponseOutboundIndex_LogicalAnd:
return "<AND>"
default:
return "<index: " + strconv.Itoa(int(i)) + ">"
}
}
func (i DnsResponseOutboundIndex) IsReserved() bool {
return !strings.HasPrefix(i.String(), "<index: ")
}

View File

@ -27,6 +27,7 @@ const (
DisableL4RxChecksumKey
ControlPlanePidKey
ControlPlaneNatDirectKey
ControlPlaneDnsRoutingKey
OneKey ParamKey = 1
)
@ -52,15 +53,19 @@ const (
MatchType_Mac
MatchType_ProcessName
MatchType_Fallback
MatchType_Upstream
MatchType_QType
)
type OutboundIndex uint8
const (
OutboundDirect OutboundIndex = 0
OutboundBlock OutboundIndex = 1
OutboundDirect OutboundIndex = iota
OutboundBlock
OutboundMustDirect OutboundIndex = 0xFC
OutboundControlPlaneDirect OutboundIndex = 0xFD
OutboundControlPlaneRouting OutboundIndex = 0xFD
OutboundLogicalOr OutboundIndex = 0xFE
OutboundLogicalAnd OutboundIndex = 0xFF
OutboundLogicalMask OutboundIndex = 0xFE
@ -77,8 +82,8 @@ func (i OutboundIndex) String() string {
return "block"
case OutboundMustDirect:
return "must_direct"
case OutboundControlPlaneDirect:
return "<Control Plane Direct>"
case OutboundControlPlaneRouting:
return "<Control Plane Routing>"
case OutboundLogicalOr:
return "<OR>"
case OutboundLogicalAnd:

View File

@ -23,7 +23,9 @@ const (
Function_Mac = "mac"
Function_ProcessName = "pname"
Declaration_Fallback = "fallback"
Function_QName = "qname"
Function_QType = "qtype"
Function_Upstream = "upstream"
OutboundParam_Mark = "mark"
)

View File

@ -17,7 +17,7 @@ type Ip46 struct {
Ip6 netip.Addr
}
func ParseIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort, host string, tcp bool) (ipv46 *Ip46, err error) {
func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort, host string, tcp bool) (ipv46 *Ip46, err error) {
addrs4, err := ResolveNetip(ctx, dialer, dns, host, dnsmessage.TypeA, tcp)
if err != nil {
return nil, err

View File

@ -12,6 +12,7 @@ import (
"encoding/binary"
"encoding/hex"
"fmt"
internal "github.com/v2rayA/dae/pkg/ebpf_internal"
"golang.org/x/net/dns/dnsmessage"
"net/netip"
"net/url"
@ -48,7 +49,7 @@ func ARangeU32(n uint32) []uint32 {
func Ipv6ByteSliceToUint32Array(_ip []byte) (ip [4]uint32) {
for j := 0; j < 16; j += 4 {
ip[j/4] = binary.LittleEndian.Uint32(_ip[j : j+4])
ip[j/4] = internal.NativeEndian.Uint32(_ip[j : j+4])
}
return ip
}
@ -61,7 +62,7 @@ func Ipv6ByteSliceToUint8Array(_ip []byte) (ip [16]uint8) {
func Ipv6Uint32ArrayToByteSlice(_ip [4]uint32) (ip []byte) {
ip = make([]byte, 16)
for j := 0; j < 4; j++ {
binary.LittleEndian.PutUint32(ip[j*4:], _ip[j])
internal.NativeEndian.PutUint32(ip[j*4:], _ip[j])
}
return ip
}
@ -372,13 +373,20 @@ func BoolToString(b bool) string {
}
}
func ConvergeIp(addr netip.Addr) netip.Addr {
func ConvergeAddr(addr netip.Addr) netip.Addr {
if addr.Is4In6() {
addr = netip.AddrFrom4(addr.As4())
}
return addr
}
func ConvergeAddrPort(addrPort netip.AddrPort) netip.AddrPort {
if addrPort.Addr().Is4In6() {
return netip.AddrPortFrom(netip.AddrFrom4(addrPort.Addr().As4()), addrPort.Port())
}
return addrPort
}
func NewGcm(key []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(key)
if err != nil {

203
component/dns/dns.go Normal file
View File

@ -0,0 +1,203 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package dns
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/routing"
"github.com/v2rayA/dae/config"
"golang.org/x/net/dns/dnsmessage"
"net/netip"
"net/url"
"sync"
)
var BadUpstreamFormatError = fmt.Errorf("bad upstream format")
type Dns struct {
upstream []*UpstreamResolver
upstream2IndexMu sync.Mutex
upstream2Index map[*Upstream]int
reqMatcher *RequestMatcher
respMatcher *ResponseMatcher
}
type NewOption struct {
UpstreamReadyCallback func(raw *url.URL, upstream *Upstream) (err error)
}
func New(log *logrus.Logger, dns *config.Dns, opt *NewOption) (s *Dns, err error) {
s = &Dns{
upstream2Index: map[*Upstream]int{
nil: int(consts.DnsRequestOutboundIndex_AsIs),
},
}
// Parse upstream.
upstreamName2Id := map[string]uint8{}
for i, upstreamRaw := range dns.Upstream {
if i >= int(consts.DnsRequestOutboundIndex_UserDefinedMax) ||
i >= int(consts.DnsResponseOutboundIndex_UserDefinedMax) {
return nil, fmt.Errorf("too many upstreams")
}
tag, link := common.GetTagFromLinkLikePlaintext(upstreamRaw)
if tag == "" {
return nil, fmt.Errorf("%w: '%v' has no tag", BadUpstreamFormatError, upstreamRaw)
}
u, err := url.Parse(link)
if err != nil {
return nil, fmt.Errorf("%w: %v", BadUpstreamFormatError, err)
}
r := &UpstreamResolver{
Raw: u,
FinishInitCallback: func(i int) func(raw *url.URL, upstream *Upstream) (err error) {
return func(raw *url.URL, upstream *Upstream) (err error) {
if opt != nil && opt.UpstreamReadyCallback != nil {
if err = opt.UpstreamReadyCallback(raw, upstream); err != nil {
return err
}
}
s.upstream2IndexMu.Lock()
s.upstream2Index[upstream] = i
s.upstream2IndexMu.Unlock()
return nil
}
}(i),
}
upstreamName2Id[tag] = uint8(len(s.upstream))
s.upstream = append(s.upstream, r)
}
// Optimize routings.
if dns.Routing.Request.Rules, err = routing.ApplyRulesOptimizers(dns.Routing.Request.Rules,
&routing.DatReaderOptimizer{Logger: log},
&routing.MergeAndSortRulesOptimizer{},
&routing.DeduplicateParamsOptimizer{},
); err != nil {
return nil, err
}
if dns.Routing.Response.Rules, err = routing.ApplyRulesOptimizers(dns.Routing.Response.Rules,
&routing.DatReaderOptimizer{Logger: log},
&routing.MergeAndSortRulesOptimizer{},
&routing.DeduplicateParamsOptimizer{},
); err != nil {
return nil, err
}
// Parse request routing.
reqMatcherBuilder, err := NewRequestMatcherBuilder(log, dns.Routing.Request.Rules, upstreamName2Id, dns.Routing.Request.Fallback)
if err != nil {
return nil, fmt.Errorf("failed to build DNS request routing: %w", err)
}
s.reqMatcher, err = reqMatcherBuilder.Build()
if err != nil {
return nil, fmt.Errorf("failed to build DNS request routing: %w", err)
}
// Parse response routing.
respMatcherBuilder, err := NewResponseMatcherBuilder(log, dns.Routing.Response.Rules, upstreamName2Id, dns.Routing.Response.Fallback)
if err != nil {
return nil, fmt.Errorf("failed to build DNS response routing: %w", err)
}
s.respMatcher, err = respMatcherBuilder.Build()
if err != nil {
return nil, fmt.Errorf("failed to build DNS response routing: %w", err)
}
if len(dns.Upstream) == 0 {
// Immediately ready.
if err = opt.UpstreamReadyCallback(nil, nil); err != nil {
return nil, err
}
}
return s, nil
}
func (s *Dns) RequestSelect(msg *dnsmessage.Message) (upstream *Upstream, err error) {
if msg.Response {
return nil, fmt.Errorf("DNS request expected but DNS response received")
}
// Prepare routing.
var qname string
var qtype dnsmessage.Type
if len(msg.Questions) == 0 {
qname = ""
qtype = 0
} else {
q := msg.Questions[0]
qname = q.Name.String()
qtype = q.Type
}
// Route.
upstreamIndex, err := s.reqMatcher.Match(qname, qtype)
if err != nil {
return nil, err
}
// nil indicates AsIs.
if upstreamIndex == consts.DnsRequestOutboundIndex_AsIs {
return nil, nil
}
if int(upstreamIndex) >= len(s.upstream) {
return nil, fmt.Errorf("bad upstream index: %v not in [0, %v]", upstreamIndex, len(s.upstream)-1)
}
// Get corresponding upstream.
upstream, err = s.upstream[upstreamIndex].GetUpstream()
if err != nil {
return nil, err
}
return upstream, nil
}
func (s *Dns) ResponseSelect(msg *dnsmessage.Message, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) {
if !msg.Response {
return 0, nil, fmt.Errorf("DNS response expected but DNS request received")
}
// Prepare routing.
var qname string
var qtype dnsmessage.Type
var ips []netip.Addr
if len(msg.Questions) == 0 {
qname = ""
qtype = 0
} else {
q := msg.Questions[0]
qname = q.Name.String()
qtype = q.Type
for _, ans := range msg.Answers {
switch body := ans.Body.(type) {
case *dnsmessage.AResource:
ips = append(ips, netip.AddrFrom4(body.A))
case *dnsmessage.AAAAResource:
ips = append(ips, netip.AddrFrom16(body.AAAA))
}
}
}
s.upstream2IndexMu.Lock()
from := s.upstream2Index[fromUpstream]
s.upstream2IndexMu.Unlock()
// Route.
upstreamIndex, err = s.respMatcher.Match(qname, qtype, ips, consts.DnsRequestOutboundIndex(from))
if err != nil {
return 0, nil, err
}
// Get corresponding upstream if upstream is neither 'accept' nor 'reject'.
if !upstreamIndex.IsReserved() {
if int(upstreamIndex) >= len(s.upstream) {
return 0, nil, fmt.Errorf("bad upstream index: %v not in [0, %v]", upstreamIndex, len(s.upstream)-1)
}
upstream, err = s.upstream[upstreamIndex].GetUpstream()
if err != nil {
return 0, nil, err
}
} else {
// Assign explicitly to let coder know.
upstream = nil
}
return upstreamIndex, upstream, nil
}

View File

@ -0,0 +1,47 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package dns
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/component/routing"
"github.com/v2rayA/dae/pkg/config_parser"
"golang.org/x/net/dns/dnsmessage"
"strings"
)
var typeNames = map[string]dnsmessage.Type{
"A": dnsmessage.TypeA,
"NS": dnsmessage.TypeNS,
"CNAME": dnsmessage.TypeCNAME,
"SOA": dnsmessage.TypeSOA,
"PTR": dnsmessage.TypePTR,
"MX": dnsmessage.TypeMX,
"TXT": dnsmessage.TypeTXT,
"AAAA": dnsmessage.TypeAAAA,
"SRV": dnsmessage.TypeSRV,
"OPT": dnsmessage.TypeOPT,
"WKS": dnsmessage.TypeWKS,
"HINFO": dnsmessage.TypeHINFO,
"MINFO": dnsmessage.TypeMINFO,
"AXFR": dnsmessage.TypeAXFR,
"ALL": dnsmessage.TypeALL,
}
func TypeParserFactory(callback func(f *config_parser.Function, types []dnsmessage.Type, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) {
var types []dnsmessage.Type
for _, v := range paramValueGroup {
t, ok := typeNames[strings.ToUpper(v)]
if !ok {
return fmt.Errorf("unknown DNS request type: %v", v)
}
types = append(types, t)
}
return callback(f, types, overrideOutbound)
}
}

View File

@ -0,0 +1,213 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package dns
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/routing"
"github.com/v2rayA/dae/component/routing/domain_matcher"
"github.com/v2rayA/dae/config"
"github.com/v2rayA/dae/pkg/config_parser"
"golang.org/x/net/dns/dnsmessage"
"strconv"
)
type RequestMatcherBuilder struct {
upstreamName2Id map[string]uint8
simulatedDomainSet []routing.DomainSet
fallback *routing.Outbound
rules []requestMatchSet
}
func NewRequestMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, upstreamName2Id map[string]uint8, fallback config.FunctionOrString) (b *RequestMatcherBuilder, err error) {
b = &RequestMatcherBuilder{upstreamName2Id: upstreamName2Id}
rulesBuilder := routing.NewRulesBuilder(log)
rulesBuilder.RegisterFunctionParser(consts.Function_QName, routing.PlainParserFactory(b.addQName))
rulesBuilder.RegisterFunctionParser(consts.Function_QType, TypeParserFactory(b.addQType))
if err = rulesBuilder.Apply(rules); err != nil {
return nil, err
}
if err = b.addFallback(fallback); err != nil {
return nil, err
}
return b, nil
}
func (b *RequestMatcherBuilder) upstreamToId(upstream string) (upstreamId consts.DnsRequestOutboundIndex, err error) {
switch upstream {
case consts.DnsRequestOutboundIndex_AsIs.String():
upstreamId = consts.DnsRequestOutboundIndex_AsIs
case consts.DnsRequestOutboundIndex_LogicalAnd.String():
upstreamId = consts.DnsRequestOutboundIndex_LogicalAnd
case consts.DnsRequestOutboundIndex_LogicalOr.String():
upstreamId = consts.DnsRequestOutboundIndex_LogicalOr
default:
_upstreamId, ok := b.upstreamName2Id[upstream]
if !ok {
return 0, fmt.Errorf("upstream %v not found; please define it in section \"dns.upstream\"", strconv.Quote(upstream))
}
upstreamId = consts.DnsRequestOutboundIndex(_upstreamId)
}
return upstreamId, nil
}
func (b *RequestMatcherBuilder) addQName(f *config_parser.Function, key string, values []string, upstream *routing.Outbound) (err error) {
switch consts.RoutingDomainKey(key) {
case consts.RoutingDomainKey_Regex,
consts.RoutingDomainKey_Full,
consts.RoutingDomainKey_Keyword,
consts.RoutingDomainKey_Suffix:
default:
return fmt.Errorf("addQName: unsupported key: %v", key)
}
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
Key: consts.RoutingDomainKey(key),
RuleIndex: len(b.simulatedDomainSet),
Domains: values,
})
upstreamId, err := b.upstreamToId(upstream.Name)
if err != nil {
return err
}
b.rules = append(b.rules, requestMatchSet{
Type: consts.MatchType_DomainSet,
Not: f.Not,
Upstream: uint8(upstreamId),
})
return nil
}
func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
upstreamName = upstream.Name
}
upstreamId, err := b.upstreamToId(upstreamName)
if err != nil {
return err
}
b.rules = append(b.rules, requestMatchSet{
Type: consts.MatchType_QType,
Value: uint16(value),
Not: f.Not,
Upstream: uint8(upstreamId),
})
}
return nil
}
func (b *RequestMatcherBuilder) addFallback(fallbackOutbound config.FunctionOrString) (err error) {
upstream, err := routing.ParseOutbound(config.FunctionOrStringToFunction(fallbackOutbound))
if err != nil {
return err
}
upstreamId, err := b.upstreamToId(upstream.Name)
if err != nil {
return err
}
b.rules = append(b.rules, requestMatchSet{
Type: consts.MatchType_Fallback,
Upstream: uint8(upstreamId),
})
return nil
}
func (b *RequestMatcherBuilder) Build() (matcher *RequestMatcher, err error) {
var m RequestMatcher
// Build domainMatcher
m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen)
for _, domains := range b.simulatedDomainSet {
m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
}
if err = m.domainMatcher.Build(); err != nil {
return nil, err
}
// Write routings.
// Fallback rule MUST be the last.
if b.rules[len(b.rules)-1].Type != consts.MatchType_Fallback {
return nil, fmt.Errorf("fallback rule MUST be the last")
}
m.matches = b.rules
return &m, nil
}
type RequestMatcher struct {
domainMatcher routing.DomainMatcher // All domain matchSets use one DomainMatcher.
matches []requestMatchSet
}
type requestMatchSet struct {
Value uint16
Not bool
Type consts.MatchType
Upstream uint8
}
func (m *RequestMatcher) Match(
qName string,
qType dnsmessage.Type,
) (upstreamIndex consts.DnsRequestOutboundIndex, err error) {
var domainMatchBitmap []uint32
if qName != "" {
domainMatchBitmap = m.domainMatcher.MatchDomainBitmap(qName)
}
goodSubrule := false
badRule := false
for i, match := range m.matches {
if badRule || goodSubrule {
goto beforeNextLoop
}
switch match.Type {
case consts.MatchType_DomainSet:
if domainMatchBitmap != nil && (domainMatchBitmap[i/32]>>(i%32))&1 > 0 {
goodSubrule = true
}
case consts.MatchType_QType:
if qType == dnsmessage.Type(match.Value) {
goodSubrule = true
}
case consts.MatchType_Fallback:
goodSubrule = true
default:
return 0, fmt.Errorf("unknown match type: %v", match.Type)
}
beforeNextLoop:
upstream := consts.DnsRequestOutboundIndex(match.Upstream)
if upstream != consts.DnsRequestOutboundIndex_LogicalOr {
// This match_set reaches the end of subrule.
// We are now at end of rule, or next match_set belongs to another
// subrule.
if goodSubrule == match.Not {
// This subrule does not hit.
badRule = true
}
// Reset goodSubrule.
goodSubrule = false
}
if upstream&consts.DnsRequestOutboundIndex_LogicalMask !=
consts.DnsRequestOutboundIndex_LogicalMask {
// Tail of a rule (line).
// Decide whether to hit.
if !badRule {
return upstream, nil
}
badRule = false
}
}
return 0, fmt.Errorf("no match set hit")
}

View File

@ -0,0 +1,320 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package dns
import (
"fmt"
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/routing"
"github.com/v2rayA/dae/component/routing/domain_matcher"
"github.com/v2rayA/dae/config"
"github.com/v2rayA/dae/pkg/config_parser"
"github.com/v2rayA/dae/pkg/trie"
"golang.org/x/net/dns/dnsmessage"
"net/netip"
"strconv"
"strings"
)
var ValidCidrChars = trie.NewValidChars([]byte{'0', '1'})
type ResponseMatcherBuilder struct {
upstreamName2Id map[string]uint8
simulatedDomainSet []routing.DomainSet
ipSet []*trie.Trie
fallback *routing.Outbound
rules []responseMatchSet
}
func NewResponseMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, upstreamName2Id map[string]uint8, fallback config.FunctionOrString) (b *ResponseMatcherBuilder, err error) {
b = &ResponseMatcherBuilder{upstreamName2Id: upstreamName2Id}
rulesBuilder := routing.NewRulesBuilder(log)
rulesBuilder.RegisterFunctionParser(consts.Function_QName, routing.PlainParserFactory(b.addQName))
rulesBuilder.RegisterFunctionParser(consts.Function_QType, TypeParserFactory(b.addQType))
rulesBuilder.RegisterFunctionParser(consts.Function_Ip, routing.IpParserFactory(b.addIp))
rulesBuilder.RegisterFunctionParser(consts.Function_Upstream, routing.EmptyKeyPlainParserFactory(b.addUpstream))
if err = rulesBuilder.Apply(rules); err != nil {
return nil, err
}
if err = b.addFallback(fallback); err != nil {
return nil, err
}
return b, nil
}
func (b *ResponseMatcherBuilder) upstreamToId(upstream string) (upstreamId consts.DnsResponseOutboundIndex, err error) {
switch upstream {
case consts.DnsResponseOutboundIndex_Accept.String():
upstreamId = consts.DnsResponseOutboundIndex_Accept
case consts.DnsResponseOutboundIndex_Reject.String():
upstreamId = consts.DnsResponseOutboundIndex_Reject
case consts.DnsResponseOutboundIndex_LogicalAnd.String():
upstreamId = consts.DnsResponseOutboundIndex_LogicalAnd
case consts.DnsResponseOutboundIndex_LogicalOr.String():
upstreamId = consts.DnsResponseOutboundIndex_LogicalOr
default:
_upstreamId, ok := b.upstreamName2Id[upstream]
if !ok {
return 0, fmt.Errorf("upstream %v not found; please define it in \"dns.upstream\"", strconv.Quote(upstream))
}
upstreamId = consts.DnsResponseOutboundIndex(_upstreamId)
}
return upstreamId, nil
}
func prefix2bin128(prefix netip.Prefix) (bin128 string) {
bits := prefix.Bits()
if prefix.Addr().Is4() {
bits += 96
}
ip := prefix.Addr().As16()
buf := buffer.NewBuffer(128)
defer buf.Put()
loop:
for i := 0; i < len(ip); i++ {
for j := 0; j < 8; j++ {
if (ip[i]>>j)&1 == 1 {
buf.WriteByte('1')
} else {
buf.WriteByte('0')
}
bits--
if bits == 0 {
break loop
}
}
}
return buf.String()
}
func (b *ResponseMatcherBuilder) addIp(f *config_parser.Function, cidrs []netip.Prefix, upstream *routing.Outbound) (err error) {
upstreamId, err := b.upstreamToId(upstream.Name)
if err != nil {
return err
}
rule := responseMatchSet{
Value: uint16(len(b.ipSet)),
Type: consts.MatchType_IpSet,
Not: f.Not,
Upstream: uint8(upstreamId),
}
var keys []string
// Convert netip.Prefix -> '0' '1' string
for _, prefix := range cidrs {
keys = append(keys, prefix2bin128(prefix))
}
t, err := trie.NewTrie(keys, ValidCidrChars)
if err != nil {
return err
}
b.ipSet = append(b.ipSet, t)
b.rules = append(b.rules, rule)
return nil
}
func (b *ResponseMatcherBuilder) addQName(f *config_parser.Function, key string, values []string, upstream *routing.Outbound) (err error) {
switch consts.RoutingDomainKey(key) {
case consts.RoutingDomainKey_Regex,
consts.RoutingDomainKey_Full,
consts.RoutingDomainKey_Keyword,
consts.RoutingDomainKey_Suffix:
default:
return fmt.Errorf("addQName: unsupported key: %v", key)
}
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
Key: consts.RoutingDomainKey(key),
RuleIndex: len(b.simulatedDomainSet),
Domains: values,
})
upstreamId, err := b.upstreamToId(upstream.Name)
if err != nil {
return err
}
b.rules = append(b.rules, responseMatchSet{
Type: consts.MatchType_DomainSet,
Not: f.Not,
Upstream: uint8(upstreamId),
})
return nil
}
func (b *ResponseMatcherBuilder) addUpstream(f *config_parser.Function, values []string, upstream *routing.Outbound) (err error) {
for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
upstreamName = upstream.Name
}
upstreamId, err := b.upstreamToId(upstreamName)
if err != nil {
return err
}
lastUpstreamId, err := b.upstreamToId(value)
if err != nil {
return err
}
b.rules = append(b.rules, responseMatchSet{
Type: consts.MatchType_Upstream,
Value: uint16(lastUpstreamId),
Not: f.Not,
Upstream: uint8(upstreamId),
})
}
return nil
}
func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
upstreamName = upstream.Name
}
upstreamId, err := b.upstreamToId(upstreamName)
if err != nil {
return err
}
b.rules = append(b.rules, responseMatchSet{
Type: consts.MatchType_QType,
Value: uint16(value),
Not: f.Not,
Upstream: uint8(upstreamId),
})
}
return nil
}
func (b *ResponseMatcherBuilder) addFallback(fallbackOutbound config.FunctionOrString) (err error) {
upstream, err := routing.ParseOutbound(config.FunctionOrStringToFunction(fallbackOutbound))
if err != nil {
return err
}
upstreamId, err := b.upstreamToId(upstream.Name)
if err != nil {
return err
}
b.rules = append(b.rules, responseMatchSet{
Type: consts.MatchType_Fallback,
Upstream: uint8(upstreamId),
})
return nil
}
func (b *ResponseMatcherBuilder) Build() (matcher *ResponseMatcher, err error) {
var m ResponseMatcher
// Build domainMatcher.
m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen)
for _, domains := range b.simulatedDomainSet {
m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
}
if err = m.domainMatcher.Build(); err != nil {
return nil, err
}
// IpSet.
m.ipSet = b.ipSet
// Write routings.
// Fallback rule MUST be the last.
if b.rules[len(b.rules)-1].Type != consts.MatchType_Fallback {
return nil, fmt.Errorf("fallback rule MUST be the last")
}
m.matches = b.rules
return &m, nil
}
type ResponseMatcher struct {
domainMatcher routing.DomainMatcher // All domain matchSets use one DomainMatcher.
ipSet []*trie.Trie
matches []responseMatchSet
}
type responseMatchSet struct {
Value uint16
Not bool
Type consts.MatchType
Upstream uint8
}
func (m *ResponseMatcher) Match(
qName string,
qType dnsmessage.Type,
ips []netip.Addr,
upstream consts.DnsRequestOutboundIndex,
) (upstreamIndex consts.DnsResponseOutboundIndex, err error) {
if qName == "" {
return 0, fmt.Errorf("qName cannot be empty")
}
qName = strings.TrimSuffix(strings.ToLower(qName), ".")
domainMatchBitmap := m.domainMatcher.MatchDomainBitmap(qName)
bin128 := make([]string, 0, len(ips))
for _, ip := range ips {
bin128 = append(bin128, prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(ip.As16()), 128)))
}
goodSubrule := false
badRule := false
for i, match := range m.matches {
if badRule || goodSubrule {
goto beforeNextLoop
}
switch match.Type {
case consts.MatchType_DomainSet:
if domainMatchBitmap != nil && (domainMatchBitmap[i/32]>>(i%32))&1 > 0 {
goodSubrule = true
}
case consts.MatchType_IpSet:
for _, bin128 := range bin128 {
// Check if any of IP hit the rule.
if m.ipSet[match.Value].HasPrefix(bin128) {
goodSubrule = true
break
}
}
case consts.MatchType_QType:
if qType == dnsmessage.Type(match.Value) {
goodSubrule = true
}
case consts.MatchType_Upstream:
if upstream == consts.DnsRequestOutboundIndex(match.Value) {
goodSubrule = true
}
case consts.MatchType_Fallback:
goodSubrule = true
default:
return 0, fmt.Errorf("unknown match type: %v", match.Type)
}
beforeNextLoop:
upstream := consts.DnsResponseOutboundIndex(match.Upstream)
if upstream != consts.DnsResponseOutboundIndex_LogicalOr {
// This match_set reaches the end of subrule.
// We are now at end of rule, or next match_set belongs to another
// subrule.
if goodSubrule == match.Not {
// This subrule does not hit.
badRule = true
}
// Reset goodSubrule.
goodSubrule = false
}
if upstream&consts.DnsResponseOutboundIndex_LogicalMask !=
consts.DnsResponseOutboundIndex_LogicalMask {
// Tail of a rule (line).
// Decide whether to hit.
if !badRule {
return upstream, nil
}
badRule = false
}
}
return 0, fmt.Errorf("no match set hit")
}

View File

@ -3,68 +3,68 @@
* Copyright (c) 2022-2023, v2rayA Organization <team@v2raya.org>
*/
package control
package dns
import (
"context"
"fmt"
"github.com/mzz2017/softwind/protocol/direct"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/common/netutils"
"net"
"net/url"
"strconv"
"sync"
"time"
)
type DnsUpstreamScheme string
type UpstreamScheme string
const (
DnsUpstreamScheme_TCP DnsUpstreamScheme = "tcp"
DnsUpstreamScheme_UDP DnsUpstreamScheme = "udp"
DnsUpstreamScheme_TCP_UDP DnsUpstreamScheme = "tcp+udp"
UpstreamScheme_TCP UpstreamScheme = "tcp"
UpstreamScheme_UDP UpstreamScheme = "udp"
UpstreamScheme_TCP_UDP UpstreamScheme = "tcp+udp"
)
func (s DnsUpstreamScheme) ContainsTcp() bool {
func (s UpstreamScheme) ContainsTcp() bool {
switch s {
case DnsUpstreamScheme_TCP,
DnsUpstreamScheme_TCP_UDP:
case UpstreamScheme_TCP,
UpstreamScheme_TCP_UDP:
return true
default:
return false
}
}
type DnsUpstream struct {
Scheme DnsUpstreamScheme
Hostname string
Port uint16
*netutils.Ip46
}
func ParseDnsUpstream(dnsUpstream *url.URL) (scheme DnsUpstreamScheme, hostname string, port uint16, err error) {
func ParseRawUpstream(raw *url.URL) (scheme UpstreamScheme, hostname string, port uint16, err error) {
var __port string
switch scheme = DnsUpstreamScheme(dnsUpstream.Scheme); scheme {
case DnsUpstreamScheme_TCP, DnsUpstreamScheme_UDP, DnsUpstreamScheme_TCP_UDP:
__port = dnsUpstream.Port()
switch scheme = UpstreamScheme(raw.Scheme); scheme {
case UpstreamScheme_TCP, UpstreamScheme_UDP, UpstreamScheme_TCP_UDP:
__port = raw.Port()
if __port == "" {
__port = "53"
}
default:
return "", "", 0, fmt.Errorf("unexpected dns_upstream format")
}
_port, err := strconv.ParseUint(dnsUpstream.Port(), 10, 16)
port = uint16(_port)
_port, err := strconv.ParseUint(__port, 10, 16)
if err != nil {
return "", "", 0, fmt.Errorf("parse dns_upstream port: %v", err)
return "", "", 0, fmt.Errorf("failed to parse dns_upstream port: %v", err)
}
hostname = dnsUpstream.Hostname()
port = uint16(_port)
hostname = raw.Hostname()
return scheme, hostname, port, nil
}
func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstream, err error) {
scheme, hostname, port, err := ParseDnsUpstream(dnsUpstream)
type Upstream struct {
Scheme UpstreamScheme
Hostname string
Port uint16
*netutils.Ip46
}
func NewUpstream(ctx context.Context, upstream *url.URL) (up *Upstream, err error) {
scheme, hostname, port, err := ParseRawUpstream(upstream)
if err != nil {
return nil, err
}
@ -79,7 +79,7 @@ func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstr
}
}()
ip46, err := netutils.ParseIp46(ctx, direct.SymmetricDirect, systemDns, hostname, false)
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, false)
if err != nil {
return nil, fmt.Errorf("failed to resolve dns_upstream: %w", err)
}
@ -87,7 +87,7 @@ func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstr
return nil, fmt.Errorf("dns_upstream has no record")
}
return &DnsUpstream{
return &Upstream{
Scheme: scheme,
Hostname: hostname,
Port: port,
@ -95,7 +95,7 @@ func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstr
}, nil
}
func (u *DnsUpstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4protos []consts.L4ProtoStr) {
func (u *Upstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4protos []consts.L4ProtoStr) {
if u.Ip4.IsValid() && u.Ip6.IsValid() {
ipversions = []consts.IpVersionStr{consts.IpVersionStr_4, consts.IpVersionStr_6}
} else {
@ -106,27 +106,31 @@ func (u *DnsUpstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4p
}
}
switch u.Scheme {
case DnsUpstreamScheme_TCP:
case UpstreamScheme_TCP:
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_TCP}
case DnsUpstreamScheme_UDP:
case UpstreamScheme_UDP:
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP}
case DnsUpstreamScheme_TCP_UDP:
case UpstreamScheme_TCP_UDP:
// UDP first.
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP, consts.L4ProtoStr_TCP}
}
return ipversions, l4protos
}
type DnsUpstreamRaw struct {
Raw common.UrlOrEmpty
func (u *Upstream) String() string {
return string(u.Scheme) + "://" + net.JoinHostPort(u.Hostname, strconv.Itoa(int(u.Port)))
}
type UpstreamResolver struct {
Raw *url.URL
// FinishInitCallback may be invoked again if err is not nil
FinishInitCallback func(raw common.UrlOrEmpty, upstream *DnsUpstream) (err error)
FinishInitCallback func(raw *url.URL, upstream *Upstream) (err error)
mu sync.Mutex
upstream *DnsUpstream
upstream *Upstream
init bool
}
func (u *DnsUpstreamRaw) GetUpstream() (_ *DnsUpstream, err error) {
func (u *UpstreamResolver) GetUpstream() (_ *Upstream, err error) {
u.mu.Lock()
defer u.mu.Unlock()
if !u.init {
@ -141,13 +145,9 @@ func (u *DnsUpstreamRaw) GetUpstream() (_ *DnsUpstream, err error) {
}()
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
defer cancel()
if !u.Raw.Empty {
if u.upstream, err = ResolveDnsUpstream(ctx, u.Raw.Url); err != nil {
if u.upstream, err = NewUpstream(ctx, u.Raw); err != nil {
return nil, fmt.Errorf("failed to init dns upstream: %v", err)
}
} else {
// Empty string. As-is.
}
}
return u.upstream, nil
}

View File

@ -90,11 +90,8 @@ func (a *AliveDialerSet) GetMinLatency() (d *Dialer, latency time.Duration) {
}
func (a *AliveDialerSet) printLatencies() {
if !a.log.IsLevelEnabled(logrus.TraceLevel) {
return
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("%v (%v):\n", a.dialerGroupName, a.CheckTyp.String()))
builder.WriteString(fmt.Sprintf("Group '%v' [%v]:\n", a.dialerGroupName, a.CheckTyp.String()))
for _, d := range a.inorderedAliveDialerSet {
latency, ok := a.dialerToLatency[d]
if !ok {
@ -210,9 +207,13 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
string(a.selectionPolicy): a.minLatency.latency,
"group": a.dialerGroupName,
"network": a.CheckTyp.String(),
"new dialer": a.minLatency.dialer.Name(),
"old dialer": oldDialerName,
"new_dialer": a.minLatency.dialer.Name(),
"old_dialer": oldDialerName,
}).Infof("Group %vselects dialer", re)
if a.log.IsLevelEnabled(logrus.TraceLevel) {
a.printLatencies()
}
} else {
// Alive -> not alive
defer a.aliveChangeCallback(false)
@ -221,9 +222,6 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
"network": a.CheckTyp.String(),
}).Infof("Group has no dialer alive")
}
if a.log.IsLevelEnabled(logrus.TraceLevel) {
a.printLatencies()
}
}
} else {
if alive && minPolicy && a.minLatency.dialer == nil {

View File

@ -118,7 +118,7 @@ func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOptio
if err != nil {
return nil, err
}
ip46, err := netutils.ParseIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), false)
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), false)
if err != nil {
return nil, err
}
@ -153,7 +153,7 @@ func ParseCheckDnsOption(ctx context.Context, dnsHostPort string) (opt *CheckDns
if err != nil {
return nil, fmt.Errorf("bad port: %v", err)
}
ip46, err := netutils.ParseIp46(ctx, direct.SymmetricDirect, systemDns, host, false)
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, host, false)
if err != nil {
return nil, err
}
@ -409,6 +409,10 @@ func (d *Dialer) NotifyCheck() {
}
}
func (d *Dialer) MustGetLatencies10(typ *NetworkType) *LatenciesN {
return d.mustGetCollection(typ).Latencies10
}
// RegisterAliveDialerSet is thread-safe.
func (d *Dialer) RegisterAliveDialerSet(a *AliveDialerSet) {
if a == nil {

View File

@ -19,7 +19,7 @@ type DialerSelectionPolicy struct {
FixedIndex int
}
func NewDialerSelectionPolicyFromGroupParam(param *config.GroupParam) (policy *DialerSelectionPolicy, err error) {
func NewDialerSelectionPolicyFromGroupParam(param *config.Group) (policy *DialerSelectionPolicy, err error) {
switch val := param.Policy.(type) {
case string:
switch consts.DialerSelectionPolicy(val) {

View File

@ -15,6 +15,8 @@ import (
"strings"
)
var ValidDomainChars = trie.NewValidChars([]byte("0123456789abcdefghijklmnopqrstuvwxyz-.^"))
type AhocorasickSlimtrie struct {
validAcIndexes []int
validTrieIndexes []int
@ -173,7 +175,7 @@ func (n *AhocorasickSlimtrie) Build() (err error) {
}
toBuild = ToSuffixTrieStrings(toBuild)
sort.Strings(toBuild)
n.trie[i], err = trie.NewTrie(toBuild)
n.trie[i], err = trie.NewTrie(toBuild, ValidDomainChars)
if err != nil {
return err
}

View File

@ -94,7 +94,6 @@ var TestSample = []string{
}
type RoutingMatcherBuilder struct {
*routing.DefaultMatcherBuilder
outboundName2Id map[string]uint8
simulatedDomainSet []routing.DomainSet
Fallback string
@ -118,7 +117,7 @@ func (b *RoutingMatcherBuilder) AddDomain(f *config_parser.Function, key string,
consts.RoutingDomainKey_Keyword,
consts.RoutingDomainKey_Suffix:
default:
b.err = fmt.Errorf("AddDomain: unsupported key: %v", key)
b.err = fmt.Errorf("addDomain: unsupported key: %v", key)
return
}
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
@ -132,22 +131,16 @@ func getDomain() (simulatedDomainSet []routing.DomainSet, err error) {
var rules []*config_parser.RoutingRule
sections, err := config_parser.Parse(`
routing {
pname(NetworkManager, dnsmasq, systemd-resolved) -> must_direct # Traffic of DNS in local must be direct to avoid loop when binding to WAN.
pname(sogou-qimpanel, sogou-qimpanel-watchdog) -> block
ip(geoip:private, 224.0.0.0/3, 'ff00::/8') -> direct # Put it in front unless you know what you're doing.
domain(geosite:bing)->us
domain(full:dns.google) && port(53) -> direct
domain(full:dns.google) -> direct
domain(geosite:category-ads-all) -> block
ip(geoip:private) -> direct
ip(geoip:cn) -> direct
domain(geosite:cn) -> direct
fallback: my_group
}`)
if err != nil {
return nil, err
}
var r config.Routing
if err = config.RoutingRuleAndParamParser(reflect.ValueOf(&r), sections[0]); err != nil {
if err = config.SectionParser(reflect.ValueOf(&r), sections[0]); err != nil {
return nil, err
}
if rules, err = routing.ApplyRulesOptimizers(r.Rules,
@ -159,8 +152,13 @@ routing {
return nil, fmt.Errorf("ApplyRulesOptimizers error:\n%w", err)
}
builder := RoutingMatcherBuilder{}
if err = routing.ApplyMatcherBuilder(logrus.StandardLogger(), &builder, rules, r.Fallback); err != nil {
return nil, fmt.Errorf("ApplyMatcherBuilder: %w", err)
rb := routing.NewRulesBuilder(logrus.StandardLogger())
rb.RegisterFunctionParser("domain", func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) {
builder.AddDomain(f, key, paramValueGroup, overrideOutbound)
return nil
})
if err = rb.Apply(rules); err != nil {
return nil, fmt.Errorf("Apply: %w", err)
}
return builder.simulatedDomainSet, nil
}

View File

@ -0,0 +1,139 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package routing
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/pkg/config_parser"
"net/netip"
"strings"
)
type FunctionParser func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error)
// Preset function parser factories.
// PlainParserFactory is for style unity.
func PlainParserFactory(callback func(f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
return callback(f, key, paramValueGroup, overrideOutbound)
}
}
// EmptyKeyPlainParserFactory only accepts function with empty key.
func EmptyKeyPlainParserFactory(callback func(f *config_parser.Function, values []string, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
if key != "" {
return fmt.Errorf("this function cannot accept a key")
}
return callback(f, paramValueGroup, overrideOutbound)
}
}
func IpParserFactory(callback func(f *config_parser.Function, cidrs []netip.Prefix, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
cidrs, err := parsePrefixes(paramValueGroup)
if err != nil {
return err
}
return callback(f, cidrs, overrideOutbound)
}
}
func MacParserFactory(callback func(f *config_parser.Function, macAddrs [][6]byte, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
var macAddrs [][6]byte
for _, v := range paramValueGroup {
mac, err := common.ParseMac(v)
if err != nil {
return err
}
macAddrs = append(macAddrs, mac)
}
return callback(f, macAddrs, overrideOutbound)
}
}
func PortRangeParserFactory(callback func(f *config_parser.Function, portRanges [][2]uint16, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
var portRanges [][2]uint16
for _, v := range paramValueGroup {
portRange, err := common.ParsePortRange(v)
if err != nil {
return err
}
portRanges = append(portRanges, portRange)
}
return callback(f, portRanges, overrideOutbound)
}
}
func L4ProtoParserFactory(callback func(f *config_parser.Function, l4protoType consts.L4ProtoType, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
var l4protoType consts.L4ProtoType
for _, v := range paramValueGroup {
switch v {
case "tcp":
l4protoType |= consts.L4ProtoType_TCP
case "udp":
l4protoType |= consts.L4ProtoType_UDP
}
}
return callback(f, l4protoType, overrideOutbound)
}
}
func IpVersionParserFactory(callback func(f *config_parser.Function, ipVersion consts.IpVersionType, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
var ipVersion consts.IpVersionType
for _, v := range paramValueGroup {
switch v {
case "4":
ipVersion |= consts.IpVersion_4
case "6":
ipVersion |= consts.IpVersion_6
}
}
return callback(f, ipVersion, overrideOutbound)
}
}
func ProcessNameParserFactory(callback func(f *config_parser.Function, procNames [][consts.TaskCommLen]byte, overrideOutbound *Outbound) (err error)) FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
var procNames [][consts.TaskCommLen]byte
for _, v := range paramValueGroup {
if len([]byte(v)) > consts.TaskCommLen {
log.Infof(`pname routing: trim "%v" to "%v" because it is too long.`, v, string([]byte(v)[:consts.TaskCommLen]))
}
procNames = append(procNames, toProcessName(v))
}
return callback(f, procNames, overrideOutbound)
}
}
func parsePrefixes(values []string) (cidrs []netip.Prefix, err error) {
for _, value := range values {
toParse := value
if strings.LastIndexByte(value, '/') == -1 {
toParse += "/32"
}
prefix, err := netip.ParsePrefix(toParse)
if err != nil {
return nil, fmt.Errorf("cannot parse %v: %w", value, err)
}
cidrs = append(cidrs, prefix)
}
return cidrs, nil
}
func toProcessName(processName string) (procName [consts.TaskCommLen]byte) {
n := []byte(processName)
copy(procName[:], n)
return procName
}

View File

@ -8,18 +8,11 @@ package routing
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/pkg/config_parser"
"net/netip"
"strconv"
"strings"
)
var FakeOutbound_MUST_DIRECT = consts.OutboundMustDirect.String()
var FakeOutbound_AND = consts.OutboundLogicalAnd.String()
var FakeOutbound_OR = consts.OutboundLogicalOr.String()
type DomainSet struct {
Key consts.RoutingDomainKey
RuleIndex int
@ -31,22 +24,71 @@ type Outbound struct {
Mark uint32
}
type MatcherBuilder interface {
AddDomain(f *config_parser.Function, key string, values []string, outbound *Outbound)
AddIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound)
AddPort(f *config_parser.Function, values [][2]uint16, outbound *Outbound)
AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound)
AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound *Outbound)
AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *Outbound)
AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *Outbound)
AddSourceMac(f *config_parser.Function, values [][6]byte, outbound *Outbound)
AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *Outbound)
AddFallback(outbound *Outbound)
AddAnyBefore(f *config_parser.Function, key string, values []string, outbound *Outbound)
AddAnyAfter(f *config_parser.Function, key string, values []string, outbound *Outbound)
type RulesBuilder struct {
log *logrus.Logger
parsers map[string]FunctionParser
}
func GroupParamValuesByKey(params []*config_parser.Param) (keyToValues map[string][]string, keyOrder []string) {
func NewRulesBuilder(log *logrus.Logger) *RulesBuilder {
return &RulesBuilder{
log: log,
parsers: make(map[string]FunctionParser),
}
}
func (b *RulesBuilder) RegisterFunctionParser(funcName string, parser FunctionParser) {
b.parsers[funcName] = parser
}
func (b *RulesBuilder) Apply(rules []*config_parser.RoutingRule) (err error) {
for _, rule := range rules {
b.log.Debugln("[rule]", rule.String(true))
outbound, err := ParseOutbound(&rule.Outbound)
if err != nil {
return err
}
// rule is like: domain(domain:baidu.com) && port(443) -> proxy
for iFunc, f := range rule.AndFunctions {
// f is like: domain(domain:baidu.com)
functionParser, ok := b.parsers[f.Name]
if !ok {
return fmt.Errorf("unknown function: %v", f.Name)
}
paramValueGroups, keyOrder := groupParamValuesByKey(f.Params)
for jMatchSet, key := range keyOrder {
paramValueGroup := paramValueGroups[key]
// Preprocess the outbound.
overrideOutbound := &Outbound{
Name: consts.OutboundLogicalOr.String(),
Mark: outbound.Mark,
}
if jMatchSet == len(keyOrder)-1 {
overrideOutbound.Name = consts.OutboundLogicalAnd.String()
if iFunc == len(rule.AndFunctions)-1 {
overrideOutbound.Name = outbound.Name
}
}
{
// Debug
symNot := ""
if f.Not {
symNot = "!"
}
b.log.Debugf("\t%v%v(%v) -> %v", symNot, f.Name, key, overrideOutbound.Name)
}
if err = functionParser(b.log, f, key, paramValueGroup, overrideOutbound); err != nil {
return fmt.Errorf("failed to parse '%v': %w", f.String(false), err)
}
}
}
}
return nil
}
func groupParamValuesByKey(params []*config_parser.Param) (keyToValues map[string][]string, keyOrder []string) {
groups := make(map[string][]string)
for _, param := range params {
if _, ok := groups[param.Key]; !ok {
@ -57,28 +99,7 @@ func GroupParamValuesByKey(params []*config_parser.Param) (keyToValues map[strin
return groups, keyOrder
}
func ParsePrefixes(values []string) (cidrs []netip.Prefix, err error) {
for _, value := range values {
toParse := value
if strings.LastIndexByte(value, '/') == -1 {
toParse += "/32"
}
prefix, err := netip.ParsePrefix(toParse)
if err != nil {
return nil, fmt.Errorf("cannot parse %v: %w", value, err)
}
cidrs = append(cidrs, prefix)
}
return cidrs, nil
}
func ToProcessName(processName string) (procName [consts.TaskCommLen]byte) {
n := []byte(processName)
copy(procName[:], n)
return procName
}
func parseOutbound(rawOutbound *config_parser.Function) (outbound *Outbound, err error) {
func ParseOutbound(rawOutbound *config_parser.Function) (outbound *Outbound, err error) {
outbound = &Outbound{
Name: rawOutbound.Name,
Mark: 0,
@ -98,164 +119,3 @@ func parseOutbound(rawOutbound *config_parser.Function) (outbound *Outbound, err
}
return outbound, nil
}
func ApplyMatcherBuilder(log *logrus.Logger, builder MatcherBuilder, rules []*config_parser.RoutingRule, fallbackOutbound interface{}) (err error) {
for _, rule := range rules {
log.Debugln("[rule]", rule.String(true))
outbound, err := parseOutbound(&rule.Outbound)
if err != nil {
return err
}
// rule is like: domain(domain:baidu.com) && port(443) -> proxy
for iFunc, f := range rule.AndFunctions {
// f is like: domain(domain:baidu.com)
paramValueGroups, keyOrder := GroupParamValuesByKey(f.Params)
for jMatchSet, key := range keyOrder {
paramValueGroup := paramValueGroups[key]
// Preprocess the outbound.
overrideOutbound := &Outbound{
Name: FakeOutbound_OR,
Mark: outbound.Mark,
}
if jMatchSet == len(keyOrder)-1 {
overrideOutbound.Name = FakeOutbound_AND
if iFunc == len(rule.AndFunctions)-1 {
overrideOutbound.Name = outbound.Name
}
}
{
// Debug
symNot := ""
if f.Not {
symNot = "!"
}
log.Debugf("\t%v%v(%v) -> %v", symNot, f.Name, key, overrideOutbound)
}
builder.AddAnyBefore(f, key, paramValueGroup, overrideOutbound)
switch f.Name {
case consts.Function_Domain:
builder.AddDomain(f, key, paramValueGroup, overrideOutbound)
case consts.Function_Ip, consts.Function_SourceIp:
cidrs, err := ParsePrefixes(paramValueGroup)
if err != nil {
return err
}
if f.Name == consts.Function_Ip {
builder.AddIp(f, cidrs, overrideOutbound)
} else {
builder.AddSourceIp(f, cidrs, overrideOutbound)
}
case consts.Function_Mac:
var macAddrs [][6]byte
for _, v := range paramValueGroup {
mac, err := common.ParseMac(v)
if err != nil {
return err
}
macAddrs = append(macAddrs, mac)
}
builder.AddSourceMac(f, macAddrs, overrideOutbound)
case consts.Function_Port, consts.Function_SourcePort:
var portRanges [][2]uint16
for _, v := range paramValueGroup {
portRange, err := common.ParsePortRange(v)
if err != nil {
return err
}
portRanges = append(portRanges, portRange)
}
if f.Name == consts.Function_Port {
builder.AddPort(f, portRanges, overrideOutbound)
} else {
builder.AddSourcePort(f, portRanges, overrideOutbound)
}
case consts.Function_L4Proto:
var l4protoType consts.L4ProtoType
for _, v := range paramValueGroup {
switch v {
case "tcp":
l4protoType |= consts.L4ProtoType_TCP
case "udp":
l4protoType |= consts.L4ProtoType_UDP
}
}
builder.AddL4Proto(f, l4protoType, overrideOutbound)
case consts.Function_IpVersion:
var ipVersion consts.IpVersionType
for _, v := range paramValueGroup {
switch v {
case "4":
ipVersion |= consts.IpVersion_4
case "6":
ipVersion |= consts.IpVersion_6
}
}
builder.AddIpVersion(f, ipVersion, overrideOutbound)
case consts.Function_ProcessName:
var procNames [][consts.TaskCommLen]byte
for _, v := range paramValueGroup {
if len([]byte(v)) > consts.TaskCommLen {
log.Infof(`pname routing: trim "%v" to "%v" because it is too long.`, v, string([]byte(v)[:consts.TaskCommLen]))
}
procNames = append(procNames, ToProcessName(v))
}
builder.AddProcessName(f, procNames, overrideOutbound)
default:
return fmt.Errorf("unsupported function name: %v", f.Name)
}
builder.AddAnyAfter(f, key, paramValueGroup, overrideOutbound)
}
}
}
var rawFallback *config_parser.Function
switch fallback := fallbackOutbound.(type) {
case string:
rawFallback = &config_parser.Function{Name: fallback}
case *config_parser.Function:
rawFallback = fallback
default:
return fmt.Errorf("unknown type of 'fallback' in section routing: %T", fallback)
}
fallback, err := parseOutbound(rawFallback)
if err != nil {
return err
}
builder.AddAnyBefore(&config_parser.Function{
Name: "fallback",
}, "", nil, fallback)
builder.AddFallback(fallback)
builder.AddAnyAfter(&config_parser.Function{
Name: "fallback",
}, "", nil, fallback)
return nil
}
type DefaultMatcherBuilder struct {
}
func (d *DefaultMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddSourceMac(f *config_parser.Function, values [][6]byte, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddFallback(outbound *Outbound) {}
func (d *DefaultMatcherBuilder) AddAnyBefore(f *config_parser.Function, key string, values []string, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *Outbound) {
}
func (d *DefaultMatcherBuilder) AddAnyAfter(f *config_parser.Function, key string, values []string, outbound *Outbound) {
}

View File

@ -157,7 +157,7 @@ func (r *CryptoFrameRelocation) BytesFromPool() []byte {
return pool.Get(0)
}
right := r.o[len(r.o)-1]
return r.copyBytes(0, 0, len(r.o)-1, len(right.Data)-1, r.length)
return r.copyBytesToPool(0, 0, len(r.o)-1, len(right.Data)-1, r.length)
}
// RangeFromPool copy bytes from iUpperAppOffset to jUpperAppOffset.
@ -191,11 +191,11 @@ func (r *CryptoFrameRelocation) RangeFromPool(i, j int) []byte {
}
}
return r.copyBytes(iOuter, iInner, jOuter, jInner, j-i+1)
return r.copyBytesToPool(iOuter, iInner, jOuter, jInner, j-i+1)
}
// copyBytes copy bytes including i and j.
func (r *CryptoFrameRelocation) copyBytes(iOuter, iInner, jOuter, jInner, size int) []byte {
// copyBytesToPool copy bytes including i and j.
func (r *CryptoFrameRelocation) copyBytesToPool(iOuter, iInner, jOuter, jInner, size int) []byte {
b := pool.Get(size)
//io := r.o[iOuter]
k := 0

View File

@ -8,6 +8,7 @@ package sniffing
import (
"github.com/mzz2017/softwind/pool"
"io"
"sync"
)
type Sniffer struct {
@ -15,12 +16,13 @@ type Sniffer struct {
buf []byte
bufAt int
stream bool
readMu sync.Mutex
}
func NewStreamSniffer(r io.Reader, bufSize int) *Sniffer {
s := &Sniffer{
r: r,
buf: pool.Get(bufSize),
buf: make([]byte, bufSize),
stream: true,
}
return s
@ -37,6 +39,8 @@ func NewPacketSniffer(data []byte) *Sniffer {
type sniff func() (d string, err error)
func (s *Sniffer) SniffTcp() (d string, err error) {
s.readMu.Lock()
defer s.readMu.Unlock()
if s.stream {
n, err := s.r.Read(s.buf)
if err != nil {
@ -65,6 +69,8 @@ func (s *Sniffer) SniffTcp() (d string, err error) {
}
func (s *Sniffer) Read(p []byte) (n int, err error) {
s.readMu.Lock()
defer s.readMu.Unlock()
if s.buf != nil && s.bufAt < len(s.buf) {
// Read buf first.
n = copy(p, s.buf[s.bufAt:])
@ -84,6 +90,5 @@ func (s *Sniffer) Read(p []byte) (n int, err error) {
}
func (s *Sniffer) Close() (err error) {
// DO NOT use pool.Put() here because Close() may not interrupt the reading, which will modify the value of the pool buffer.
return nil
}

View File

@ -7,7 +7,6 @@ package config
import (
"fmt"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/pkg/config_parser"
"reflect"
"time"
@ -22,7 +21,7 @@ type Global struct {
UdpCheckDns string `mapstructure:"udp_check_dns" default:"dns.google:53"`
CheckInterval time.Duration `mapstructure:"check_interval" default:"30s"`
CheckTolerance time.Duration `mapstructure:"check_tolerance" default:"0"`
DnsUpstream common.UrlOrEmpty `mapstructure:"dns_upstream" default:""`
DnsUpstream string `mapstructure:"dns_upstream" default:"<empty>"`
LanInterface []string `mapstructure:"lan_interface"`
LanNatDirect bool `mapstructure:"lan_nat_direct" default:"true"`
WanInterface []string `mapstructure:"wan_interface"`
@ -30,28 +29,55 @@ type Global struct {
DialMode string `mapstructure:"dial_mode" default:"domain"`
}
type Group struct {
Name string
Param GroupParam
type FunctionOrString interface{}
func FunctionOrStringToFunction(fs FunctionOrString) (f *config_parser.Function) {
switch fs := fs.(type) {
case string:
return &config_parser.Function{Name: fs}
case *config_parser.Function:
return fs
default:
panic(fmt.Sprintf("unknown type of 'fallback' in section routing: %T", fs))
}
}
type GroupParam struct {
type Group struct {
Name string `mapstructure:"_"`
Filter []*config_parser.Function `mapstructure:"filter"`
Policy interface{} `mapstructure:"policy" required:""`
}
type DnsRequestRouting struct {
Rules []*config_parser.RoutingRule `mapstructure:"_"`
Fallback FunctionOrString `mapstructure:"fallback" required:""`
}
type DnsResponseRouting struct {
Rules []*config_parser.RoutingRule `mapstructure:"_"`
Fallback FunctionOrString `mapstructure:"fallback" required:""`
}
type Dns struct {
Upstream []string `mapstructure:"upstream"`
Routing struct {
Request DnsRequestRouting `mapstructure:"request"`
Response DnsResponseRouting `mapstructure:"response"`
} `mapstructure:"routing"`
}
type Routing struct {
Rules []*config_parser.RoutingRule `mapstructure:"_"`
Fallback interface{} `mapstructure:"fallback"`
Final interface{} `mapstructure:"final"`
Fallback FunctionOrString `mapstructure:"fallback"`
Final FunctionOrString `mapstructure:"final"`
}
type Params struct {
Global Global `mapstructure:"global" parser:"ParamParser"`
Subscription []string `mapstructure:"subscription" parser:"StringListParser"`
Node []string `mapstructure:"node" parser:"StringListParser"`
Group []Group `mapstructure:"group" parser:"GroupListParser"`
Routing Routing `mapstructure:"routing" parser:"RoutingRuleAndParamParser"`
Global Global `mapstructure:"global" required:""`
Subscription []string `mapstructure:"subscription"`
Node []string `mapstructure:"node"`
Group []Group `mapstructure:"group" required:""`
Routing Routing `mapstructure:"routing" required:""`
Dns Dns `mapstructure:"dns"`
}
// New params from sections. This func assumes merging (section "include") and deduplication for section names has been executed.
@ -82,21 +108,15 @@ func New(sections []*config_parser.Section) (params *Params, err error) {
}
section, ok := nameToSection[sectionName]
if !ok {
if _, required := structField.Tag.Lookup("required"); required {
return nil, fmt.Errorf("section %v is required but not provided", sectionName)
} else {
continue
}
// Find corresponding parser func.
parserName, ok := structField.Tag.Lookup("parser")
if !ok {
return nil, fmt.Errorf("no parser is specified in field %v", structField.Name)
}
parser, ok := ParserMap[parserName]
if !ok {
return nil, fmt.Errorf("unknown parser %v in field %v", parserName, structField.Name)
}
// Parse section and unmarshal to field.
if err := parser(field.Addr(), section.Val); err != nil {
if err := SectionParser(field.Addr(), section.Val); err != nil {
return nil, fmt.Errorf("failed to parse \"%v\": %w", sectionName, err)
}
section.Parsed = true

View File

@ -8,12 +8,15 @@ package config
import (
"fmt"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common/consts"
)
type patch func(params *Params) error
var patches = []patch{
patchRoutingFallback,
patchEmptyDns,
patchDeprecatedGlobalDnsUpstream,
}
func patchRoutingFallback(params *Params) error {
@ -28,3 +31,20 @@ func patchRoutingFallback(params *Params) error {
}
return nil
}
func patchEmptyDns(params *Params) error {
if params.Dns.Routing.Request.Fallback == nil {
params.Dns.Routing.Request.Fallback = consts.DnsRequestOutboundIndex_AsIs.String()
}
if params.Dns.Routing.Response.Fallback == nil {
params.Dns.Routing.Response.Fallback = consts.DnsResponseOutboundIndex_Accept.String()
}
return nil
}
func patchDeprecatedGlobalDnsUpstream(params *Params) error {
if params.Global.DnsUpstream != "<empty>" {
return fmt.Errorf("'global.dns_upstream' was deprecated, please refer to the latest examples and docs for help")
}
return nil
}

View File

@ -13,16 +13,6 @@ import (
"strings"
)
// Parser is section items parser
type Parser func(to reflect.Value, section *config_parser.Section) error
var ParserMap = map[string]Parser{
"StringListParser": StringListParser,
"ParamParser": ParamParser,
"GroupListParser": GroupListParser,
"RoutingRuleAndParamParser": RoutingRuleAndParamParser,
}
func StringListParser(to reflect.Value, section *config_parser.Section) error {
if to.Kind() != reflect.Pointer {
return fmt.Errorf("StringListParser can only unmarshal section to *[]string")
@ -44,7 +34,7 @@ func StringListParser(to reflect.Value, section *config_parser.Section) error {
return nil
}
func paramParser(to reflect.Value, section *config_parser.Section, ignoreType []reflect.Type) error {
func ParamParser(to reflect.Value, section *config_parser.Section, ignoreType []reflect.Type) error {
if to.Kind() != reflect.Pointer {
return fmt.Errorf("ParamParser can only unmarshal section to *struct")
}
@ -67,7 +57,7 @@ func paramParser(to reflect.Value, section *config_parser.Section, ignoreType []
// Set up key to field mapping.
key, ok := structField.Tag.Lookup("mapstructure")
if !ok {
return fmt.Errorf("field %v has no mapstructure tag", structField.Name)
return fmt.Errorf("field \"%v\" has no mapstructure tag", structField.Name)
}
if key == "_" {
// omit
@ -95,11 +85,11 @@ func paramParser(to reflect.Value, section *config_parser.Section, ignoreType []
switch itemVal := item.Value.(type) {
case *config_parser.Param:
if itemVal.Key == "" {
return fmt.Errorf("section %v does not support text without a key: %v", section.Name, itemVal.String(true))
return fmt.Errorf("unsupported text without a key: %v", itemVal.String(true))
}
field, ok := keyToField[itemVal.Key]
if !ok {
return fmt.Errorf("section %v does not support key: %v", section.Name, itemVal.Key)
return fmt.Errorf("unexpected key: %v", itemVal.Key)
}
if itemVal.AndFunctions != nil {
// AndFunctions.
@ -108,7 +98,7 @@ func paramParser(to reflect.Value, section *config_parser.Section, ignoreType []
field.Val.Type() == reflect.TypeOf(itemVal.AndFunctions) {
field.Val.Set(reflect.ValueOf(itemVal.AndFunctions))
} else {
return fmt.Errorf("failed to parse \"%v.%v\": value \"%v\" cannot be convert to %v", section.Name, itemVal.Key, itemVal.Val, field.Val.Type().String())
return fmt.Errorf("failed to parse \"%v\": value \"%v\" cannot be convert to %v", itemVal.Key, itemVal.Val, field.Val.Type().String())
}
} else {
// String value.
@ -122,21 +112,42 @@ func paramParser(to reflect.Value, section *config_parser.Section, ignoreType []
for _, value := range values {
vPointerNew := reflect.New(field.Val.Type().Elem())
if !common.FuzzyDecode(vPointerNew.Interface(), value) {
return fmt.Errorf("failed to parse \"%v.%v\": value \"%v\" cannot be convert to %v", section.Name, itemVal.Key, itemVal.Val, field.Val.Type().Elem().String())
return fmt.Errorf("failed to parse \"%v\": value \"%v\" cannot be convert to %v", itemVal.Key, itemVal.Val, field.Val.Type().Elem().String())
}
field.Val.Set(reflect.Append(field.Val, vPointerNew.Elem()))
}
default:
// Field is not interface{}, we can decode.
if !common.FuzzyDecode(field.Val.Addr().Interface(), itemVal.Val) {
return fmt.Errorf("failed to parse \"%v.%v\": value \"%v\" cannot be convert to %v", section.Name, itemVal.Key, itemVal.Val, field.Val.Type().String())
return fmt.Errorf("failed to parse \"%v\": value \"%v\" cannot be convert to %v", itemVal.Key, itemVal.Val, field.Val.Type().String())
}
}
}
field.Set = true
case *config_parser.Section:
// Named section config item.
field, ok := keyToField[itemVal.Name]
if !ok {
return fmt.Errorf("unexpected key: %v", itemVal.Name)
}
if err := SectionParser(field.Val.Addr(), itemVal); err != nil {
return fmt.Errorf("failed to parse %v: %w", itemVal.Name, err)
}
field.Set = true
case *config_parser.RoutingRule:
// Assign. "to" should have field "Rules".
structField, ok := to.Type().FieldByName("Rules")
if !ok || structField.Type != reflect.TypeOf([]*config_parser.RoutingRule{}) {
return fmt.Errorf("unexpected type: \"routing rule\": %v", itemVal.String(true))
}
if structField.Tag.Get("mapstructure") != "_" {
return fmt.Errorf("a []*RoutingRule field \"Rules\" with mapstructure:\"_\" is required in struct %v to parse section", to.Type().String())
}
field := to.FieldByName("Rules")
field.Set(reflect.Append(field, reflect.ValueOf(itemVal)))
default:
if _, ignore := ignoreTypeSet[reflect.TypeOf(itemVal)]; !ignore {
return fmt.Errorf("section %v does not support type %v: %v", section.Name, item.Type.String(), item.String())
return fmt.Errorf("unexpected type %v: %v", item.Type.String(), item.String())
}
}
}
@ -155,76 +166,66 @@ func paramParser(to reflect.Value, section *config_parser.Section, ignoreType []
return nil
}
func ParamParser(to reflect.Value, section *config_parser.Section) error {
return paramParser(to, section, nil)
}
func GroupListParser(to reflect.Value, section *config_parser.Section) error {
func SectionParser(to reflect.Value, section *config_parser.Section) error {
if to.Kind() != reflect.Pointer {
return fmt.Errorf("GroupListParser can only unmarshal section to *[]Group")
return fmt.Errorf("SectionParser can only unmarshal section to a pointer")
}
to = to.Elem()
if to.Type() != reflect.TypeOf([]Group{}) {
return fmt.Errorf("GroupListParser can only unmarshal section to *[]Group")
switch to.Kind() {
case reflect.Slice:
elemType := to.Type().Elem()
switch elemType.Kind() {
case reflect.String:
return StringListParser(to.Addr(), section)
case reflect.Struct:
// "to" is a section list (sections in section).
/**
to {
field1 {
...
}
field2 {
...
}
}
should be parsed to:
to []struct {
Name string `mapstructure: "_"`
...
}
*/
// The struct should contain Name.
nameStructField, ok := elemType.FieldByName("Name")
if !ok || nameStructField.Type.Kind() != reflect.String || nameStructField.Tag.Get("mapstructure") != "_" {
return fmt.Errorf("a string field \"Name\" with mapstructure:\"_\" is required in struct %v to parse section", to.Type().Elem().String())
}
// Scan sections.
for _, item := range section.Items {
elem := reflect.New(elemType).Elem()
switch itemVal := item.Value.(type) {
case *config_parser.Section:
group := Group{
Name: itemVal.Name,
Param: GroupParam{},
elem.FieldByName("Name").SetString(itemVal.Name)
if err := SectionParser(elem.Addr(), itemVal); err != nil {
return fmt.Errorf("error when parse \"%v\": %w", itemVal.Name, err)
}
paramVal := reflect.ValueOf(&group.Param)
if err := paramParser(paramVal, itemVal, nil); err != nil {
return fmt.Errorf("failed to parse \"%v\": %w", itemVal.Name, err)
}
to.Set(reflect.Append(to, reflect.ValueOf(group)))
to.Set(reflect.Append(to, elem))
default:
return fmt.Errorf("section %v does not support type %v: %v", section.Name, item.Type.String(), item.String())
return fmt.Errorf("unmatched type: %v -> %v", item.Type.String(), elemType)
}
}
return nil
}
func RoutingRuleAndParamParser(to reflect.Value, section *config_parser.Section) error {
if to.Kind() != reflect.Pointer {
return fmt.Errorf("RoutingRuleAndParamParser can only unmarshal section to *struct")
}
to = to.Elem()
if to.Kind() != reflect.Struct {
return fmt.Errorf("RoutingRuleAndParamParser can only unmarshal section to *struct")
}
// Find the first []*RoutingRule field to unmarshal.
targetType := reflect.TypeOf([]*config_parser.RoutingRule{})
var ruleTo *reflect.Value
for i := 0; i < to.NumField(); i++ {
field := to.Field(i)
if field.Type() == targetType {
ruleTo = &field
break
}
}
if ruleTo == nil {
return fmt.Errorf(`no %v field found`, targetType.String())
}
// Parse and unmarshal list of RoutingRule to ruleTo.
for _, item := range section.Items {
switch itemVal := item.Value.(type) {
case *config_parser.RoutingRule:
ruleTo.Set(reflect.Append(*ruleTo, reflect.ValueOf(itemVal)))
case *config_parser.Param:
// pass
default:
return fmt.Errorf("section %v does not support type %v: %v", section.Name, item.Type.String(), item.String())
goto unsupported
}
case reflect.Struct:
// Section.
return ParamParser(to.Addr(), section, nil)
default:
goto unsupported
}
// Parse Param.
return paramParser(to.Addr(), section,
[]reflect.Type{reflect.TypeOf(&config_parser.RoutingRule{})},
)
panic("code should not reach here")
unsupported:
return fmt.Errorf("unsupported section type %v", to.Type())
}

View File

@ -14,6 +14,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/dns"
"github.com/v2rayA/dae/component/outbound"
"github.com/v2rayA/dae/component/outbound/dialer"
"github.com/v2rayA/dae/component/routing"
@ -24,9 +25,9 @@ import (
"golang.org/x/sys/unix"
"net"
"net/netip"
"net/url"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
@ -44,10 +45,8 @@ type ControlPlane struct {
// TODO: add mutex?
outbounds []*outbound.DialerGroup
// mutex protects the dnsCache.
dnsCacheMu sync.Mutex
dnsCache map[string]*dnsCache
dnsUpstream DnsUpstreamRaw
dnsController *DnsController
onceNetworkReady sync.Once
dialMode consts.DialMode
@ -60,6 +59,7 @@ func NewControlPlane(
groups []config.Group,
routingA *config.Routing,
global *config.Global,
dnsConfig *config.Dns,
) (c *ControlPlane, err error) {
kernelVersion, e := internal.KernelVersion()
if e != nil {
@ -199,13 +199,6 @@ func NewControlPlane(
}
/// DialerGroups (outbounds).
checkDnsTcp := false
if !global.DnsUpstream.Empty {
if scheme, _, _, err := ParseDnsUpstream(global.DnsUpstream.Url); err == nil &&
scheme.ContainsTcp() {
checkDnsTcp = true
}
}
if global.AllowInsecure {
log.Warnln("AllowInsecure is enabled, but it is not recommended. Please make sure you have to turn it on.")
}
@ -215,7 +208,7 @@ func NewControlPlane(
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: global.UdpCheckDns},
CheckInterval: global.CheckInterval,
CheckTolerance: global.CheckTolerance,
CheckDnsTcp: checkDnsTcp,
CheckDnsTcp: true,
AllowInsecure: global.AllowInsecure,
}
outbounds := []*outbound.DialerGroup{
@ -237,12 +230,12 @@ func NewControlPlane(
dialerSet := outbound.NewDialerSetFromLinks(option, tagToNodeList)
for _, group := range groups {
// Parse policy.
policy, err := outbound.NewDialerSelectionPolicyFromGroupParam(&group.Param)
policy, err := outbound.NewDialerSelectionPolicyFromGroupParam(&group)
if err != nil {
return nil, fmt.Errorf("failed to create group %v: %w", group.Name, err)
}
// Filter nodes with user given filters.
dialers, err := dialerSet.Filter(group.Param.Filter)
dialers, err := dialerSet.Filter(group.Filter)
if err != nil {
return nil, fmt.Errorf(`failed to create group "%v": %w`, group.Name, err)
}
@ -276,7 +269,7 @@ func NewControlPlane(
outboundId2Name[uint8(i)] = o.Name
}
core.outboundId2Name = outboundId2Name
builder := NewRoutingMatcherBuilder(outboundName2Id, &bpf)
// Apply rules optimizers.
var rules []*config_parser.RoutingRule
if rules, err = routing.ApplyRulesOptimizers(routingA.Rules,
&routing.RefineFunctionParamKeyOptimizer{},
@ -294,18 +287,24 @@ func NewControlPlane(
}
log.Debugf("RoutingA:\n%vfallback: %v\n", debugBuilder.String(), routingA.Fallback)
}
if err = routing.ApplyMatcherBuilder(log, builder, rules, routingA.Fallback); err != nil {
return nil, fmt.Errorf("ApplyMatcherBuilder: %w", err)
// Parse rules and build.
builder, err := NewRoutingMatcherBuilder(log, rules, outboundName2Id, &bpf, routingA.Fallback)
if err != nil {
return nil, fmt.Errorf("NewRoutingMatcherBuilder: %w", err)
}
if err = builder.BuildKernspace(); err != nil {
return nil, fmt.Errorf("RoutingMatcherBuilder.BuildKernspace: %w", err)
}
routingMatcher, err := builder.BuildUserspace()
routingMatcher, err := builder.BuildUserspace(core.bpf.LpmArrayMap)
if err != nil {
return nil, fmt.Errorf("RoutingMatcherBuilder.BuildKernspace: %w", err)
}
/// Dial mode.
dialMode, err := consts.ParseDialMode(global.DialMode)
if err != nil {
return nil, err
}
c = &ControlPlane{
log: log,
@ -313,50 +312,72 @@ func NewControlPlane(
deferFuncs: nil,
listenIp: "0.0.0.0",
outbounds: outbounds,
dnsCacheMu: sync.Mutex{},
dnsCache: make(map[string]*dnsCache),
dnsUpstream: DnsUpstreamRaw{
Raw: global.DnsUpstream,
FinishInitCallback: nil,
},
dialMode: dialMode,
routingMatcher: routingMatcher,
}
/// DNS upstream
c.dnsUpstream.FinishInitCallback = c.finishInitDnsUpstreamResolve
// Try to invoke once to avoid dns leaking at the very beginning.
_, _ = c.dnsUpstream.GetUpstream()
/// DNS upstream.
dnsUpstream, err := dns.New(log, dnsConfig, &dns.NewOption{
UpstreamReadyCallback: c.dnsUpstreamReadyCallback,
})
if err != nil {
return nil, err
}
/// Dns controller.
c.dnsController, err = NewDnsController(dnsUpstream, &DnsControllerOption{
Log: log,
CacheAccessCallback: func(cache *DnsCache) (err error) {
// Write mappings into eBPF map:
// IP record (from dns lookup) -> domain routing
if err = core.BatchUpdateDomainRouting(cache); err != nil {
return fmt.Errorf("BatchUpdateDomainRouting: %w", err)
}
return nil
},
NewCache: func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error) {
return &DnsCache{
DomainBitmap: c.routingMatcher.domainMatcher.MatchDomainBitmap(fqdn),
Answers: answers,
Deadline: deadline,
}, nil
},
BestDialerChooser: c.chooseBestDnsDialer,
})
// Call GC to release memory.
runtime.GC()
return c, nil
}
func (c *ControlPlane) finishInitDnsUpstreamResolve(raw common.UrlOrEmpty, dnsUpstream *DnsUpstream) (err error) {
func (c *ControlPlane) dnsUpstreamReadyCallback(raw *url.URL, dnsUpstream *dns.Upstream) (err error) {
/// Notify dialers to check.
c.onceNetworkReady.Do(func() {
for _, out := range c.outbounds {
for _, d := range out.Dialers {
d.NotifyCheck()
}
}
/// Updates dns cache to support domain routing for hostname of dns_upstream.
if !raw.Empty {
ip4in6 := dnsUpstream.Ip4.As16()
ip6 := dnsUpstream.Ip6.As16()
if err = c.core.bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{
Ip4: common.Ipv6ByteSliceToUint32Array(ip4in6[:]),
Ip6: common.Ipv6ByteSliceToUint32Array(ip6[:]),
HasIp4: dnsUpstream.Ip4.IsValid(),
HasIp6: dnsUpstream.Ip6.IsValid(),
Port: common.Htons(dnsUpstream.Port),
}, ebpf.UpdateAny); err != nil {
if dnsUpstream != nil {
// Control plane DNS routing.
if err = c.core.bpf.ParamMap.Update(consts.ControlPlaneDnsRoutingKey, uint32(1), ebpf.UpdateAny); err != nil {
return
}
} else {
// As-is.
if err = c.core.bpf.ParamMap.Update(consts.ControlPlaneDnsRoutingKey, uint32(0), ebpf.UpdateAny); err != nil {
return
}
}
})
if err != nil {
return err
}
/// Update dns cache to support domain routing for hostname of dns_upstream.
if dnsUpstream == nil {
return nil
}
/// Updates dns cache to support domain routing for hostname of dns_upstream.
// Ten years later.
deadline := time.Now().Add(24 * time.Hour * 365 * 10)
deadline := time.Now().Add(time.Hour * 24 * 365 * 10)
fqdn := dnsUpstream.Hostname
if !strings.HasSuffix(fqdn, ".") {
fqdn = fqdn + "."
@ -375,9 +396,8 @@ func (c *ControlPlane) finishInitDnsUpstreamResolve(raw common.UrlOrEmpty, dnsUp
A: dnsUpstream.Ip4.As4(),
},
}}
if err = c.UpdateDnsCache(dnsUpstream.Hostname, typ, answers, deadline); err != nil {
c = nil
return
if err = c.dnsController.UpdateDnsCache(dnsUpstream.Hostname, typ, answers, deadline); err != nil {
return err
}
}
@ -394,21 +414,7 @@ func (c *ControlPlane) finishInitDnsUpstreamResolve(raw common.UrlOrEmpty, dnsUp
AAAA: dnsUpstream.Ip6.As16(),
},
}}
if err = c.UpdateDnsCache(dnsUpstream.Hostname, typ, answers, deadline); err != nil {
c = nil
return
}
}
} else {
// Empty string. As-is.
if err = c.core.bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{
Ip4: [4]uint32{},
Ip6: [4]uint32{},
HasIp4: false,
HasIp6: false,
// Zero port indicates no element, because bpf_map_lookup_elem cannot return 0 for map_type_array.
Port: 0,
}, ebpf.UpdateAny); err != nil {
if err = c.dnsController.UpdateDnsCache(dnsUpstream.Hostname, typ, answers, deadline); err != nil {
return err
}
}
@ -421,9 +427,8 @@ func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip
if !outbound.IsReserved() && domain != "" {
switch c.dialMode {
case consts.DialMode_Domain:
dstIp := common.ConvergeIp(dst.Addr())
cache := c.lookupDnsRespCache(domain, common.AddrToDnsType(dstIp))
if cache != nil && cache.IncludeIp(dstIp) {
cache := c.dnsController.LookupDnsRespCache(domain, common.AddrToDnsType(dst.Addr()))
if cache != nil && cache.IncludeIp(dst.Addr()) {
mode = consts.DialMode_Domain
}
case consts.DialMode_DomainPlus:
@ -552,7 +557,7 @@ func (c *ControlPlane) ListenAndServe(port uint16) (err error) {
} else {
realDst = pktDst
}
if e := c.handlePkt(udpConn, data, src, pktDst, realDst, routingResult); e != nil {
if e := c.handlePkt(udpConn, data, common.ConvergeAddrPort(src), common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult); e != nil {
c.log.Warnln("handlePkt:", e)
}
}(newBuf, src)
@ -562,6 +567,103 @@ func (c *ControlPlane) ListenAndServe(port uint16) (err error) {
return nil
}
func (c *ControlPlane) chooseBestDnsDialer(
req *udpRequest,
dnsUpstream *dns.Upstream,
) (*dialArgument, error) {
/// Choose the best l4proto+ipversion dialer, and change taregt DNS to the best ipversion DNS upstream for DNS request.
// Get available ipversions and l4protos for DNS upstream.
ipversions, l4protos := dnsUpstream.SupportedNetworks()
var (
bestLatency time.Duration
l4proto consts.L4ProtoStr
ipversion consts.IpVersionStr
bestDialer *dialer.Dialer
bestOutbound *outbound.DialerGroup
bestTarget netip.AddrPort
dialMark uint32
)
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"ipversions": ipversions,
"l4protos": l4protos,
"upstream": dnsUpstream.String(),
}).Traceln("Choose DNS path")
}
// Get the min latency path.
networkType := dialer.NetworkType{
IsDns: true,
}
for _, ver := range ipversions {
for _, proto := range l4protos {
networkType.L4Proto = proto
networkType.IpVersion = ver
var dAddr netip.Addr
switch ver {
case consts.IpVersionStr_4:
dAddr = dnsUpstream.Ip4
case consts.IpVersionStr_6:
dAddr = dnsUpstream.Ip6
default:
return nil, fmt.Errorf("unexpected ipversion: %v", ver)
}
outboundIndex, mark, err := c.Route(req.realSrc, netip.AddrPortFrom(dAddr, dnsUpstream.Port), "", proto.ToL4ProtoType(), req.routingResult)
if err != nil {
return nil, err
}
// Already "must direct".
if outboundIndex == consts.OutboundMustDirect {
outboundIndex = consts.OutboundDirect
}
if int(outboundIndex) >= len(c.outbounds) {
return nil, fmt.Errorf("bad outbound index: %v", outboundIndex)
}
dialerGroup := c.outbounds[outboundIndex]
d, latency, err := dialerGroup.Select(&networkType)
if err != nil {
continue
}
//if c.log.IsLevelEnabled(logrus.TraceLevel) {
// c.log.WithFields(logrus.Fields{
// "name": d.Name(),
// "latency": latency,
// "network": networkType.String(),
// "outbound": dialerGroup.Name,
// }).Traceln("Choice")
//}
if bestDialer == nil || latency < bestLatency {
bestDialer = d
bestOutbound = dialerGroup
bestLatency = latency
l4proto = proto
ipversion = ver
dialMark = mark
if bestLatency == 0 {
break
}
}
}
}
if bestDialer == nil {
return nil, fmt.Errorf("no proper dialer for DNS upstream: %v", dnsUpstream.String())
}
switch ipversion {
case consts.IpVersionStr_4:
bestTarget = netip.AddrPortFrom(dnsUpstream.Ip4, dnsUpstream.Port)
case consts.IpVersionStr_6:
bestTarget = netip.AddrPortFrom(dnsUpstream.Ip6, dnsUpstream.Port)
}
return &dialArgument{
l4proto: l4proto,
ipversion: ipversion,
bestDialer: bestDialer,
bestOutbound: bestOutbound,
bestTarget: bestTarget,
mark: dialMark,
}, nil
}
func (c *ControlPlane) Close() (err error) {
// Invoke defer funcs in reverse order.
for i := len(c.deferFuncs) - 1; i >= 0; i-- {

View File

@ -11,11 +11,14 @@ import (
ciliumLink "github.com/cilium/ebpf/link"
"github.com/safchain/ethtool"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
internal "github.com/v2rayA/dae/pkg/ebpf_internal"
"github.com/vishvananda/netlink"
"golang.org/x/net/dns/dnsmessage"
"golang.org/x/sys/unix"
"net"
"net/netip"
"os"
"regexp"
)
@ -415,3 +418,42 @@ func (c *ControlPlaneCore) bindWan(ifname string) error {
})
return nil
}
// BatchUpdateDomainRouting update bpf map domain_routing. Since one IP may have multiple domains, this function should
// be invoked every A/AAAA-record lookup.
func (c *ControlPlaneCore) BatchUpdateDomainRouting(cache *DnsCache) error {
// Parse ips from DNS resp answers.
var ips []netip.Addr
for _, ans := range cache.Answers {
switch ans.Header.Type {
case dnsmessage.TypeA:
ips = append(ips, netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A))
case dnsmessage.TypeAAAA:
ips = append(ips, netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA))
}
}
if len(ips) == 0 {
return nil
}
// Update bpf map.
// Construct keys and vals, and BpfMapBatchUpdate.
var keys [][4]uint32
var vals []bpfDomainRouting
for _, ip := range ips {
ip6 := ip.As16()
keys = append(keys, common.Ipv6ByteSliceToUint32Array(ip6[:]))
r := bpfDomainRouting{}
if len(cache.DomainBitmap) != len(r.Bitmap) {
return fmt.Errorf("domain bitmap length not sync with kern program")
}
copy(r.Bitmap[:], cache.DomainBitmap)
vals = append(vals, r)
}
if _, err := BpfMapBatchUpdate(c.bpf.DomainRoutingMap, keys, vals, &ebpf.BatchOptions{
ElemFlags: uint64(ebpf.UpdateAny),
}); err != nil {
return err
}
return nil
}

View File

@ -1,370 +0,0 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2022-2023, v2rayA Organization <team@v2raya.org>
*/
package control
import (
"encoding/binary"
"errors"
"fmt"
"github.com/cilium/ebpf"
"github.com/mohae/deepcopy"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"golang.org/x/net/dns/dnsmessage"
"hash/fnv"
"math/rand"
"net/netip"
"strings"
"time"
)
var (
SuspectedRushAnswerError = fmt.Errorf("suspected DNS rush-answer")
UnsupportedQuestionTypeError = fmt.Errorf("unsupported question type")
)
type dnsCache struct {
DomainBitmap []uint32
Answers []dnsmessage.Resource
Deadline time.Time
}
func (c *dnsCache) FillInto(req *dnsmessage.Message) {
req.Answers = deepcopy.Copy(c.Answers).([]dnsmessage.Resource)
// Align question and answer Name.
if len(req.Questions) > 0 {
q := req.Questions[0]
for i := range req.Answers {
if strings.EqualFold(req.Answers[i].Header.Name.String(), q.Name.String()) {
req.Answers[i].Header.Name.Data = q.Name.Data
}
}
}
req.RCode = dnsmessage.RCodeSuccess
req.Response = true
req.RecursionAvailable = true
req.Truncated = false
}
func (c *dnsCache) IncludeIp(ip netip.Addr) bool {
ip = common.ConvergeIp(ip)
for _, ans := range c.Answers {
switch body := ans.Body.(type) {
case *dnsmessage.AResource:
if !ip.Is4() {
continue
}
if netip.AddrFrom4(body.A) == ip {
return true
}
case *dnsmessage.AAAAResource:
if !ip.Is6() {
continue
}
if netip.AddrFrom16(body.AAAA) == ip {
return true
}
}
}
return false
}
// BatchUpdateDomainRouting update bpf map domain_routing. Since one IP may have multiple domains, this function should
// be invoked every A/AAAA-record lookup.
func (c *ControlPlane) BatchUpdateDomainRouting(cache *dnsCache) error {
// Parse ips from DNS resp answers.
var ips []netip.Addr
for _, ans := range cache.Answers {
switch ans.Header.Type {
case dnsmessage.TypeA:
ips = append(ips, netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A))
case dnsmessage.TypeAAAA:
ips = append(ips, netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA))
}
}
if len(ips) == 0 {
return nil
}
// Update bpf map.
// Construct keys and vals, and BpfMapBatchUpdate.
var keys [][4]uint32
var vals []bpfDomainRouting
for _, ip := range ips {
ip6 := ip.As16()
keys = append(keys, common.Ipv6ByteSliceToUint32Array(ip6[:]))
vals = append(vals, bpfDomainRouting{
Bitmap: [3]uint32{},
})
if len(cache.DomainBitmap) != len(vals[len(vals)-1].Bitmap) {
return fmt.Errorf("domain bitmap length not sync with kern program")
}
copy(vals[len(vals)-1].Bitmap[:], cache.DomainBitmap)
}
if _, err := BpfMapBatchUpdate(c.core.bpf.DomainRoutingMap, keys, vals, &ebpf.BatchOptions{
ElemFlags: uint64(ebpf.UpdateAny),
}); err != nil {
return err
}
return nil
}
func (c *ControlPlane) lookupDnsRespCache(domain string, t dnsmessage.Type) (cache *dnsCache) {
now := time.Now()
// To fqdn.
if !strings.HasSuffix(domain, ".") {
domain = domain + "."
}
c.dnsCacheMu.Lock()
cache, ok := c.dnsCache[strings.ToLower(domain)+t.String()]
c.dnsCacheMu.Unlock()
if ok && cache.Deadline.After(now) {
return cache
}
return nil
}
func (c *ControlPlane) LookupDnsRespCache_(msg *dnsmessage.Message) (resp []byte) {
if len(msg.Questions) == 0 {
return nil
}
q := msg.Questions[0]
if msg.Response {
return nil
}
switch q.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
default:
return nil
}
cache := c.lookupDnsRespCache(q.Name.String(), q.Type)
if cache != nil {
cache.FillInto(msg)
b, err := msg.Pack()
if err != nil {
c.log.Warnf("failed to pack: %v", err)
return nil
}
if err = c.BatchUpdateDomainRouting(cache); err != nil {
c.log.Warnf("failed to BatchUpdateDomainRouting: %v", err)
return nil
}
return b
}
return nil
}
// FlipDnsQuestionCase is used to reduce dns pollution.
func FlipDnsQuestionCase(dm *dnsmessage.Message) {
if len(dm.Questions) == 0 {
return
}
q := &dm.Questions[0]
// For reproducibility, we use dm.ID as input and add some entropy to make the results more discrete.
h := fnv.New64()
var buf [4]byte
binary.BigEndian.PutUint16(buf[:], dm.ID)
h.Write(buf[:2])
binary.BigEndian.PutUint32(buf[:], 20230204) // entropy
h.Write(buf[:])
r := rand.New(rand.NewSource(int64(h.Sum64())))
perm := r.Perm(int(q.Name.Length))
for i := 0; i < int(q.Name.Length/3); i++ {
j := perm[i]
// Upper to lower; lower to upper.
if q.Name.Data[j] >= 'a' && q.Name.Data[j] <= 'z' {
q.Name.Data[j] -= 'a' - 'A'
} else if q.Name.Data[j] >= 'A' && q.Name.Data[j] <= 'Z' {
q.Name.Data[j] += 'a' - 'A'
}
}
}
// EnsureAdditionalOpt makes sure there is additional record OPT in the request.
func EnsureAdditionalOpt(dm *dnsmessage.Message, isReqAdd bool) (bool, error) {
// Check healthy resp.
if isReqAdd == dm.Response || dm.RCode != dnsmessage.RCodeSuccess || len(dm.Questions) == 0 {
return false, UnsupportedQuestionTypeError
}
q := dm.Questions[0]
switch q.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
default:
return false, UnsupportedQuestionTypeError
}
for _, ad := range dm.Additionals {
if ad.Header.Type == dnsmessage.TypeOPT {
// Already has additional record OPT.
return true, nil
}
}
if !isReqAdd {
return false, nil
}
// Add one.
dm.Additionals = append(dm.Additionals, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName("."),
Type: dnsmessage.TypeOPT,
Class: 512, TTL: 0, Length: 0,
},
Body: &dnsmessage.OPTResource{
Options: nil,
},
})
return false, nil
}
type RscWrapper struct {
Rsc dnsmessage.Resource
}
func (w RscWrapper) String() string {
return fmt.Sprintf("%v: %v", w.Rsc.Header.GoString(), w.Rsc.Body.GoString())
}
func FormatDnsRsc(ans []dnsmessage.Resource) (w []string) {
for _, a := range ans {
w = append(w, RscWrapper{Rsc: a}.String())
}
return w
}
// DnsRespHandler handle DNS resp. This function should be invoked when cache miss.
func (c *ControlPlane) DnsRespHandler(data []byte, validateRushAns bool) (newData []byte, err error) {
var msg dnsmessage.Message
if err = msg.Unpack(data); err != nil {
return nil, fmt.Errorf("unpack dns pkt: %w", err)
}
// Check healthy resp.
if !msg.Response || len(msg.Questions) == 0 {
return data, nil
}
FlipDnsQuestionCase(&msg)
q := msg.Questions[0]
// Align Name.
for i := range msg.Answers {
if strings.EqualFold(msg.Answers[i].Header.Name.String(), q.Name.String()) {
msg.Answers[i].Header.Name.Data = q.Name.Data
}
}
for i := range msg.Additionals {
if strings.EqualFold(msg.Additionals[i].Header.Name.String(), q.Name.String()) {
msg.Additionals[i].Header.Name.Data = q.Name.Data
}
}
for i := range msg.Authorities {
if strings.EqualFold(msg.Authorities[i].Header.Name.String(), q.Name.String()) {
msg.Authorities[i].Header.Name.Data = q.Name.Data
}
}
// Check suc resp.
if msg.RCode != dnsmessage.RCodeSuccess {
return msg.Pack()
}
// Check req type.
switch q.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
default:
return msg.Pack()
}
// Set ttl.
var ttl uint32
for i := range msg.Answers {
if ttl == 0 {
ttl = msg.Answers[i].Header.TTL
}
// Set TTL = zero. This requests applications must resend every request.
// However, it may be not defined in the standard.
msg.Answers[i].Header.TTL = 0
}
// Check if there is any A/AAAA record.
var hasIpRecord bool
loop:
for i := range msg.Answers {
switch msg.Answers[i].Header.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
hasIpRecord = true
break loop
}
}
if !hasIpRecord {
return msg.Pack()
}
if validateRushAns {
exist, e := EnsureAdditionalOpt(&msg, false)
if e != nil && !errors.Is(e, UnsupportedQuestionTypeError) {
c.log.Warnf("EnsureAdditionalOpt: %v", e)
}
if e == nil && !exist {
// Additional record OPT in the request was ensured, and in normal case the resp should also set it.
// This DNS packet may be a rush-answer, and we should reject it.
c.log.WithFields(logrus.Fields{
"ques": q,
"addi": FormatDnsRsc(msg.Additionals),
"ans": FormatDnsRsc(msg.Answers),
}).Traceln("DNS rush-answer detected")
return nil, SuspectedRushAnswerError
}
}
// Update dnsCache.
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"qname": q.Name,
"rcode": msg.RCode,
"ans": FormatDnsRsc(msg.Answers),
"auth": FormatDnsRsc(msg.Authorities),
"addi": FormatDnsRsc(msg.Additionals),
}).Tracef("Update DNS record cache")
}
if err = c.UpdateDnsCache(q.Name.String(), q.Type, msg.Answers, time.Now().Add(time.Duration(ttl)*time.Second+DnsNatTimeout)); err != nil {
return nil, err
}
// Pack to get newData.
return msg.Pack()
}
func (c *ControlPlane) UpdateDnsCache(host string, typ dnsmessage.Type, answers []dnsmessage.Resource, deadline time.Time) (err error) {
var fqdn string
if strings.HasSuffix(host, ".") {
fqdn = host
host = host[:len(host)-1]
} else {
fqdn = host + "."
}
// Bypass pure IP.
if _, err = netip.ParseAddr(host); err == nil {
return nil
}
cacheKey := fqdn + typ.String()
c.dnsCacheMu.Lock()
cache, ok := c.dnsCache[cacheKey]
if ok {
c.dnsCacheMu.Unlock()
cache.Deadline = deadline
cache.Answers = answers
} else {
cache = &dnsCache{
DomainBitmap: c.routingMatcher.domainMatcher.MatchDomainBitmap(fqdn),
Answers: answers,
Deadline: deadline,
}
c.dnsCache[cacheKey] = cache
c.dnsCacheMu.Unlock()
}
if err = c.BatchUpdateDomainRouting(cache); err != nil {
return fmt.Errorf("BatchUpdateDomainRouting: %w", err)
}
return nil
}

59
control/dns_cache.go Normal file
View File

@ -0,0 +1,59 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2022-2023, v2rayA Organization <team@v2raya.org>
*/
package control
import (
"github.com/mohae/deepcopy"
"golang.org/x/net/dns/dnsmessage"
"net/netip"
"strings"
"time"
)
type DnsCache struct {
DomainBitmap []uint32
Answers []dnsmessage.Resource
Deadline time.Time
}
func (c *DnsCache) FillInto(req *dnsmessage.Message) {
req.Answers = deepcopy.Copy(c.Answers).([]dnsmessage.Resource)
// Align question and answer Name.
if len(req.Questions) > 0 {
q := req.Questions[0]
for i := range req.Answers {
if strings.EqualFold(req.Answers[i].Header.Name.String(), q.Name.String()) {
req.Answers[i].Header.Name.Data = q.Name.Data
}
}
}
req.RCode = dnsmessage.RCodeSuccess
req.Response = true
req.RecursionAvailable = true
req.Truncated = false
}
func (c *DnsCache) IncludeIp(ip netip.Addr) bool {
for _, ans := range c.Answers {
switch body := ans.Body.(type) {
case *dnsmessage.AResource:
if !ip.Is4() {
continue
}
if netip.AddrFrom4(body.A) == ip {
return true
}
case *dnsmessage.AAAAResource:
if !ip.Is6() {
continue
}
if netip.AddrFrom16(body.AAAA) == ip {
return true
}
}
}
return false
}

557
control/dns_control.go Normal file
View File

@ -0,0 +1,557 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package control
import (
"encoding/binary"
"errors"
"fmt"
"github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pool"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/common/netutils"
"github.com/v2rayA/dae/component/dns"
"github.com/v2rayA/dae/component/outbound"
"github.com/v2rayA/dae/component/outbound/dialer"
"golang.org/x/net/dns/dnsmessage"
"io"
"net"
"net/netip"
"strings"
"sync"
"time"
)
const (
MaxDnsLookupDepth = 3
)
var (
SuspectedRushAnswerError = fmt.Errorf("suspected DNS rush-answer")
UnsupportedQuestionTypeError = fmt.Errorf("unsupported question type")
)
type DnsControllerOption struct {
Log *logrus.Logger
CacheAccessCallback func(cache *DnsCache) (err error)
NewCache func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error)
BestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error)
}
type DnsController struct {
routing *dns.Dns
log *logrus.Logger
cacheAccessCallback func(cache *DnsCache) (err error)
newCache func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error)
bestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error)
// mutex protects the dnsCache.
dnsCacheMu sync.Mutex
dnsCache map[string]*DnsCache
}
func NewDnsController(routing *dns.Dns, option *DnsControllerOption) (c *DnsController, err error) {
return &DnsController{
routing: routing,
log: option.Log,
cacheAccessCallback: option.CacheAccessCallback,
newCache: option.NewCache,
bestDialerChooser: option.BestDialerChooser,
dnsCacheMu: sync.Mutex{},
dnsCache: make(map[string]*DnsCache),
}, nil
}
func (c *DnsController) LookupDnsRespCache(domain string, t dnsmessage.Type) (cache *DnsCache) {
now := time.Now()
// To fqdn.
if !strings.HasSuffix(domain, ".") {
domain = domain + "."
}
c.dnsCacheMu.Lock()
cache, ok := c.dnsCache[strings.ToLower(domain)+t.String()]
c.dnsCacheMu.Unlock()
if ok && cache.Deadline.After(now) {
return cache
}
return nil
}
// LookupDnsRespCache_ will modify the msg in place.
func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Message) (resp []byte) {
if len(msg.Questions) == 0 {
return nil
}
q := msg.Questions[0]
if msg.Response {
return nil
}
switch q.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
default:
return nil
}
cache := c.LookupDnsRespCache(q.Name.String(), q.Type)
if cache != nil {
cache.FillInto(msg)
b, err := msg.Pack()
if err != nil {
c.log.Warnf("failed to pack: %v", err)
return nil
}
if err = c.cacheAccessCallback(cache); err != nil {
c.log.Warnf("failed to BatchUpdateDomainRouting: %v", err)
return nil
}
return b
}
return nil
}
// DnsRespHandler handle DNS resp.
func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMsg *dnsmessage.Message, err error) {
var msg dnsmessage.Message
if err = msg.Unpack(data); err != nil {
return nil, fmt.Errorf("unpack dns pkt: %w", err)
}
// Check healthy resp.
if !msg.Response || len(msg.Questions) == 0 {
return &msg, nil
}
q := msg.Questions[0]
// Check suc resp.
if msg.RCode != dnsmessage.RCodeSuccess {
return &msg, nil
}
// Check req type.
switch q.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
default:
return &msg, nil
}
// Set ttl.
var ttl uint32
for i := range msg.Answers {
if ttl == 0 {
ttl = msg.Answers[i].Header.TTL
}
// Set TTL = zero. This requests applications must resend every request.
// However, it may be not defined in the standard.
msg.Answers[i].Header.TTL = 0
}
// Check if there is any A/AAAA record.
var hasIpRecord bool
loop:
for i := range msg.Answers {
switch msg.Answers[i].Header.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
hasIpRecord = true
break loop
}
}
if !hasIpRecord {
return &msg, nil
}
if validateRushAns {
exist, e := EnsureAdditionalOpt(&msg, false)
if e != nil && !errors.Is(e, UnsupportedQuestionTypeError) {
c.log.Warnf("EnsureAdditionalOpt: %v", e)
}
if e == nil && !exist {
// Additional record OPT in the request was ensured, and in normal case the resp should also set it.
// This DNS packet may be a rush-answer, and we should reject it.
c.log.WithFields(logrus.Fields{
"ques": q,
"addition": FormatDnsRsc(msg.Additionals),
"ans": FormatDnsRsc(msg.Answers),
}).Traceln("DNS rush-answer detected")
return nil, SuspectedRushAnswerError
}
}
// Update DnsCache.
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"qname": q.Name,
"rcode": msg.RCode,
"ans": FormatDnsRsc(msg.Answers),
"auth": FormatDnsRsc(msg.Authorities),
"addition": FormatDnsRsc(msg.Additionals),
}).Tracef("Update DNS record cache")
}
if err = c.UpdateDnsCache(q.Name.String(), q.Type, msg.Answers, time.Now().Add(time.Duration(ttl)*time.Second+DnsNatTimeout)); err != nil {
return nil, err
}
// Pack to get newData.
return &msg, nil
}
func (c *DnsController) UpdateDnsCache(host string, typ dnsmessage.Type, answers []dnsmessage.Resource, deadline time.Time) (err error) {
var fqdn string
if strings.HasSuffix(host, ".") {
fqdn = host
host = host[:len(host)-1]
} else {
fqdn = host + "."
}
// Bypass pure IP.
if _, err = netip.ParseAddr(host); err == nil {
return nil
}
cacheKey := fqdn + typ.String()
c.dnsCacheMu.Lock()
cache, ok := c.dnsCache[cacheKey]
if ok {
c.dnsCacheMu.Unlock()
cache.Deadline = deadline
cache.Answers = answers
} else {
cache, err = c.newCache(fqdn, answers, deadline)
if err != nil {
c.dnsCacheMu.Unlock()
return err
}
c.dnsCache[cacheKey] = cache
c.dnsCacheMu.Unlock()
}
if err = c.cacheAccessCallback(cache); err != nil {
return err
}
return nil
}
func (c *DnsController) DnsRespHandlerFactory(req *udpRequest, validateRushAnsFunc func(from netip.AddrPort) bool) func(data []byte, from netip.AddrPort) (msg *dnsmessage.Message, err error) {
return func(data []byte, from netip.AddrPort) (msg *dnsmessage.Message, err error) {
// Do not return conn-unrelated err in this func.
validateRushAns := validateRushAnsFunc(from)
msg, err = c.DnsRespHandler(data, validateRushAns)
if err != nil {
if errors.Is(err, SuspectedRushAnswerError) {
if validateRushAns {
// Reject DNS rush-answer.
c.log.WithFields(logrus.Fields{
"from": from,
}).Tracef("DNS rush-answer rejected")
return nil, nil
}
} else {
if c.log.IsLevelEnabled(logrus.DebugLevel) {
c.log.Debugf("DnsRespHandler: %v", err)
}
return nil, err
}
}
return msg, nil
}
}
type udpRequest struct {
lanWanFlag consts.LanWanFlag
realSrc netip.AddrPort
realDst netip.AddrPort
src netip.AddrPort
lConn *net.UDPConn
routingResult *bpfRoutingResult
}
type dialArgument struct {
l4proto consts.L4ProtoStr
ipversion consts.IpVersionStr
bestDialer *dialer.Dialer
bestOutbound *outbound.DialerGroup
bestTarget netip.AddrPort
mark uint32
}
func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) (err error) {
if resp := c.LookupDnsRespCache_(dnsMessage); resp != nil {
// Send cache to client directly.
if err = sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn, req.lanWanFlag); err != nil {
return fmt.Errorf("failed to write cached DNS resp: %w", err)
}
if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 {
q := dnsMessage.Questions[0]
c.log.Tracef("UDP(DNS) %v <-> Cache: %v %v",
RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), strings.ToLower(q.Name.String()), q.Type,
)
}
return nil
}
// Make sure there is additional record OPT in the request to filter DNS rush-answer in the response process.
// Because rush-answer has no resp OPT. We can distinguish them from multiple responses.
// Note that additional record OPT may not be supported by home router either.
_, _ = EnsureAdditionalOpt(dnsMessage, true)
// Route request.
upstream, err := c.routing.RequestSelect(dnsMessage)
if err != nil {
return err
}
if c.log.IsLevelEnabled(logrus.TraceLevel) {
upstreamName := "asis"
if upstream != nil {
upstreamName = upstream.String()
}
c.log.WithFields(logrus.Fields{
"question": dnsMessage.Questions,
"upstream": upstreamName,
}).Traceln("Request to DNS upstream")
}
// Re-pack DNS packet.
data, err := dnsMessage.Pack()
if err != nil {
return fmt.Errorf("pack DNS packet: %w", err)
}
return c.dialSend(req, data, upstream, 0)
}
func (c *DnsController) dialSend(req *udpRequest, data []byte, upstream *dns.Upstream, invokingDepth int) (err error) {
if invokingDepth >= MaxDnsLookupDepth {
return fmt.Errorf("too deep DNS lookup invoking (depth: %v); there may be infinite loop in your DNS response routing", MaxDnsLookupDepth)
}
upstreamName := "asis"
if upstream == nil {
// As-is.
// As-is should not be valid in response routing, thus using connection realDest is reasonable.
var ip46 netutils.Ip46
if req.realDst.Addr().Is4() {
ip46.Ip4 = req.realDst.Addr()
} else {
ip46.Ip6 = req.realDst.Addr()
}
upstream = &dns.Upstream{
Scheme: "udp",
Hostname: req.realDst.Addr().String(),
Port: req.realDst.Port(),
Ip46: &ip46,
}
} else {
upstreamName = upstream.String()
}
// Select best dial arguments (outbound, dialer, l4proto, ipversion, etc.)
dialArgument, err := c.bestDialerChooser(req, upstream)
if err != nil {
return err
}
networkType := &dialer.NetworkType{
L4Proto: dialArgument.l4proto,
IpVersion: dialArgument.ipversion,
IsDns: true, // UDP relies on DNS check result.
}
// dnsRespHandler caches dns response and check rush answers.
dnsRespHandler := c.DnsRespHandlerFactory(req, func(from netip.AddrPort) bool {
// We only validate rush-ans when outbound is direct and pkt does not send to a home device.
// Because additional record OPT may not be supported by home router.
// So se should trust home devices even if they make rush-answer (or looks like).
return dialArgument.bestDialer.Name() == "direct" && !from.Addr().IsPrivate()
})
// Dial and send.
var respMsg *dnsmessage.Message
// defer in a recursive call will delay Close(), thus we Close() before
// the next recursive call. However, a connection cannot be closed twice.
// We should set a connClosed flag to avoid it.
var connClosed bool
var conn netproxy.Conn
// TODO: Rewritten domain should not use full-cone (such as VMess Packet Addr).
// Maybe we should set up a mapping for UDP: Dialer + Target Domain => Remote Resolved IP.
// However, games may not use QUIC for communication, thus we cannot use domain to dial, which is fine.
switch dialArgument.l4proto {
case consts.L4ProtoStr_UDP:
// Get udp endpoint.
// TODO: connection pool.
conn, err = dialArgument.bestDialer.Dial(
MagicNetwork("udp", dialArgument.mark),
dialArgument.bestTarget.String(),
)
if err != nil {
return fmt.Errorf("failed to dial '%v': %w", dialArgument.bestTarget, err)
}
defer func() {
if !connClosed {
conn.Close()
}
}()
_ = conn.SetDeadline(time.Now().Add(DnsNatTimeout))
_, err = conn.Write(data)
if err != nil {
if c.log.IsLevelEnabled(logrus.DebugLevel) {
c.log.WithFields(logrus.Fields{
"to": dialArgument.bestTarget.String(),
"pid": req.routingResult.Pid,
"pname": ProcessName2String(req.routingResult.Pname[:]),
"mac": Mac2String(req.routingResult.Mac[:]),
"from": req.realSrc.String(),
"network": networkType.String(),
"err": err.Error(),
}).Debugln("Failed to write UDP(DNS) packet request.")
}
return fmt.Errorf("failed to write UDP(DNS) packet request: %w", err)
}
// We can block here because we are in a coroutine.
respBuf := pool.Get(512)
defer pool.Put(respBuf)
for {
// Wait for response.
n, err := conn.Read(respBuf)
if err != nil {
return fmt.Errorf("failed to read from: %v (dialer: %v): %w", dialArgument.bestTarget, dialArgument.bestDialer.Name(), err)
}
respMsg, err = dnsRespHandler(respBuf[:n], dialArgument.bestTarget)
if err != nil {
return err
}
if respMsg != nil {
break
}
}
case consts.L4ProtoStr_TCP:
// We can block here because we are in a coroutine.
conn, err = dialArgument.bestDialer.Dial(MagicNetwork("tcp", dialArgument.mark), dialArgument.bestTarget.String())
if err != nil {
return fmt.Errorf("failed to dial proxy to tcp: %w", err)
}
defer func() {
if !connClosed {
conn.Close()
}
}()
_ = conn.SetDeadline(time.Now().Add(DnsNatTimeout))
// We should write two byte length in the front of TCP DNS request.
bReq := pool.Get(2 + len(data))
defer pool.Put(bReq)
binary.BigEndian.PutUint16(bReq, uint16(len(data)))
copy(bReq[2:], data)
_, err = conn.Write(bReq)
if err != nil {
return fmt.Errorf("failed to write DNS req: %w", err)
}
// Read two byte length.
if _, err = io.ReadFull(conn, bReq[:2]); err != nil {
return fmt.Errorf("failed to read DNS resp payload length: %w", err)
}
respLen := int(binary.BigEndian.Uint16(bReq))
// Try to reuse the buf.
var buf []byte
if len(bReq) < respLen {
buf = pool.Get(respLen)
defer pool.Put(buf)
} else {
buf = bReq
}
var n int
if n, err = io.ReadFull(conn, buf[:respLen]); err != nil {
return fmt.Errorf("failed to read DNS resp payload: %w", err)
}
respMsg, err = dnsRespHandler(buf[:n], dialArgument.bestTarget)
if respMsg == nil && err == nil {
err = fmt.Errorf("bad DNS response")
}
if err != nil {
return fmt.Errorf("failed to write DNS resp to client: %w", err)
}
default:
return fmt.Errorf("unexpected l4proto: %v", dialArgument.l4proto)
}
// Close conn before the recursive call.
conn.Close()
connClosed = true
// Route response.
upstreamIndex, nextUpstream, err := c.routing.ResponseSelect(respMsg, upstream)
if err != nil {
return err
}
switch upstreamIndex {
case consts.DnsResponseOutboundIndex_Accept:
// Accept.
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"question": respMsg.Questions,
"upstream": upstreamName,
}).Traceln("Accept")
}
case consts.DnsResponseOutboundIndex_Reject:
// Reject the request with empty answer.
respMsg.Answers = nil
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"question": respMsg.Questions,
"upstream": upstreamName,
}).Traceln("Reject with empty answer")
}
default:
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"question": respMsg.Questions,
"last_upstream": upstreamName,
"next_upstream": nextUpstream.String(),
}).Traceln("Change DNS upstream and resend")
}
return c.dialSend(req, data, nextUpstream, invokingDepth+1)
}
if upstreamIndex.IsReserved() && c.log.IsLevelEnabled(logrus.InfoLevel) {
var qname, qtype string
if len(respMsg.Questions) > 0 {
q := respMsg.Questions[0]
qname = strings.ToLower(q.Name.String())
qtype = q.Type.String()
}
fields := logrus.Fields{
"network": networkType.String(),
"outbound": dialArgument.bestOutbound.Name,
"policy": dialArgument.bestOutbound.GetSelectionPolicy(),
"dialer": dialArgument.bestDialer.Name(),
"qname": qname,
"qtype": qtype,
"pid": req.routingResult.Pid,
"pname": ProcessName2String(req.routingResult.Pname[:]),
"mac": Mac2String(req.routingResult.Mac[:]),
}
switch upstreamIndex {
case consts.DnsResponseOutboundIndex_Accept:
c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), RefineAddrPortToShow(dialArgument.bestTarget))
case consts.DnsResponseOutboundIndex_Reject:
c.log.WithFields(fields).Infof("%v -> reject", RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag))
default:
return fmt.Errorf("unknown upstream: %v", upstreamIndex.String())
}
}
data, err = respMsg.Pack()
if err != nil {
return err
}
if err = sendPkt(data, req.realDst, req.realSrc, req.src, req.lConn, req.lanWanFlag); err != nil {
return err
}
return nil
}

102
control/dns_utils.go Normal file
View File

@ -0,0 +1,102 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package control
import (
"encoding/binary"
"fmt"
"golang.org/x/net/dns/dnsmessage"
"hash/fnv"
"math/rand"
"net/netip"
"strings"
)
// FlipDnsQuestionCase is used to reduce dns pollution.
func FlipDnsQuestionCase(dm *dnsmessage.Message) {
if len(dm.Questions) == 0 {
return
}
q := &dm.Questions[0]
// For reproducibility, we use dm.ID as input and add some entropy to make the results more discrete.
h := fnv.New64()
var buf [4]byte
binary.BigEndian.PutUint16(buf[:], dm.ID)
h.Write(buf[:2])
binary.BigEndian.PutUint32(buf[:], 20230204) // entropy
h.Write(buf[:])
r := rand.New(rand.NewSource(int64(h.Sum64())))
perm := r.Perm(int(q.Name.Length))
for i := 0; i < int(q.Name.Length/3); i++ {
j := perm[i]
// Upper to lower; lower to upper.
if q.Name.Data[j] >= 'a' && q.Name.Data[j] <= 'z' {
q.Name.Data[j] -= 'a' - 'A'
} else if q.Name.Data[j] >= 'A' && q.Name.Data[j] <= 'Z' {
q.Name.Data[j] += 'a' - 'A'
}
}
}
// EnsureAdditionalOpt makes sure there is additional record OPT in the request.
func EnsureAdditionalOpt(dm *dnsmessage.Message, isReqAdd bool) (bool, error) {
// Check healthy resp.
if isReqAdd == dm.Response || dm.RCode != dnsmessage.RCodeSuccess || len(dm.Questions) == 0 {
return false, UnsupportedQuestionTypeError
}
q := dm.Questions[0]
switch q.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
default:
return false, UnsupportedQuestionTypeError
}
for _, ad := range dm.Additionals {
if ad.Header.Type == dnsmessage.TypeOPT {
// Already has additional record OPT.
return true, nil
}
}
if !isReqAdd {
return false, nil
}
// Add one.
dm.Additionals = append(dm.Additionals, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName("."),
Type: dnsmessage.TypeOPT,
Class: 512, TTL: 0, Length: 0,
},
Body: &dnsmessage.OPTResource{
Options: nil,
},
})
return false, nil
}
type RscWrapper struct {
Rsc dnsmessage.Resource
}
func (w RscWrapper) String() string {
var strBody string
switch body := w.Rsc.Body.(type) {
case *dnsmessage.AResource:
strBody = netip.AddrFrom4(body.A).String()
case *dnsmessage.AAAAResource:
strBody = netip.AddrFrom16(body.AAAA).String()
default:
strBody = body.GoString()
}
return fmt.Sprintf("%v(%v): %v", w.Rsc.Header.Name.String(), w.Rsc.Header.Type.String(), strBody)
}
func FormatDnsRsc(ans []dnsmessage.Resource) string {
var w []string
for _, a := range ans {
w = append(w, RscWrapper{Rsc: a}.String())
}
return strings.Join(w, "; ")
}

View File

@ -59,7 +59,7 @@
#define OUTBOUND_DIRECT 0
#define OUTBOUND_BLOCK 1
#define OUTBOUND_MUST_DIRECT 0xFC
#define OUTBOUND_CONTROL_PLANE_DIRECT 0xFD
#define OUTBOUND_CONTROL_PLANE_ROUTING 0xFD
#define OUTBOUND_LOGICAL_OR 0xFE
#define OUTBOUND_LOGICAL_AND 0xFF
#define OUTBOUND_LOGICAL_MASK 0xFE
@ -89,6 +89,7 @@ static const __u32 disable_l4_rx_checksum_key
__attribute__((unused, deprecated)) = 3;
static const __u32 control_plane_pid_key = 4;
static const __u32 control_plane_nat_direct_key = 5;
static const __u32 control_plane_dns_routing_key = 6;
// Outbound Connectivity Map:
@ -225,23 +226,6 @@ struct {
__uint(pinning, LIBBPF_PIN_BY_NAME);
} ipproto_hdrsize_map SEC(".maps");
// Dns upstream:
struct dns_upstream {
__be32 ip4[4];
__be32 ip6[4];
bool hasIp4;
bool hasIp6;
__be16 port;
};
struct {
__uint(type, BPF_MAP_TYPE_ARRAY);
__type(key, __u32);
__type(value, struct dns_upstream);
/// FIXME: l4proto is always udp.
__uint(max_entries, 1);
} dns_upstream_map SEC(".maps");
// Interface Ips:
struct if_params {
bool rx_cksm_offload;
@ -946,7 +930,7 @@ decap_after_udp_hdr(struct __sk_buff *skb, __u8 ipversion, __u8 ihl,
// low -> high: outbound(8b) mark(32b) unused(23b) sign(1b)
static __s64 __attribute__((noinline))
routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
const __be32 _daddr[4], const __be32 mac[4]) {
const __be32 daddr[4], const __be32 mac[4]) {
#define _l4proto_type flag[0]
#define _ipversion_type flag[1]
#define _pname &flag[2]
@ -957,7 +941,6 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
__u32 key = MatchType_L4Proto;
__u16 h_dport;
__u16 h_sport;
__u32 daddr[4];
/// TODO: BPF_MAP_UPDATE_BATCH ?
if (unlikely((ret = bpf_map_update_elem(&l4proto_ipversion_map, &key,
@ -992,27 +975,11 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
// Modify DNS upstream for routing.
if (h_dport == 53 && _l4proto_type == L4ProtoType_UDP) {
struct dns_upstream *upstream =
bpf_map_lookup_elem(&dns_upstream_map, &zero_key);
if (upstream && upstream->port != 0) {
h_dport = bpf_ntohs(upstream->port);
if (_ipversion_type == IpVersionType_4 && upstream->hasIp4) {
__builtin_memcpy(daddr, upstream->ip4, IPV6_BYTE_LENGTH);
} else if (_ipversion_type == IpVersionType_6 && upstream->hasIp6) {
__builtin_memcpy(daddr, upstream->ip6, IPV6_BYTE_LENGTH);
} else if (upstream->hasIp4) {
__builtin_memcpy(daddr, upstream->ip4, IPV6_BYTE_LENGTH);
} else if (upstream->hasIp6) {
__builtin_memcpy(daddr, upstream->ip6, IPV6_BYTE_LENGTH);
} else {
bpf_printk("bad dns upstream; use as-is.");
__builtin_memcpy(daddr, _daddr, IPV6_BYTE_LENGTH);
__u32 *control_plane_dns_routing =
bpf_map_lookup_elem(&param_map, &control_plane_dns_routing_key);
if (control_plane_dns_routing && *control_plane_dns_routing) {
return OUTBOUND_CONTROL_PLANE_ROUTING;
}
} else {
__builtin_memcpy(daddr, _daddr, IPV6_BYTE_LENGTH);
}
} else {
__builtin_memcpy(daddr, _daddr, IPV6_BYTE_LENGTH);
}
lpm_key_instance.trie_key.prefixlen = IPV6_BYTE_LENGTH * 8;
__builtin_memcpy(lpm_key_instance.data, daddr, IPV6_BYTE_LENGTH);
@ -1169,11 +1136,6 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
bpf_printk("MATCHED: match_set->type: %u, match_set->not: %d",
match_set->type, match_set->not );
#endif
if (match_set->outbound == OUTBOUND_DIRECT && h_dport == 53 &&
_l4proto_type == L4ProtoType_UDP) {
// DNS packet should go through control plane.
return OUTBOUND_CONTROL_PLANE_DIRECT | (match_set->mark << 8);
}
return match_set->outbound | (match_set->mark << 8);
}
bad_rule = false;
@ -1583,7 +1545,6 @@ int tproxy_wan_egress(struct __sk_buff *skb) {
bpf_skc_lookup_tcp(skb, &tuple, tuple_size, BPF_F_CURRENT_NETNS, 0);
if (sk) {
// Not a tproxy WAN response. It is a tproxy LAN response.
tproxy_response = false;
bpf_sk_release(sk);
return TC_ACT_OK;
}
@ -1594,6 +1555,9 @@ int tproxy_wan_egress(struct __sk_buff *skb) {
// Packets from tproxy port.
// We need to redirect it to original port.
// bpf_printk("tproxy_response: %pI6:%u", tuples.dip.u6_addr32,
// bpf_ntohs(tuples.dport));
// Write mac.
if ((ret = bpf_skb_store_bytes(skb, offsetof(struct ethhdr, h_dest),
ethh.h_source, sizeof(ethh.h_source), 0))) {
@ -1665,7 +1629,7 @@ int tproxy_wan_egress(struct __sk_buff *skb) {
struct dst_routing_result *dst =
bpf_map_lookup_elem(&tcp_dst_map, &key_src);
if (!dst) {
// Do not impact previous connections.
// Do not impact previous connections and server connections.
return TC_ACT_OK;
}
outbound = dst->routing_result.outbound;
@ -1978,7 +1942,7 @@ int tproxy_wan_ingress(struct __sk_buff *skb) {
return TC_ACT_SHOT;
}
// bpf_printk("real from: %pI4:%u", &ori_src.ip, bpf_ntohs(ori_src.port));
// bpf_printk("real from: %pI6:%u", ori_src.ip, bpf_ntohs(ori_src.port));
// Print packet in hex for debugging (checksum or something else).
// bpf_printk("UDP EGRESS OK");

View File

@ -9,79 +9,96 @@ import (
"encoding/binary"
"fmt"
"github.com/cilium/ebpf"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/routing"
"github.com/v2rayA/dae/component/routing/domain_matcher"
"github.com/v2rayA/dae/config"
"github.com/v2rayA/dae/pkg/config_parser"
"net/netip"
"strconv"
)
type RoutingMatcherBuilder struct {
*routing.DefaultMatcherBuilder
outboundName2Id map[string]uint8
bpf *bpfObjects
rules []bpfMatchSet
simulatedLpmTries [][]netip.Prefix
simulatedDomainSet []routing.DomainSet
err error
fallback *routing.Outbound
}
func NewRoutingMatcherBuilder(outboundName2Id map[string]uint8, bpf *bpfObjects) *RoutingMatcherBuilder {
return &RoutingMatcherBuilder{outboundName2Id: outboundName2Id, bpf: bpf}
func NewRoutingMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, outboundName2Id map[string]uint8, bpf *bpfObjects, fallback config.FunctionOrString) (b *RoutingMatcherBuilder, err error) {
b = &RoutingMatcherBuilder{outboundName2Id: outboundName2Id, bpf: bpf}
rulesBuilder := routing.NewRulesBuilder(log)
rulesBuilder.RegisterFunctionParser(consts.Function_Domain, routing.PlainParserFactory(b.addDomain))
rulesBuilder.RegisterFunctionParser(consts.Function_Ip, routing.IpParserFactory(b.addIp))
rulesBuilder.RegisterFunctionParser(consts.Function_SourceIp, routing.IpParserFactory(b.addSourceIp))
rulesBuilder.RegisterFunctionParser(consts.Function_Port, routing.PortRangeParserFactory(b.addPort))
rulesBuilder.RegisterFunctionParser(consts.Function_SourcePort, routing.PortRangeParserFactory(b.addSourcePort))
rulesBuilder.RegisterFunctionParser(consts.Function_L4Proto, routing.L4ProtoParserFactory(b.addL4Proto))
rulesBuilder.RegisterFunctionParser(consts.Function_Mac, routing.MacParserFactory(b.addSourceMac))
rulesBuilder.RegisterFunctionParser(consts.Function_ProcessName, routing.ProcessNameParserFactory(b.addProcessName))
rulesBuilder.RegisterFunctionParser(consts.Function_IpVersion, routing.IpVersionParserFactory(b.addIpVersion))
if err = rulesBuilder.Apply(rules); err != nil {
return nil, err
}
func (b *RoutingMatcherBuilder) OutboundToId(outbound string) uint8 {
if err = b.addFallback(fallback); err != nil {
return nil, err
}
return b, nil
}
func (b *RoutingMatcherBuilder) outboundToId(outbound string) (uint8, error) {
var outboundId uint8
switch outbound {
case routing.FakeOutbound_MUST_DIRECT:
case consts.OutboundMustDirect.String():
outboundId = uint8(consts.OutboundMustDirect)
case routing.FakeOutbound_AND:
outboundId = uint8(consts.OutboundLogicalAnd)
case routing.FakeOutbound_OR:
case consts.OutboundLogicalOr.String():
outboundId = uint8(consts.OutboundLogicalOr)
case consts.OutboundLogicalAnd.String():
outboundId = uint8(consts.OutboundLogicalAnd)
default:
var ok bool
outboundId, ok = b.outboundName2Id[outbound]
if !ok {
b.err = fmt.Errorf("outbound (group) %v not found; please define it in section \"group\"", strconv.Quote(outbound))
return 0, fmt.Errorf("outbound (group) %v not found; please define it in section \"group\"", strconv.Quote(outbound))
}
}
return outboundId
return outboundId, nil
}
func (b *RoutingMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound *routing.Outbound) {
if b.err != nil {
return
}
func (b *RoutingMatcherBuilder) addDomain(f *config_parser.Function, key string, values []string, outbound *routing.Outbound) (err error) {
switch consts.RoutingDomainKey(key) {
case consts.RoutingDomainKey_Regex,
consts.RoutingDomainKey_Full,
consts.RoutingDomainKey_Keyword,
consts.RoutingDomainKey_Suffix:
default:
b.err = fmt.Errorf("AddDomain: unsupported key: %v", key)
return
return fmt.Errorf("addDomain: unsupported key: %v", key)
}
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
Key: consts.RoutingDomainKey(key),
RuleIndex: len(b.rules),
Domains: values,
})
outboundId, err := b.outboundToId(outbound.Name)
if err != nil {
return err
}
b.rules = append(b.rules, bpfMatchSet{
Type: uint8(consts.MatchType_DomainSet),
Not: f.Not,
Outbound: b.OutboundToId(outbound.Name),
Outbound: outboundId,
Mark: outbound.Mark,
})
return nil
}
func (b *RoutingMatcherBuilder) AddSourceMac(f *config_parser.Function, macAddrs [][6]byte, outbound *routing.Outbound) {
if b.err != nil {
return
}
func (b *RoutingMatcherBuilder) addSourceMac(f *config_parser.Function, macAddrs [][6]byte, outbound *routing.Outbound) (err error) {
var addr16 [16]byte
values := make([]netip.Prefix, 0, len(macAddrs))
for _, mac := range macAddrs {
@ -91,41 +108,51 @@ func (b *RoutingMatcherBuilder) AddSourceMac(f *config_parser.Function, macAddrs
}
lpmTrieIndex := len(b.simulatedLpmTries)
b.simulatedLpmTries = append(b.simulatedLpmTries, values)
outboundId, err := b.outboundToId(outbound.Name)
if err != nil {
return err
}
set := bpfMatchSet{
Value: [16]byte{},
Type: uint8(consts.MatchType_Mac),
Not: f.Not,
Outbound: b.OutboundToId(outbound.Name),
Outbound: outboundId,
Mark: outbound.Mark,
}
binary.LittleEndian.PutUint32(set.Value[:], uint32(lpmTrieIndex))
b.rules = append(b.rules, set)
return nil
}
func (b *RoutingMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound *routing.Outbound) {
if b.err != nil {
return
}
func (b *RoutingMatcherBuilder) addIp(f *config_parser.Function, values []netip.Prefix, outbound *routing.Outbound) (err error) {
lpmTrieIndex := len(b.simulatedLpmTries)
b.simulatedLpmTries = append(b.simulatedLpmTries, values)
outboundId, err := b.outboundToId(outbound.Name)
if err != nil {
return err
}
set := bpfMatchSet{
Value: [16]byte{},
Type: uint8(consts.MatchType_IpSet),
Not: f.Not,
Outbound: b.OutboundToId(outbound.Name),
Outbound: outboundId,
Mark: outbound.Mark,
}
binary.LittleEndian.PutUint32(set.Value[:], uint32(lpmTrieIndex))
b.rules = append(b.rules, set)
return nil
}
func (b *RoutingMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, outbound *routing.Outbound) {
func (b *RoutingMatcherBuilder) addPort(f *config_parser.Function, values [][2]uint16, outbound *routing.Outbound) (err error) {
for i, value := range values {
outboundName := routing.FakeOutbound_OR
outboundName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
outboundName = outbound.Name
}
outboundId, err := b.outboundToId(outboundName)
if err != nil {
return err
}
b.rules = append(b.rules, bpfMatchSet{
Type: uint8(consts.MatchType_Port),
Value: _bpfPortRange{
@ -133,35 +160,42 @@ func (b *RoutingMatcherBuilder) AddPort(f *config_parser.Function, values [][2]u
PortEnd: value[1],
}.Encode(),
Not: f.Not,
Outbound: b.OutboundToId(outboundName),
Outbound: outboundId,
Mark: outbound.Mark,
})
}
return nil
}
func (b *RoutingMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *routing.Outbound) {
if b.err != nil {
return
}
func (b *RoutingMatcherBuilder) addSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *routing.Outbound) (err error) {
lpmTrieIndex := len(b.simulatedLpmTries)
b.simulatedLpmTries = append(b.simulatedLpmTries, values)
outboundId, err := b.outboundToId(outbound.Name)
if err != nil {
return err
}
set := bpfMatchSet{
Value: [16]byte{},
Type: uint8(consts.MatchType_SourceIpSet),
Not: f.Not,
Outbound: b.OutboundToId(outbound.Name),
Outbound: outboundId,
Mark: outbound.Mark,
}
binary.LittleEndian.PutUint32(set.Value[:], uint32(lpmTrieIndex))
b.rules = append(b.rules, set)
return nil
}
func (b *RoutingMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound *routing.Outbound) {
func (b *RoutingMatcherBuilder) addSourcePort(f *config_parser.Function, values [][2]uint16, outbound *routing.Outbound) (err error) {
for i, value := range values {
outboundName := routing.FakeOutbound_OR
outboundName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
outboundName = outbound.Name
}
outboundId, err := b.outboundToId(outboundName)
if err != nil {
return err
}
b.rules = append(b.rules, bpfMatchSet{
Type: uint8(consts.MatchType_SourcePort),
Value: _bpfPortRange{
@ -169,70 +203,83 @@ func (b *RoutingMatcherBuilder) AddSourcePort(f *config_parser.Function, values
PortEnd: value[1],
}.Encode(),
Not: f.Not,
Outbound: b.OutboundToId(outboundName),
Outbound: outboundId,
Mark: outbound.Mark,
})
}
return nil
}
func (b *RoutingMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *routing.Outbound) {
if b.err != nil {
return
func (b *RoutingMatcherBuilder) addL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *routing.Outbound) (err error) {
outboundId, err := b.outboundToId(outbound.Name)
if err != nil {
return err
}
b.rules = append(b.rules, bpfMatchSet{
Value: [16]byte{byte(values)},
Type: uint8(consts.MatchType_L4Proto),
Not: f.Not,
Outbound: b.OutboundToId(outbound.Name),
Outbound: outboundId,
Mark: outbound.Mark,
})
return nil
}
func (b *RoutingMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *routing.Outbound) {
if b.err != nil {
return
func (b *RoutingMatcherBuilder) addIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *routing.Outbound) (err error) {
outboundId, err := b.outboundToId(outbound.Name)
if err != nil {
return err
}
b.rules = append(b.rules, bpfMatchSet{
Value: [16]byte{byte(values)},
Type: uint8(consts.MatchType_IpVersion),
Not: f.Not,
Outbound: b.OutboundToId(outbound.Name),
Outbound: outboundId,
Mark: outbound.Mark,
})
return nil
}
func (b *RoutingMatcherBuilder) AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *routing.Outbound) {
func (b *RoutingMatcherBuilder) addProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *routing.Outbound) (err error) {
for i, value := range values {
outboundName := routing.FakeOutbound_OR
outboundName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
outboundName = outbound.Name
}
outboundId, err := b.outboundToId(outboundName)
if err != nil {
return err
}
matchSet := bpfMatchSet{
Type: uint8(consts.MatchType_ProcessName),
Not: f.Not,
Outbound: b.OutboundToId(outboundName),
Outbound: outboundId,
Mark: outbound.Mark,
}
copy(matchSet.Value[:], value[:])
b.rules = append(b.rules, matchSet)
}
return nil
}
func (b *RoutingMatcherBuilder) AddFallback(outbound *routing.Outbound) {
if b.err != nil {
return
func (b *RoutingMatcherBuilder) addFallback(fallbackOutbound config.FunctionOrString) (err error) {
outbound, err := routing.ParseOutbound(config.FunctionOrStringToFunction(fallbackOutbound))
if err != nil {
return err
}
outboundId, err := b.outboundToId(outbound.Name)
if err != nil {
return err
}
b.rules = append(b.rules, bpfMatchSet{
Type: uint8(consts.MatchType_Fallback),
Outbound: b.OutboundToId(outbound.Name),
Outbound: outboundId,
Mark: outbound.Mark,
})
return nil
}
func (b *RoutingMatcherBuilder) BuildKernspace() (err error) {
if b.err != nil {
return b.err
}
// Update lpm_array_map.
for i, cidrs := range b.simulatedLpmTries {
var keys []_bpfLpmKey
@ -255,8 +302,7 @@ func (b *RoutingMatcherBuilder) BuildKernspace() (err error) {
// Write routings.
// Fallback rule MUST be the last.
if b.rules[len(b.rules)-1].Type != uint8(consts.MatchType_Fallback) {
b.err = fmt.Errorf("fallback rule MUST be the last")
return b.err
return fmt.Errorf("fallback rule MUST be the last")
}
routingsLen := uint32(len(b.rules))
routingsKeys := common.ARangeU32(routingsLen)
@ -266,34 +312,28 @@ func (b *RoutingMatcherBuilder) BuildKernspace() (err error) {
return fmt.Errorf("BpfMapBatchUpdate: %w", err)
}
// Release.
b.simulatedLpmTries = nil
return nil
}
func (b *RoutingMatcherBuilder) BuildUserspace() (matcher *RoutingMatcher, err error) {
if b.err != nil {
return nil, b.err
}
var m RoutingMatcher
func (b *RoutingMatcherBuilder) BuildUserspace(lpmArrayMap *ebpf.Map) (matcher *RoutingMatcher, err error) {
// Build domainMatcher
m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen)
domainMatcher := domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen)
for _, domains := range b.simulatedDomainSet {
m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
}
if err = m.domainMatcher.Build(); err != nil {
if err = domainMatcher.Build(); err != nil {
return nil, err
}
// Write routings.
// Fallback rule MUST be the last.
if b.rules[len(b.rules)-1].Type != uint8(consts.MatchType_Fallback) {
b.err = fmt.Errorf("fallback rule MUST be the last")
return nil, b.err
return nil, fmt.Errorf("fallback rule MUST be the last")
}
m.matches = b.rules
// Release.
b.simulatedDomainSet = nil
return &m, nil
return &RoutingMatcher{
lpmArrayMap: lpmArrayMap,
domainMatcher: domainMatcher,
matches: b.rules,
}, nil
}

View File

@ -8,13 +8,11 @@ package control
import (
"encoding/binary"
"fmt"
"github.com/Asphaltt/lpmtrie"
"github.com/cilium/ebpf"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/routing"
"net"
"net/netip"
)
type RoutingMatcher struct {
@ -33,11 +31,11 @@ func (m *RoutingMatcher) Match(
ipVersion consts.IpVersionType,
l4proto consts.L4ProtoType,
domain string,
processName string,
processName [16]uint8,
mac []byte,
) (outboundIndex consts.OutboundIndex, err error) {
) (outboundIndex consts.OutboundIndex, mark uint32, err error) {
if len(sourceAddr) != net.IPv6len || len(destAddr) != net.IPv6len || len(mac) != net.IPv6len {
return 0, fmt.Errorf("bad address length")
return 0, 0, fmt.Errorf("bad address length")
}
lpmKeys := make([]*_bpfLpmKey, consts.MatchType_Mac+1)
lpmKeys[consts.MatchType_IpSet] = &_bpfLpmKey{
@ -68,11 +66,13 @@ func (m *RoutingMatcher) Match(
lpmIndex := uint32(binary.LittleEndian.Uint16(match.Value[:]))
var lpm *ebpf.Map
if err = m.lpmArrayMap.Lookup(lpmIndex, &lpm); err != nil {
//logrus.Debugln("m.lpmArrayMap.Lookup:", err)
break
}
var v uint32
if err = lpm.Lookup(*lpmKeys[int(match.Type)], &v); err != nil {
_ = lpm.Close()
//logrus.Debugln("lpm.Lookup:", err, lpmKeys[int(match.Type)], match.Type, destAddr)
break
}
_ = lpm.Close()
@ -104,13 +104,13 @@ func (m *RoutingMatcher) Match(
goodSubrule = true
}
case consts.MatchType_ProcessName:
if processName != "" && string(match.Value[:]) == processName {
if processName[0] != 0 && match.Value == processName {
goodSubrule = true
}
case consts.MatchType_Fallback:
goodSubrule = true
default:
return 0, fmt.Errorf("unknown match type: %v", match.Type)
return 0, 0, fmt.Errorf("unknown match type: %v", match.Type)
}
beforeNextLoop:
outbound := consts.OutboundIndex(match.Outbound)
@ -133,27 +133,10 @@ func (m *RoutingMatcher) Match(
// Tail of a rule (line).
// Decide whether to hit.
if !badRule {
if outbound == consts.OutboundDirect && destPort == 53 &&
l4proto == consts.L4ProtoType_UDP {
// DNS packet should go through control plane.
return consts.OutboundControlPlaneDirect, nil
}
return outbound, nil
return outbound, match.Mark, nil
}
badRule = false
}
}
return 0, fmt.Errorf("no match set hit")
}
func cidrToLpmTrieKey(prefix netip.Prefix) lpmtrie.Key {
bits := prefix.Bits()
if prefix.Addr().Is4() {
bits += 96
}
ip := prefix.Addr().As16()
return lpmtrie.Key{
PrefixLen: bits,
Data: ip[:],
}
return 0, 0, fmt.Errorf("no match set hit")
}

View File

@ -59,20 +59,27 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
}
dst = netip.AddrPortFrom(dstAddr, common.Htons(value.Port))
}
src = common.ConvergeAddrPort(src)
dst = common.ConvergeAddrPort(dst)
var outboundIndex = consts.OutboundIndex(routingResult.Outbound)
switch outboundIndex {
case consts.OutboundDirect:
case consts.OutboundMustDirect:
fallthrough
case consts.OutboundControlPlaneDirect:
outboundIndex = consts.OutboundDirect
case consts.OutboundControlPlaneRouting:
if outboundIndex, routingResult.Mark, err = c.Route(src, dst, domain, consts.L4ProtoType_TCP, routingResult); err != nil {
return err
}
routingResult.Outbound = uint8(outboundIndex)
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.Tracef("outbound: %v => %v",
consts.OutboundControlPlaneRouting.String(),
outboundIndex.String(),
consts.OutboundDirect.String(),
)
}
outboundIndex = consts.OutboundDirect
default:
}
outbound := c.outbounds[outboundIndex]
@ -104,8 +111,7 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
}
// Dial and relay.
dst = netip.AddrPortFrom(common.ConvergeIp(dst.Addr()), dst.Port())
rConn, err := d.Dial(GetNetwork("tcp", routingResult.Mark), c.ChooseDialTarget(outboundIndex, dst, domain))
rConn, err := d.Dial(MagicNetwork("tcp", routingResult.Mark), c.ChooseDialTarget(outboundIndex, dst, domain))
if err != nil {
return fmt.Errorf("failed to dial %v: %w", dst, err)
}

View File

@ -7,21 +7,18 @@ package control
import (
"encoding/binary"
"encoding/gob"
"errors"
"fmt"
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
"github.com/mzz2017/softwind/pool"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/outbound/dialer"
"github.com/v2rayA/dae/component/sniffing"
internal "github.com/v2rayA/dae/pkg/ebpf_internal"
"golang.org/x/net/dns/dnsmessage"
"io"
"net"
"net/netip"
"strings"
"syscall"
"time"
"unsafe"
@ -33,17 +30,14 @@ const (
MaxRetry = 2
)
var (
UnspecifiedAddr4 = netip.AddrFrom4([4]byte{})
UnspecifiedAddr6 = netip.AddrFrom16([16]byte{})
)
func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Message, timeout time.Duration) {
if sniffDns {
var dnsmsg dnsmessage.Message
if err := dnsmsg.Unpack(data); err == nil {
//log.Printf("DEBUG: lookup %v", dnsmsg.Questions[0].Name)
return &dnsmsg, DnsNatTimeout
}
}
return nil, DefaultNatTimeout
}
@ -57,29 +51,34 @@ func ParseAddrHdr(data []byte) (hdr *bpfDstRoutingResult, dataOffset int, err er
return &_hdr, dataOffset, nil
}
func sendPktWithHdrWithFlag(data []byte, mark uint32, from netip.AddrPort, lConn *net.UDPConn, to netip.AddrPort, lanWanFlag consts.LanWanFlag) error {
func sendPktWithHdrWithFlag(data []byte, realFrom netip.AddrPort, lConn *net.UDPConn, to netip.AddrPort, lanWanFlag consts.LanWanFlag) error {
realFrom16 := realFrom.Addr().As16()
hdr := bpfDstRoutingResult{
Ip: common.Ipv6ByteSliceToUint32Array(from.Addr().AsSlice()),
Port: common.Htons(from.Port()),
Ip: common.Ipv6ByteSliceToUint32Array(realFrom16[:]),
Port: common.Htons(realFrom.Port()),
RoutingResult: bpfRoutingResult{
Outbound: uint8(lanWanFlag), // Pass some message to the kernel program.
},
}
buf := pool.Get(int(unsafe.Sizeof(hdr)) + len(data))
defer pool.Put(buf)
b := buffer.NewBufferFrom(buf)
// Do not put this 'buf' because it has been taken by buffer.
b := buffer.NewBuffer(int(unsafe.Sizeof(hdr)) + len(data))
defer b.Put()
if err := gob.NewEncoder(b).Encode(&hdr); err != nil {
// Use internal.NativeEndian due to already big endian.
if err := binary.Write(b, internal.NativeEndian, hdr); err != nil {
return err
}
copy(buf[int(unsafe.Sizeof(hdr)):], data)
//log.Println("from", from, "to", to)
_, err := lConn.WriteToUDPAddrPort(buf, to)
b.Write(data)
//logrus.Debugln("sendPktWithHdrWithFlag: from", realFrom, "to", to)
_, err := lConn.WriteToUDPAddrPort(b.Bytes(), to)
return err
}
// sendPkt uses bind first, and fallback to send hdr if addr is in use.
func sendPkt(data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn *net.UDPConn, lanWanFlag consts.LanWanFlag) (err error) {
if lanWanFlag == consts.LanWanFlag_IsWan {
return sendPktWithHdrWithFlag(data, from, lConn, to, lanWanFlag)
}
d := net.Dialer{Control: func(network, address string, c syscall.RawConn) error {
return dialer.BindControl(c, from)
}}
@ -88,7 +87,7 @@ func sendPkt(data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn
if err != nil {
if errors.Is(err, syscall.EADDRINUSE) {
// Port collision, use traditional method.
return sendPktWithHdrWithFlag(data, 0, from, lConn, to, lanWanFlag)
return sendPktWithHdrWithFlag(data, from, lConn, to, lanWanFlag)
}
return err
}
@ -98,36 +97,6 @@ func sendPkt(data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn
return err
}
func (c *ControlPlane) WriteToUDP(lanWanFlag consts.LanWanFlag, lConn *net.UDPConn, realTo, to netip.AddrPort, isDNS bool, dummyFrom *netip.AddrPort, validateRushAnsFunc func(from netip.AddrPort) bool) UdpHandler {
return func(data []byte, from netip.AddrPort) (err error) {
// Do not return conn-unrelated err in this func.
if isDNS {
validateRushAns := validateRushAnsFunc(from)
data, err = c.DnsRespHandler(data, validateRushAns)
if err != nil {
if validateRushAns && errors.Is(err, SuspectedRushAnswerError) {
// Reject DNS rush-answer.
c.log.WithFields(logrus.Fields{
"from": from,
}).Tracef("DNS rush-answer rejected")
return err
}
if c.log.IsLevelEnabled(logrus.DebugLevel) {
c.log.Debugf("DnsRespHandler: %v", err)
}
if data == nil {
return nil
}
}
}
if dummyFrom != nil {
from = *dummyFrom
}
return sendPkt(data, from, realTo, to, lConn, lanWanFlag)
}
}
func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, realDst netip.AddrPort, routingResult *bpfRoutingResult) (err error) {
var lanWanFlag consts.LanWanFlag
var realSrc netip.AddrPort
@ -142,60 +111,11 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
realSrc = netip.AddrPortFrom(pktDst.Addr(), src.Port())
}
mustDirect := false
outboundIndex := consts.OutboundIndex(routingResult.Outbound)
switch outboundIndex {
case consts.OutboundDirect:
case consts.OutboundMustDirect:
mustDirect = true
fallthrough
case consts.OutboundControlPlaneDirect:
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.Tracef("outbound: %v => %v",
outboundIndex.String(),
consts.OutboundDirect.String(),
)
}
outboundIndex = consts.OutboundDirect
default:
}
if int(outboundIndex) >= len(c.outbounds) {
return fmt.Errorf("outbound %v out of range [0, %v]", outboundIndex, len(c.outbounds)-1)
}
outbound := c.outbounds[outboundIndex]
// To keep consistency with kernel program, we only sniff DNS request sent to 53.
dnsMessage, natTimeout := ChooseNatTimeout(data, realDst.Port() == 53)
// We should cache DNS records and set record TTL to 0, in order to monitor the dns req and resp in real time.
isDns := dnsMessage != nil
var dummyFrom *netip.AddrPort
destToSend := realDst
if isDns {
if resp := c.LookupDnsRespCache_(dnsMessage); resp != nil {
// Send cache to client directly.
if err = sendPkt(resp, destToSend, realSrc, src, lConn, lanWanFlag); err != nil {
return fmt.Errorf("failed to write cached DNS resp: %w", err)
}
if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 {
q := dnsMessage.Questions[0]
c.log.Tracef("UDP(DNS) %v <-[%v]-> Cache: %v %v",
RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), outbound.Name, strings.ToLower(q.Name.String()), q.Type,
)
}
return nil
}
// Flip dns question to reduce dns pollution.
FlipDnsQuestionCase(dnsMessage)
// Make sure there is additional record OPT in the request to filter DNS rush-answer in the response process.
// Because rush-answer has no resp OPT. We can distinguish them from multiple responses.
// Note that additional record OPT may not be supported by home router either.
_, _ = EnsureAdditionalOpt(dnsMessage, true)
// Re-pack DNS packet.
if data, err = dnsMessage.Pack(); err != nil {
return fmt.Errorf("pack flipped dns packet: %w", err)
}
} else {
if !isDns {
// Sniff Quic
sniffer := sniffing.NewPacketSniffer(data)
domain, err = sniffer.SniffQuic()
@ -206,108 +126,64 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
sniffer.Close()
}
l4proto := consts.L4ProtoStr_UDP
ipversion := consts.IpVersionFromAddr(realDst.Addr())
var dialerForNew *dialer.Dialer
// Get outbound.
outboundIndex := consts.OutboundIndex(routingResult.Outbound)
switch outboundIndex {
case consts.OutboundDirect:
case consts.OutboundMustDirect:
outboundIndex = consts.OutboundDirect
isDns = false // Regard as plain traffic.
case consts.OutboundControlPlaneRouting:
if isDns {
// Routing of DNS packets are managed by DNS controller.
break
}
// For DNS request, modify realDst to dns upstream.
// NOTICE: We might modify l4proto and ipversion.
dnsUpstream, err := c.dnsUpstream.GetUpstream()
if err != nil {
if outboundIndex, routingResult.Mark, err = c.Route(realSrc, realDst, domain, consts.L4ProtoType_TCP, routingResult); err != nil {
return err
}
if isDns && dnsUpstream != nil && !mustDirect {
// Modify dns target to upstream.
// NOTICE: Routing was calculated in advance by the eBPF program.
/// Choose the best l4proto+ipversion dialer, and change taregt DNS to the best ipversion DNS upstream for DNS request.
// Get available ipversions and l4protos for DNS upstream.
ipversions, l4protos := dnsUpstream.SupportedNetworks()
var (
bestDialer *dialer.Dialer
bestLatency time.Duration
bestTarget netip.AddrPort
routingResult.Outbound = uint8(outboundIndex)
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.Tracef("outbound: %v => %v",
consts.OutboundControlPlaneRouting.String(),
outboundIndex.String(),
)
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"ipversions": ipversions,
"l4protos": l4protos,
"src": realSrc.String(),
}).Traceln("Choose DNS path")
}
// Get the min latency path.
networkType := dialer.NetworkType{
IsDns: isDns,
default:
}
for _, ver := range ipversions {
for _, proto := range l4protos {
networkType.L4Proto = proto
networkType.IpVersion = ver
d, latency, err := outbound.Select(&networkType)
if err != nil {
continue
}
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"name": d.Name(),
"latency": latency,
"network": networkType.String(),
"outbound": outbound.Name,
}).Traceln("Choice")
}
if bestDialer == nil || latency < bestLatency {
bestDialer = d
bestLatency = latency
l4proto = proto
ipversion = ver
}
}
}
switch ipversion {
case consts.IpVersionStr_4:
bestTarget = netip.AddrPortFrom(dnsUpstream.Ip4, dnsUpstream.Port)
case consts.IpVersionStr_6:
bestTarget = netip.AddrPortFrom(dnsUpstream.Ip6, dnsUpstream.Port)
}
dialerForNew = bestDialer
dummyFrom = &realDst
destToSend = bestTarget
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"Original": RefineAddrPortToShow(realDst),
"New": destToSend,
"Network": string(l4proto) + string(ipversion),
}).Traceln("Modify DNS target")
if isDns {
return c.dnsController.Handle_(dnsMessage, &udpRequest{
lanWanFlag: lanWanFlag,
realSrc: realSrc,
realDst: realDst,
src: src,
lConn: lConn,
routingResult: routingResult,
})
}
if int(outboundIndex) >= len(c.outbounds) {
return fmt.Errorf("outbound %v out of range [0, %v]", outboundIndex, len(c.outbounds)-1)
}
outbound := c.outbounds[outboundIndex]
// Select dialer from outbound (dialer group).
networkType := &dialer.NetworkType{
L4Proto: l4proto,
IpVersion: ipversion,
IsDns: true,
L4Proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionFromAddr(realDst.Addr()),
IsDns: true, // UDP relies on DNS check result.
}
if dialerForNew == nil {
dialerForNew, _, err = outbound.Select(networkType)
dialerForNew, _, err := outbound.Select(networkType)
if err != nil {
return fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err)
}
}
var isNew bool
var realDialer *dialer.Dialer
udpHandler := c.WriteToUDP(lanWanFlag, lConn, realSrc, src, isDns, dummyFrom, func(from netip.AddrPort) bool {
// We only validate rush-ans when outbound is direct and pkt does not send to a home device.
// Because additional record OPT may not be supported by home router.
// So se should trust home devices even if they make rush-answer (or looks like).
return outboundIndex == consts.OutboundDirect && !common.ConvergeIp(from.Addr()).IsPrivate()
})
// Dial and send.
// TODO: Rewritten domain should not use full-cone (such as VMess Packet Addr).
// Maybe we should set up a mapping for UDP: Dialer + Target Domain => Remote Resolved IP.
destToSend = netip.AddrPortFrom(common.ConvergeIp(destToSend.Addr()), destToSend.Port())
tgtToSend := c.ChooseDialTarget(outboundIndex, destToSend, domain)
switch l4proto {
case consts.L4ProtoStr_UDP:
// However, games may not use QUIC for communication, thus we cannot use domain to dial, which is fine.
dialTarget := c.ChooseDialTarget(outboundIndex, realDst, domain)
// Get udp endpoint.
var ue *UdpEndpoint
retry := 0
@ -315,13 +191,16 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
if retry > MaxRetry {
return fmt.Errorf("touch max retry limit")
}
ue, isNew, err = DefaultUdpEndpointPool.GetOrCreate(realSrc, &UdpEndpointOptions{
Handler: udpHandler,
ue, isNew, err := DefaultUdpEndpointPool.GetOrCreate(realSrc, &UdpEndpointOptions{
// Handler handles response packets and send it to the client.
Handler: func(data []byte, from netip.AddrPort) (err error) {
// Do not return conn-unrelated err in this func.
return sendPkt(data, from, realSrc, src, lConn, lanWanFlag)
},
NatTimeout: natTimeout,
Dialer: dialerForNew,
Network: GetNetwork("udp", routingResult.Mark),
Target: tgtToSend,
Network: MagicNetwork("udp", routingResult.Mark),
Target: dialTarget,
})
if err != nil {
return fmt.Errorf("failed to GetOrCreate (policy: %v): %w", outbound.GetSelectionPolicy(), err)
@ -342,20 +221,18 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
retry++
goto getNew
}
// This is real dialer.
realDialer = ue.Dialer
_, err = ue.WriteTo(data, tgtToSend)
_, err = ue.WriteTo(data, dialTarget)
if err != nil {
if c.log.IsLevelEnabled(logrus.DebugLevel) {
c.log.WithFields(logrus.Fields{
"to": destToSend.String(),
"to": realDst.String(),
"domain": domain,
"pid": routingResult.Pid,
"pname": ProcessName2String(routingResult.Pname[:]),
"mac": Mac2String(routingResult.Mac[:]),
"from": realSrc.String(),
"network": networkType.String(),
"network": networkType.StringWithoutDns(),
"err": err.Error(),
"retry": retry,
}).Debugln("Failed to write UDP packet request. Try to remove old UDP endpoint and retry.")
@ -364,89 +241,22 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
retry++
goto getNew
}
case consts.L4ProtoStr_TCP:
// MUST be DNS.
if !isDns {
return fmt.Errorf("UDP to TCP only support DNS request")
}
isNew = true
realDialer = dialerForNew
// We can block because we are in a coroutine.
conn, err := dialerForNew.Dial(GetNetwork("tcp", routingResult.Mark), tgtToSend)
if err != nil {
return fmt.Errorf("failed to dial proxy to tcp: %w", err)
}
defer conn.Close()
_ = conn.SetDeadline(time.Now().Add(natTimeout))
// We should write two byte length in the front of TCP DNS request.
bReq := pool.Get(2 + len(data))
defer pool.Put(bReq)
binary.BigEndian.PutUint16(bReq, uint16(len(data)))
copy(bReq[2:], data)
_, err = conn.Write(bReq)
if err != nil {
return fmt.Errorf("failed to write DNS req: %w", err)
}
// Read two byte length.
if _, err = io.ReadFull(conn, bReq[:2]); err != nil {
return fmt.Errorf("failed to read DNS resp payload length: %w", err)
}
respLen := int(binary.BigEndian.Uint16(bReq))
// Try to reuse the buf.
var buf []byte
if len(bReq) < respLen {
buf = pool.Get(respLen)
defer pool.Put(buf)
} else {
buf = bReq
}
var n int
if n, err = io.ReadFull(conn, buf[:respLen]); err != nil {
return fmt.Errorf("failed to read DNS resp payload: %w", err)
}
if err = udpHandler(buf[:n], destToSend); err != nil {
return fmt.Errorf("failed to write DNS resp to client: %w", err)
}
}
// Print log.
if isNew || isDns {
// Only print routing for new connection to avoid the log exploded (Quic and BT).
if isDns && c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 {
q := dnsMessage.Questions[0]
c.log.WithFields(logrus.Fields{
"network": string(l4proto) + string(ipversion) + "(DNS)",
if isNew {
if c.log.IsLevelEnabled(logrus.InfoLevel) {
fields := logrus.Fields{
"network": networkType.StringWithoutDns(),
"outbound": outbound.Name,
"policy": outbound.GetSelectionPolicy(),
"dialer": realDialer.Name(),
"qname": strings.ToLower(q.Name.String()),
"qtype": q.Type,
"pid": routingResult.Pid,
"pname": ProcessName2String(routingResult.Pname[:]),
"mac": Mac2String(routingResult.Mac[:]),
}).Infof("%v <-> %v",
RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), RefineAddrPortToShow(destToSend),
)
} else if c.log.IsLevelEnabled(logrus.InfoLevel) {
if isDns && len(dnsMessage.Questions) > 0 {
domain = strings.ToLower(dnsMessage.Questions[0].Name.String())
}
c.log.WithFields(logrus.Fields{
"network": string(l4proto) + string(ipversion),
"outbound": outbound.Name,
"policy": outbound.GetSelectionPolicy(),
"dialer": realDialer.Name(),
"dialer": ue.Dialer.Name(),
"domain": domain,
"pid": routingResult.Pid,
"pname": ProcessName2String(routingResult.Pname[:]),
"mac": Mac2String(routingResult.Mac[:]),
}).Infof("%v <-> %v",
RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), RefineAddrPortToShow(destToSend),
)
}
c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), RefineAddrPortToShow(realDst))
}
}

View File

@ -19,6 +19,32 @@ import (
"syscall"
)
func (c *ControlPlane) Route(src, dst netip.AddrPort, domain string, l4proto consts.L4ProtoType, routingResult *bpfRoutingResult) (outboundIndex consts.OutboundIndex, mark uint32, err error) {
var ipVersion consts.IpVersionType
if dst.Addr().Is4() || dst.Addr().Is4In6() {
ipVersion = consts.IpVersion_4
} else {
ipVersion = consts.IpVersion_6
}
bSrc := src.Addr().As16()
bDst := dst.Addr().As16()
if outboundIndex, mark, err = c.routingMatcher.Match(
bSrc[:],
bDst[:],
src.Port(),
dst.Port(),
ipVersion,
l4proto,
domain,
routingResult.Pname,
append([]uint8{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, routingResult.Mac[:]...),
); err != nil {
return 0, 0, err
}
return outboundIndex, mark, nil
}
func (c *ControlPlaneCore) RetrieveRoutingResult(src, dst netip.AddrPort, l4proto uint8) (result *bpfRoutingResult, err error) {
srcIp6 := src.Addr().As16()
dstIp6 := dst.Addr().As16()
@ -79,7 +105,7 @@ func CheckIpforward(ifname string) error {
return nil
}
func GetNetwork(network string, mark uint32) string {
func MagicNetwork(network string, mark uint32) string {
if mark == 0 {
return network
} else {

49
docs/dns.md Normal file
View File

@ -0,0 +1,49 @@
# DNS
## Examples:
```shell
dns {
upstream {
# Value can be scheme://host:port.
# Scheme list: tcp, udp, tcp+udp. Ongoing: https, tls, quic.
# If host is a domain and has both IPv4 and IPv6 record, dae will automatically choose
# IPv4 or IPv6 to use according to group policy (such as min latency policy).
# Please make sure DNS traffic will go through and be forwarded by dae, which is REQUIRED for domain routing.
# If dial_mode is "ip", the upstream DNS answer SHOULD NOT be polluted, so domestic public DNS is not recommended.
alidns: 'udp://dns.alidns.com:53'
googledns: 'tcp+udp://dns.google:53'
}
# The routing format of 'request' and 'response' is similar with section 'routing'.
# See https://github.com/v2rayA/dae/blob/main/docs/routing.md
request {
# Built-in upstream in 'request': asis.
# You can also use user-defined upstreams.
# Available functions: qname, qtype.
# DNS request name (omit suffix dot '.').
qname(suffix: abc.com, keyword: google) -> googledns
qname(full: ok.com, regex: '^yes') -> googledns
# DNS request type
qtype(a, aaaa) -> alidns
qtype(cname) -> googledns
# If no match, fallback to this upstream.
fallback: asis
}
response {
# No built-in upstream in 'response'.
# You can use user-defined upstreams.
# Available functions: qname, qtype, upstream, ip.
# Accept the response if the request is sent to upstream 'googledns'. This is useful to avoid loop.
upstream(googledns) -> accept
# If DNS request name is not in CN and response answers include private IP, which is most likely polluted
# in China mainland. Therefore, resend DNS request to 'googledns' to get correct result.
!qname(geosite:cn) && ip(geoip:private) -> googledns
fallback: accept
}
}
```

View File

@ -1,4 +1,4 @@
# routing
# Routing
## Examples:

View File

@ -21,15 +21,6 @@ global {
# Group will switch node only when new_latency <= old_latency - tolerance.
check_tolerance: 50ms
# Value can be scheme://host:port or empty string ''.
# The scheme can be tcp/udp/tcp+udp. Empty string '' indicates as-is.
# If host is a domain and has both IPv4 and IPv6 record, dae will automatically choose
# IPv4 or IPv6 to use according to group policy (such as min latency policy).
# Please make sure DNS traffic will go through and be forwarded by dae, which is REQUIRED for domain routing.
# The upstream DNS answer MUST NOT be polluted, so domestic public DNS is not recommended.
# The request to DNS upstream follows the routing defined below.
dns_upstream: 'udp://dns.alidns.com:53'
# The LAN interface to bind. Use it if you only want to proxy LAN instead of localhost.
# Multiple interfaces split by ",".
#lan_interface: docker0
@ -79,6 +70,28 @@ node {
'ss://LINK'
}
# See more at https://github.com/v2rayA/dae/blob/main/docs/dns.md.
dns {
upstream {
# Value can be scheme://host:port, where the scheme can be tcp/udp/tcp+udp.
# If host is a domain and has both IPv4 and IPv6 record, dae will automatically choose
# IPv4 or IPv6 to use according to group policy (such as min latency policy).
# Please make sure DNS traffic will go through and be forwarded by dae, which is REQUIRED for domain routing.
# If dial_mode is "ip", the upstream DNS answer SHOULD NOT be polluted, so domestic public DNS is not recommended.
alidns: 'udp://dns.alidns.com:53'
googledns: 'tcp+udp://dns.google:53'
}
request {
fallback: asis
}
response {
upstream(googledns) -> accept
!qname(geosite:cn) && ip(geoip:private) -> googledns
fallback: accept
}
}
# Node group (outbound).
group {
my_group {
@ -108,7 +121,7 @@ group {
}
}
# See routing.md for full examples.
# See https://github.com/v2rayA/dae/blob/main/docs/routing.md for full examples.
routing {
### Preset rules.

3
go.mod
View File

@ -3,14 +3,13 @@ module github.com/v2rayA/dae
go 1.18
require (
github.com/Asphaltt/lpmtrie v0.0.0-20220205153150-3d814250b8ab
github.com/adrg/xdg v0.4.0
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20221202181307-76fa05c21b12
github.com/cilium/ebpf v0.10.0
github.com/gorilla/websocket v1.5.0
github.com/json-iterator/go v1.1.12
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/mzz2017/softwind v0.0.0-20230220064728-6940dc11777f
github.com/mzz2017/softwind v0.0.0-20230224125402-d460ce1c5b4b
github.com/safchain/ethtool v0.0.0-20230116090318-67cc41908669
github.com/sirupsen/logrus v1.9.0
github.com/spf13/cobra v1.6.1

6
go.sum
View File

@ -1,5 +1,3 @@
github.com/Asphaltt/lpmtrie v0.0.0-20220205153150-3d814250b8ab h1:hzN25CB5VzeKk3/c1fi1oT03N+5365nVOMPAxixkADY=
github.com/Asphaltt/lpmtrie v0.0.0-20220205153150-3d814250b8ab/go.mod h1:TdNTLzn3VVXKfmHAULK5gY+h/A1gLQ8NnwLB6cSN54g=
github.com/adrg/xdg v0.4.0 h1:RzRqFcjH4nE5C6oTAxhBtoE2IRyjBSa62SCbyPidvls=
github.com/adrg/xdg v0.4.0/go.mod h1:N6ag73EX4wyxeaoeHctc1mas01KZgsj5tYiAIwqJE/E=
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20221202181307-76fa05c21b12 h1:npHgfD4Tl2WJS3AJaMUi5ynGDPUBfkg3U3fCzDyXZ+4=
@ -68,8 +66,8 @@ github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
github.com/mzz2017/disk-bloom v1.0.1 h1:rEF9MiXd9qMW3ibRpqcerLXULoTgRlM21yqqJl1B90M=
github.com/mzz2017/disk-bloom v1.0.1/go.mod h1:JLHETtUu44Z6iBmsqzkOtFlRvXSlKnxjwiBRDapizDI=
github.com/mzz2017/softwind v0.0.0-20230220064728-6940dc11777f h1:Lmwy7FFI0PrWw0TgoQYtDiZBlCd/VZ1hBlySauTVWj4=
github.com/mzz2017/softwind v0.0.0-20230220064728-6940dc11777f/go.mod h1:V8GFOtdpTgzCJtCVXRqjmdDsY+PIhCCx4JpD0zq8Z7I=
github.com/mzz2017/softwind v0.0.0-20230224125402-d460ce1c5b4b h1:Do2nwPU6oKlZGBNUeTvyiNjFHRuOqAlunrQ+jwvSCJM=
github.com/mzz2017/softwind v0.0.0-20230224125402-d460ce1c5b4b/go.mod h1:V8GFOtdpTgzCJtCVXRqjmdDsY+PIhCCx4JpD0zq8Z7I=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=

View File

@ -15,6 +15,7 @@ import (
// license that can be found in the LICENSE file.
var smallBufferSize = 16
var defaultBufferSize = 64
// A Buffer is a variable-sized buffer of bytes with Read and Write methods.
// The zero value for Buffer is an empty buffer ready to use.
@ -158,7 +159,7 @@ func makeSlice[T constraints.Unsigned](n int) []T {
// sufficient to initialize a Buffer.
func NewBuffer[T constraints.Unsigned](size int) *Buffer[T] {
if size == 0 {
size = 512
size = defaultBufferSize
}
return &Buffer[T]{buf: make([]T, 0, size)}
}

View File

@ -181,7 +181,7 @@ type RoutingRule struct {
Outbound Function
}
func (r *RoutingRule) String(calcN bool) string {
func (r *RoutingRule) String(replaceParamWithN bool) string {
var builder strings.Builder
var n int
for i, f := range r.AndFunctions {
@ -190,7 +190,7 @@ func (r *RoutingRule) String(calcN bool) string {
}
var paramBuilder strings.Builder
n += len(f.Params)
if calcN {
if replaceParamWithN {
paramBuilder.WriteString("[n = " + strconv.Itoa(n) + "]")
} else {
for j, param := range f.Params {

View File

@ -3,18 +3,26 @@
package internal
import (
"github.com/v2rayA/dae/common"
"encoding/binary"
"syscall"
"unsafe"
)
// Htons converts the unsigned short integer hostshort from host byte order to network byte order.
func Htons(i uint16) uint16 {
b := make([]byte, 2)
binary.BigEndian.PutUint16(b, i)
return *(*uint16)(unsafe.Pointer(&b[0]))
}
func OpenRawSock(index int) (int, error) {
sock, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, int(common.Htons(syscall.ETH_P_ALL)))
sock, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, int(Htons(syscall.ETH_P_ALL)))
if err != nil {
return 0, err
}
sll := syscall.SockaddrLinklayer{
Ifindex: index,
Protocol: common.Htons(syscall.ETH_P_ALL),
Protocol: Htons(syscall.ETH_P_ALL),
}
if err := syscall.Bind(sock, &sll); err != nil {
return 0, err

View File

@ -1,5 +1,5 @@
// Package trie is modified from https://github.com/openacid/succinct/blob/loc100/sskv.go.
// Slower than about 50% but more memory saving.
// Slower than about 30% but more than 40% memory saving.
package trie
@ -9,53 +9,30 @@ import (
"math/bits"
)
var table = [256]byte{
97: 0, // 'a'
98: 1,
99: 2,
100: 3,
101: 4,
102: 5,
103: 6,
104: 7,
105: 8,
106: 9,
107: 10,
108: 11,
109: 12,
110: 13,
111: 14,
112: 15,
113: 16,
114: 17,
115: 18,
116: 19,
117: 20,
118: 21,
119: 22,
120: 23,
121: 24,
122: 25,
'-': 26,
'.': 27,
'^': 28,
'$': 29,
'1': 30,
'2': 31,
'3': 32,
'4': 33,
'5': 34,
'6': 35,
'7': 36,
'8': 37,
'9': 38,
'0': 39,
type ValidChars struct {
table [256]byte
n uint16
zeroChar byte
}
const N = 40
func NewValidChars(validChars []byte) (v *ValidChars) {
v = new(ValidChars)
for _, c := range validChars {
if v.n == 0 {
v.zeroChar = c
}
v.table[c] = byte(v.n)
v.n++
}
return v
}
func IsValidChar(b byte) bool {
return table[b] > 0 || b == 'a'
func (v *ValidChars) Size() int {
return int(v.n)
}
func (v *ValidChars) IsValidChar(c byte) bool {
return v.table[c] > 0 || c == v.zeroChar
}
// Trie is a succinct, sorted and static string set impl with compacted trie as
@ -103,22 +80,26 @@ type Trie struct {
ranks, selects []int32
labels *bitlist.CompactBitList
ranksBL, selectsBL *bitlist.CompactBitList
chars *ValidChars
}
// NewTrie creates a new *Trie struct, from a slice of sorted strings.
func NewTrie(keys []string) (*Trie, error) {
func NewTrie(keys []string, chars *ValidChars) (*Trie, error) {
// Check chars.
for _, key := range keys {
for _, c := range []byte(key) {
if !IsValidChar(c) {
if !chars.IsValidChar(c) {
return nil, fmt.Errorf("char out of range: %c", c)
}
}
}
ss := &Trie{}
ss.labels = bitlist.NewCompactBitList(bits.Len8(N))
ss := &Trie{
chars: chars,
labels: bitlist.NewCompactBitList(bits.Len(uint(chars.Size()))),
}
lIdx := 0
type qElt struct{ s, e, col int }
@ -142,7 +123,7 @@ func NewTrie(keys []string) (*Trie, error) {
}
queue = append(queue, qElt{frm, j, elt.col + 1})
ss.labels.Append(uint64(table[keys[frm][elt.col]]))
ss.labels.Append(uint64(chars.table[keys[frm][elt.col]]))
setBit(&ss.labelBitmap, lIdx, 0)
lIdx++
}
@ -190,13 +171,16 @@ func (ss *Trie) HasPrefix(word string) bool {
return true
}
c := word[i]
if !ss.chars.IsValidChar(c) {
return false
}
for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 {
// no more labels in this node
return false
}
if byte(ss.labels.Get(bmIdx-nodeId)) == table[c] {
if byte(ss.labels.Get(bmIdx-nodeId)) == ss.chars.table[c] {
break
}
}

View File

@ -94,7 +94,7 @@ func TestTrie(t *testing.T) {
"zib.fmc^",
"zk.ytamlacbci.",
"zk.ytamlacbci^",
})
}, NewValidChars([]byte("0123456789abcdefghijklmnopqrstuvwxyz-.^")))
if err != nil {
t.Fatal(err)
}
@ -110,6 +110,9 @@ func TestTrie(t *testing.T) {
if !(trie.HasPrefix("nc.^") == true) {
t.Fatal("^.cn")
}
if !(trie.HasPrefix("nc._") == true) {
t.Fatal("_.cn")
}
if !(trie.HasPrefix("n") == false) {
t.Fatal("n")
}