mirror of
https://github.com/daeuniverse/dae.git
synced 2025-07-04 07:17:55 +07:00
feat: dns routing (#26)
This commit is contained in:
13
README.md
13
README.md
@ -4,18 +4,19 @@
|
||||
|
||||
**_dae_**, means goose, is a lightweight and high-performance transparent proxy solution.
|
||||
|
||||
In order to improve the traffic split performance as much as possible, dae runs the transparent proxy and traffic split suite in the linux kernel by eBPF. Therefore, we have the opportunity to make the direct traffic bypass the forwarding by proxy application and achieve true direct traffic through. Under such a magic trick, there is almost no performance loss and additional resource consumption for direct traffic.
|
||||
In order to improve the traffic split performance as much as possible, dae runs the transparent proxy and traffic split suite in the linux kernel by eBPF. Therefore, dae has the opportunity to make the direct traffic bypass the forwarding by proxy application and achieve true direct traffic through. Under such a magic trick, there is almost no performance loss and additional resource consumption for direct traffic.
|
||||
|
||||
As a successor of [v2rayA](https://github.com/v2rayA/v2rayA), dae abandoned v2ray-core to meet the needs of users more freely.
|
||||
|
||||
**Features**
|
||||
|
||||
1. Implement `Real direct` traffic split (need ipforward on) to achieve [high performance](https://docs.google.com/spreadsheets/d/1UaWU6nNho7edBNjNqC8dfGXLlW0-cm84MM7sH6Gp7UE/edit?usp=sharing).
|
||||
1. Implement `Real Direct` traffic split (need ipforward on) to achieve [high performance](https://docs.google.com/spreadsheets/d/1UaWU6nNho7edBNjNqC8dfGXLlW0-cm84MM7sH6Gp7UE/edit?usp=sharing).
|
||||
1. Support to split traffic by process name in local host.
|
||||
1. Support to split traffic by MAC address in LAN.
|
||||
1. Support to split traffic with invert match rules.
|
||||
1. Support to automatically switch nodes according to policy. That is to say, support to automatically test independent TCP/UDP/IPv4/IPv6 latencies, and then use the best nodes for corresponding traffic according to user-defined policy.
|
||||
1. Support full-cone NAT for shadowsocks, vmess, socks5 and trojan(-go).
|
||||
1. Support advanced DNS resolution process.
|
||||
1. Support full-cone NAT for shadowsocks, trojan(-go) and socks5 (no test).
|
||||
|
||||
## Prerequisites
|
||||
|
||||
@ -81,13 +82,9 @@ Please refer to [Quick Start Guide](./docs/getting-started/README.md) to start u
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] Check dns upstream and source loop (whether upstream is also a client of us) and remind the user to add sip rule.
|
||||
- [ ] WAN L4Checksum problem.
|
||||
- [ ] If the NIC checksumming offload is enabled, the Linux network stack will make a simple checksum a packet when it is sent out from local. When NIC discovers that the source IP of the packet is the local IP of the NIC, it will checksum it complete this checksum.
|
||||
- [ ] But the problem is, after the Linux network stack, before entering the network card, we modify the source IP of this packet, causing the Linux network stack to only make a simple checksum, and the NIC also assumes that this packet is not sent from local, so no further checksum completing.
|
||||
- [ ] Automatically check dns upstream and source loop (whether upstream is also a client of us) and remind the user to add sip rule.
|
||||
- [ ] MACv2 extension extraction.
|
||||
- [ ] Log to userspace.
|
||||
- [ ] Protocol-oriented node features detecting (or filter), such as full-cone (especially VMess and VLESS).
|
||||
- [ ] DNS traffic split.
|
||||
- [ ] Add quick-start guide
|
||||
- [ ] ...
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
@ -84,11 +85,15 @@ func Run(log *logrus.Logger, param *config.Params) (err error) {
|
||||
param.Group,
|
||||
¶m.Routing,
|
||||
¶m.Global,
|
||||
¶m.Dns,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Call GC to release memory.
|
||||
runtime.GC()
|
||||
|
||||
// Serve tproxy TCP/UDP server util signals.
|
||||
sigs := make(chan os.Signal, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGILL)
|
||||
|
@ -24,11 +24,11 @@ func NewCompactBitList(unitBitSize int) *CompactBitList {
|
||||
return &CompactBitList{
|
||||
unitBitSize: unitBitSize,
|
||||
size: 0,
|
||||
b: anybuffer.NewBuffer[uint16](1),
|
||||
b: anybuffer.NewBuffer[uint16](8),
|
||||
}
|
||||
}
|
||||
|
||||
// Set is not optimized yet.
|
||||
// Set function is not optimized yet.
|
||||
func (m *CompactBitList) Set(iUnit int, v uint64) {
|
||||
if bits.Len64(v) > m.unitBitSize {
|
||||
panic(fmt.Sprintf("value %v exceeds unit bit size", v))
|
||||
|
@ -41,6 +41,16 @@ func (l L4ProtoStr) ToL4Proto() uint8 {
|
||||
panic("unsupported l4proto")
|
||||
}
|
||||
|
||||
func (l L4ProtoStr) ToL4ProtoType() L4ProtoType {
|
||||
switch l {
|
||||
case L4ProtoStr_TCP:
|
||||
return L4ProtoType_TCP
|
||||
case L4ProtoStr_UDP:
|
||||
return L4ProtoType_UDP
|
||||
}
|
||||
panic("unsupported l4proto: " + l)
|
||||
}
|
||||
|
||||
type IpVersionStr string
|
||||
|
||||
const (
|
||||
@ -58,6 +68,16 @@ func (v IpVersionStr) ToIpVersion() uint8 {
|
||||
panic("unsupported ipversion")
|
||||
}
|
||||
|
||||
func (v IpVersionStr) ToIpVersionType() IpVersionType {
|
||||
switch v {
|
||||
case IpVersionStr_4:
|
||||
return IpVersion_4
|
||||
case IpVersionStr_6:
|
||||
return IpVersion_6
|
||||
}
|
||||
panic("unsupported ipversion")
|
||||
}
|
||||
|
||||
func IpVersionFromAddr(addr netip.Addr) IpVersionStr {
|
||||
var ipversion IpVersionStr
|
||||
switch {
|
||||
|
66
common/consts/dns.go
Normal file
66
common/consts/dns.go
Normal file
@ -0,0 +1,66 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package consts
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type DnsRequestOutboundIndex uint8
|
||||
|
||||
const (
|
||||
DnsRequestOutboundIndex_AsIs DnsRequestOutboundIndex = 0xFD
|
||||
DnsRequestOutboundIndex_LogicalOr DnsRequestOutboundIndex = 0xFE
|
||||
DnsRequestOutboundIndex_LogicalAnd DnsRequestOutboundIndex = 0xFF
|
||||
DnsRequestOutboundIndex_LogicalMask DnsRequestOutboundIndex = 0xFE
|
||||
|
||||
DnsRequestOutboundIndex_UserDefinedMax = DnsRequestOutboundIndex_AsIs - 1
|
||||
)
|
||||
|
||||
func (i DnsRequestOutboundIndex) String() string {
|
||||
switch i {
|
||||
case DnsRequestOutboundIndex_AsIs:
|
||||
return "asis"
|
||||
case DnsRequestOutboundIndex_LogicalOr:
|
||||
return "<OR>"
|
||||
case DnsRequestOutboundIndex_LogicalAnd:
|
||||
return "<AND>"
|
||||
default:
|
||||
return "<index: " + strconv.Itoa(int(i)) + ">"
|
||||
}
|
||||
}
|
||||
|
||||
type DnsResponseOutboundIndex uint8
|
||||
|
||||
const (
|
||||
DnsResponseOutboundIndex_Accept DnsResponseOutboundIndex = 0xFC
|
||||
DnsResponseOutboundIndex_Reject DnsResponseOutboundIndex = 0xFD
|
||||
DnsResponseOutboundIndex_LogicalOr DnsResponseOutboundIndex = 0xFE
|
||||
DnsResponseOutboundIndex_LogicalAnd DnsResponseOutboundIndex = 0xFF
|
||||
DnsResponseOutboundIndex_LogicalMask DnsResponseOutboundIndex = 0xFE
|
||||
|
||||
DnsResponseOutboundIndex_UserDefinedMax = DnsResponseOutboundIndex_Accept - 1
|
||||
)
|
||||
|
||||
func (i DnsResponseOutboundIndex) String() string {
|
||||
switch i {
|
||||
case DnsResponseOutboundIndex_Accept:
|
||||
return "accept"
|
||||
case DnsResponseOutboundIndex_Reject:
|
||||
return "reject"
|
||||
case DnsResponseOutboundIndex_LogicalOr:
|
||||
return "<OR>"
|
||||
case DnsResponseOutboundIndex_LogicalAnd:
|
||||
return "<AND>"
|
||||
default:
|
||||
return "<index: " + strconv.Itoa(int(i)) + ">"
|
||||
}
|
||||
}
|
||||
|
||||
func (i DnsResponseOutboundIndex) IsReserved() bool {
|
||||
return !strings.HasPrefix(i.String(), "<index: ")
|
||||
}
|
@ -27,6 +27,7 @@ const (
|
||||
DisableL4RxChecksumKey
|
||||
ControlPlanePidKey
|
||||
ControlPlaneNatDirectKey
|
||||
ControlPlaneDnsRoutingKey
|
||||
|
||||
OneKey ParamKey = 1
|
||||
)
|
||||
@ -52,18 +53,22 @@ const (
|
||||
MatchType_Mac
|
||||
MatchType_ProcessName
|
||||
MatchType_Fallback
|
||||
|
||||
MatchType_Upstream
|
||||
MatchType_QType
|
||||
)
|
||||
|
||||
type OutboundIndex uint8
|
||||
|
||||
const (
|
||||
OutboundDirect OutboundIndex = 0
|
||||
OutboundBlock OutboundIndex = 1
|
||||
OutboundMustDirect OutboundIndex = 0xFC
|
||||
OutboundControlPlaneDirect OutboundIndex = 0xFD
|
||||
OutboundLogicalOr OutboundIndex = 0xFE
|
||||
OutboundLogicalAnd OutboundIndex = 0xFF
|
||||
OutboundLogicalMask OutboundIndex = 0xFE
|
||||
OutboundDirect OutboundIndex = iota
|
||||
OutboundBlock
|
||||
|
||||
OutboundMustDirect OutboundIndex = 0xFC
|
||||
OutboundControlPlaneRouting OutboundIndex = 0xFD
|
||||
OutboundLogicalOr OutboundIndex = 0xFE
|
||||
OutboundLogicalAnd OutboundIndex = 0xFF
|
||||
OutboundLogicalMask OutboundIndex = 0xFE
|
||||
|
||||
OutboundMax = OutboundLogicalAnd
|
||||
OutboundUserDefinedMax = OutboundMustDirect - 1
|
||||
@ -77,8 +82,8 @@ func (i OutboundIndex) String() string {
|
||||
return "block"
|
||||
case OutboundMustDirect:
|
||||
return "must_direct"
|
||||
case OutboundControlPlaneDirect:
|
||||
return "<Control Plane Direct>"
|
||||
case OutboundControlPlaneRouting:
|
||||
return "<Control Plane Routing>"
|
||||
case OutboundLogicalOr:
|
||||
return "<OR>"
|
||||
case OutboundLogicalAnd:
|
||||
|
@ -23,7 +23,9 @@ const (
|
||||
Function_Mac = "mac"
|
||||
Function_ProcessName = "pname"
|
||||
|
||||
Declaration_Fallback = "fallback"
|
||||
Function_QName = "qname"
|
||||
Function_QType = "qtype"
|
||||
Function_Upstream = "upstream"
|
||||
|
||||
OutboundParam_Mark = "mark"
|
||||
)
|
||||
|
@ -17,7 +17,7 @@ type Ip46 struct {
|
||||
Ip6 netip.Addr
|
||||
}
|
||||
|
||||
func ParseIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort, host string, tcp bool) (ipv46 *Ip46, err error) {
|
||||
func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort, host string, tcp bool) (ipv46 *Ip46, err error) {
|
||||
addrs4, err := ResolveNetip(ctx, dialer, dns, host, dnsmessage.TypeA, tcp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
internal "github.com/v2rayA/dae/pkg/ebpf_internal"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
@ -48,7 +49,7 @@ func ARangeU32(n uint32) []uint32 {
|
||||
|
||||
func Ipv6ByteSliceToUint32Array(_ip []byte) (ip [4]uint32) {
|
||||
for j := 0; j < 16; j += 4 {
|
||||
ip[j/4] = binary.LittleEndian.Uint32(_ip[j : j+4])
|
||||
ip[j/4] = internal.NativeEndian.Uint32(_ip[j : j+4])
|
||||
}
|
||||
return ip
|
||||
}
|
||||
@ -61,7 +62,7 @@ func Ipv6ByteSliceToUint8Array(_ip []byte) (ip [16]uint8) {
|
||||
func Ipv6Uint32ArrayToByteSlice(_ip [4]uint32) (ip []byte) {
|
||||
ip = make([]byte, 16)
|
||||
for j := 0; j < 4; j++ {
|
||||
binary.LittleEndian.PutUint32(ip[j*4:], _ip[j])
|
||||
internal.NativeEndian.PutUint32(ip[j*4:], _ip[j])
|
||||
}
|
||||
return ip
|
||||
}
|
||||
@ -372,13 +373,20 @@ func BoolToString(b bool) string {
|
||||
}
|
||||
}
|
||||
|
||||
func ConvergeIp(addr netip.Addr) netip.Addr {
|
||||
func ConvergeAddr(addr netip.Addr) netip.Addr {
|
||||
if addr.Is4In6() {
|
||||
addr = netip.AddrFrom4(addr.As4())
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
func ConvergeAddrPort(addrPort netip.AddrPort) netip.AddrPort {
|
||||
if addrPort.Addr().Is4In6() {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4(addrPort.Addr().As4()), addrPort.Port())
|
||||
}
|
||||
return addrPort
|
||||
}
|
||||
|
||||
func NewGcm(key []byte) (cipher.AEAD, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
|
203
component/dns/dns.go
Normal file
203
component/dns/dns.go
Normal file
@ -0,0 +1,203 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/component/routing"
|
||||
"github.com/v2rayA/dae/config"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var BadUpstreamFormatError = fmt.Errorf("bad upstream format")
|
||||
|
||||
type Dns struct {
|
||||
upstream []*UpstreamResolver
|
||||
upstream2IndexMu sync.Mutex
|
||||
upstream2Index map[*Upstream]int
|
||||
reqMatcher *RequestMatcher
|
||||
respMatcher *ResponseMatcher
|
||||
}
|
||||
|
||||
type NewOption struct {
|
||||
UpstreamReadyCallback func(raw *url.URL, upstream *Upstream) (err error)
|
||||
}
|
||||
|
||||
func New(log *logrus.Logger, dns *config.Dns, opt *NewOption) (s *Dns, err error) {
|
||||
s = &Dns{
|
||||
upstream2Index: map[*Upstream]int{
|
||||
nil: int(consts.DnsRequestOutboundIndex_AsIs),
|
||||
},
|
||||
}
|
||||
// Parse upstream.
|
||||
upstreamName2Id := map[string]uint8{}
|
||||
for i, upstreamRaw := range dns.Upstream {
|
||||
if i >= int(consts.DnsRequestOutboundIndex_UserDefinedMax) ||
|
||||
i >= int(consts.DnsResponseOutboundIndex_UserDefinedMax) {
|
||||
return nil, fmt.Errorf("too many upstreams")
|
||||
}
|
||||
|
||||
tag, link := common.GetTagFromLinkLikePlaintext(upstreamRaw)
|
||||
if tag == "" {
|
||||
return nil, fmt.Errorf("%w: '%v' has no tag", BadUpstreamFormatError, upstreamRaw)
|
||||
}
|
||||
u, err := url.Parse(link)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", BadUpstreamFormatError, err)
|
||||
}
|
||||
r := &UpstreamResolver{
|
||||
Raw: u,
|
||||
FinishInitCallback: func(i int) func(raw *url.URL, upstream *Upstream) (err error) {
|
||||
return func(raw *url.URL, upstream *Upstream) (err error) {
|
||||
if opt != nil && opt.UpstreamReadyCallback != nil {
|
||||
if err = opt.UpstreamReadyCallback(raw, upstream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.upstream2IndexMu.Lock()
|
||||
s.upstream2Index[upstream] = i
|
||||
s.upstream2IndexMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
}(i),
|
||||
}
|
||||
upstreamName2Id[tag] = uint8(len(s.upstream))
|
||||
s.upstream = append(s.upstream, r)
|
||||
}
|
||||
// Optimize routings.
|
||||
if dns.Routing.Request.Rules, err = routing.ApplyRulesOptimizers(dns.Routing.Request.Rules,
|
||||
&routing.DatReaderOptimizer{Logger: log},
|
||||
&routing.MergeAndSortRulesOptimizer{},
|
||||
&routing.DeduplicateParamsOptimizer{},
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dns.Routing.Response.Rules, err = routing.ApplyRulesOptimizers(dns.Routing.Response.Rules,
|
||||
&routing.DatReaderOptimizer{Logger: log},
|
||||
&routing.MergeAndSortRulesOptimizer{},
|
||||
&routing.DeduplicateParamsOptimizer{},
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Parse request routing.
|
||||
reqMatcherBuilder, err := NewRequestMatcherBuilder(log, dns.Routing.Request.Rules, upstreamName2Id, dns.Routing.Request.Fallback)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build DNS request routing: %w", err)
|
||||
}
|
||||
s.reqMatcher, err = reqMatcherBuilder.Build()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build DNS request routing: %w", err)
|
||||
}
|
||||
// Parse response routing.
|
||||
respMatcherBuilder, err := NewResponseMatcherBuilder(log, dns.Routing.Response.Rules, upstreamName2Id, dns.Routing.Response.Fallback)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build DNS response routing: %w", err)
|
||||
}
|
||||
s.respMatcher, err = respMatcherBuilder.Build()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build DNS response routing: %w", err)
|
||||
}
|
||||
if len(dns.Upstream) == 0 {
|
||||
// Immediately ready.
|
||||
if err = opt.UpstreamReadyCallback(nil, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Dns) RequestSelect(msg *dnsmessage.Message) (upstream *Upstream, err error) {
|
||||
if msg.Response {
|
||||
return nil, fmt.Errorf("DNS request expected but DNS response received")
|
||||
}
|
||||
|
||||
// Prepare routing.
|
||||
var qname string
|
||||
var qtype dnsmessage.Type
|
||||
if len(msg.Questions) == 0 {
|
||||
qname = ""
|
||||
qtype = 0
|
||||
} else {
|
||||
q := msg.Questions[0]
|
||||
qname = q.Name.String()
|
||||
qtype = q.Type
|
||||
}
|
||||
// Route.
|
||||
upstreamIndex, err := s.reqMatcher.Match(qname, qtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// nil indicates AsIs.
|
||||
if upstreamIndex == consts.DnsRequestOutboundIndex_AsIs {
|
||||
return nil, nil
|
||||
}
|
||||
if int(upstreamIndex) >= len(s.upstream) {
|
||||
return nil, fmt.Errorf("bad upstream index: %v not in [0, %v]", upstreamIndex, len(s.upstream)-1)
|
||||
}
|
||||
// Get corresponding upstream.
|
||||
upstream, err = s.upstream[upstreamIndex].GetUpstream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return upstream, nil
|
||||
}
|
||||
|
||||
func (s *Dns) ResponseSelect(msg *dnsmessage.Message, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) {
|
||||
if !msg.Response {
|
||||
return 0, nil, fmt.Errorf("DNS response expected but DNS request received")
|
||||
}
|
||||
|
||||
// Prepare routing.
|
||||
var qname string
|
||||
var qtype dnsmessage.Type
|
||||
var ips []netip.Addr
|
||||
if len(msg.Questions) == 0 {
|
||||
qname = ""
|
||||
qtype = 0
|
||||
} else {
|
||||
q := msg.Questions[0]
|
||||
qname = q.Name.String()
|
||||
qtype = q.Type
|
||||
for _, ans := range msg.Answers {
|
||||
switch body := ans.Body.(type) {
|
||||
case *dnsmessage.AResource:
|
||||
ips = append(ips, netip.AddrFrom4(body.A))
|
||||
case *dnsmessage.AAAAResource:
|
||||
ips = append(ips, netip.AddrFrom16(body.AAAA))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.upstream2IndexMu.Lock()
|
||||
from := s.upstream2Index[fromUpstream]
|
||||
s.upstream2IndexMu.Unlock()
|
||||
// Route.
|
||||
upstreamIndex, err = s.respMatcher.Match(qname, qtype, ips, consts.DnsRequestOutboundIndex(from))
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
// Get corresponding upstream if upstream is neither 'accept' nor 'reject'.
|
||||
if !upstreamIndex.IsReserved() {
|
||||
if int(upstreamIndex) >= len(s.upstream) {
|
||||
return 0, nil, fmt.Errorf("bad upstream index: %v not in [0, %v]", upstreamIndex, len(s.upstream)-1)
|
||||
}
|
||||
upstream, err = s.upstream[upstreamIndex].GetUpstream()
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
} else {
|
||||
// Assign explicitly to let coder know.
|
||||
upstream = nil
|
||||
}
|
||||
return upstreamIndex, upstream, nil
|
||||
}
|
47
component/dns/function_parser.go
Normal file
47
component/dns/function_parser.go
Normal file
@ -0,0 +1,47 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/component/routing"
|
||||
"github.com/v2rayA/dae/pkg/config_parser"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var typeNames = map[string]dnsmessage.Type{
|
||||
"A": dnsmessage.TypeA,
|
||||
"NS": dnsmessage.TypeNS,
|
||||
"CNAME": dnsmessage.TypeCNAME,
|
||||
"SOA": dnsmessage.TypeSOA,
|
||||
"PTR": dnsmessage.TypePTR,
|
||||
"MX": dnsmessage.TypeMX,
|
||||
"TXT": dnsmessage.TypeTXT,
|
||||
"AAAA": dnsmessage.TypeAAAA,
|
||||
"SRV": dnsmessage.TypeSRV,
|
||||
"OPT": dnsmessage.TypeOPT,
|
||||
"WKS": dnsmessage.TypeWKS,
|
||||
"HINFO": dnsmessage.TypeHINFO,
|
||||
"MINFO": dnsmessage.TypeMINFO,
|
||||
"AXFR": dnsmessage.TypeAXFR,
|
||||
"ALL": dnsmessage.TypeALL,
|
||||
}
|
||||
|
||||
func TypeParserFactory(callback func(f *config_parser.Function, types []dnsmessage.Type, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) {
|
||||
var types []dnsmessage.Type
|
||||
for _, v := range paramValueGroup {
|
||||
t, ok := typeNames[strings.ToUpper(v)]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown DNS request type: %v", v)
|
||||
}
|
||||
types = append(types, t)
|
||||
}
|
||||
return callback(f, types, overrideOutbound)
|
||||
}
|
||||
}
|
213
component/dns/request_routing.go
Normal file
213
component/dns/request_routing.go
Normal file
@ -0,0 +1,213 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/component/routing"
|
||||
"github.com/v2rayA/dae/component/routing/domain_matcher"
|
||||
"github.com/v2rayA/dae/config"
|
||||
"github.com/v2rayA/dae/pkg/config_parser"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type RequestMatcherBuilder struct {
|
||||
upstreamName2Id map[string]uint8
|
||||
simulatedDomainSet []routing.DomainSet
|
||||
fallback *routing.Outbound
|
||||
rules []requestMatchSet
|
||||
}
|
||||
|
||||
func NewRequestMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, upstreamName2Id map[string]uint8, fallback config.FunctionOrString) (b *RequestMatcherBuilder, err error) {
|
||||
b = &RequestMatcherBuilder{upstreamName2Id: upstreamName2Id}
|
||||
rulesBuilder := routing.NewRulesBuilder(log)
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_QName, routing.PlainParserFactory(b.addQName))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_QType, TypeParserFactory(b.addQType))
|
||||
if err = rulesBuilder.Apply(rules); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = b.addFallback(fallback); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *RequestMatcherBuilder) upstreamToId(upstream string) (upstreamId consts.DnsRequestOutboundIndex, err error) {
|
||||
switch upstream {
|
||||
case consts.DnsRequestOutboundIndex_AsIs.String():
|
||||
upstreamId = consts.DnsRequestOutboundIndex_AsIs
|
||||
case consts.DnsRequestOutboundIndex_LogicalAnd.String():
|
||||
upstreamId = consts.DnsRequestOutboundIndex_LogicalAnd
|
||||
case consts.DnsRequestOutboundIndex_LogicalOr.String():
|
||||
upstreamId = consts.DnsRequestOutboundIndex_LogicalOr
|
||||
default:
|
||||
_upstreamId, ok := b.upstreamName2Id[upstream]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("upstream %v not found; please define it in section \"dns.upstream\"", strconv.Quote(upstream))
|
||||
}
|
||||
upstreamId = consts.DnsRequestOutboundIndex(_upstreamId)
|
||||
}
|
||||
return upstreamId, nil
|
||||
}
|
||||
|
||||
func (b *RequestMatcherBuilder) addQName(f *config_parser.Function, key string, values []string, upstream *routing.Outbound) (err error) {
|
||||
switch consts.RoutingDomainKey(key) {
|
||||
case consts.RoutingDomainKey_Regex,
|
||||
consts.RoutingDomainKey_Full,
|
||||
consts.RoutingDomainKey_Keyword,
|
||||
consts.RoutingDomainKey_Suffix:
|
||||
default:
|
||||
return fmt.Errorf("addQName: unsupported key: %v", key)
|
||||
}
|
||||
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
|
||||
Key: consts.RoutingDomainKey(key),
|
||||
RuleIndex: len(b.simulatedDomainSet),
|
||||
Domains: values,
|
||||
})
|
||||
upstreamId, err := b.upstreamToId(upstream.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, requestMatchSet{
|
||||
Type: consts.MatchType_DomainSet,
|
||||
Not: f.Not,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
|
||||
for i, value := range values {
|
||||
upstreamName := consts.OutboundLogicalOr.String()
|
||||
if i == len(values)-1 {
|
||||
upstreamName = upstream.Name
|
||||
}
|
||||
upstreamId, err := b.upstreamToId(upstreamName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, requestMatchSet{
|
||||
Type: consts.MatchType_QType,
|
||||
Value: uint16(value),
|
||||
Not: f.Not,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RequestMatcherBuilder) addFallback(fallbackOutbound config.FunctionOrString) (err error) {
|
||||
upstream, err := routing.ParseOutbound(config.FunctionOrStringToFunction(fallbackOutbound))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
upstreamId, err := b.upstreamToId(upstream.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, requestMatchSet{
|
||||
Type: consts.MatchType_Fallback,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RequestMatcherBuilder) Build() (matcher *RequestMatcher, err error) {
|
||||
var m RequestMatcher
|
||||
// Build domainMatcher
|
||||
m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen)
|
||||
for _, domains := range b.simulatedDomainSet {
|
||||
m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
|
||||
}
|
||||
if err = m.domainMatcher.Build(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write routings.
|
||||
// Fallback rule MUST be the last.
|
||||
if b.rules[len(b.rules)-1].Type != consts.MatchType_Fallback {
|
||||
return nil, fmt.Errorf("fallback rule MUST be the last")
|
||||
}
|
||||
m.matches = b.rules
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
type RequestMatcher struct {
|
||||
domainMatcher routing.DomainMatcher // All domain matchSets use one DomainMatcher.
|
||||
|
||||
matches []requestMatchSet
|
||||
}
|
||||
|
||||
type requestMatchSet struct {
|
||||
Value uint16
|
||||
Not bool
|
||||
Type consts.MatchType
|
||||
Upstream uint8
|
||||
}
|
||||
|
||||
func (m *RequestMatcher) Match(
|
||||
qName string,
|
||||
qType dnsmessage.Type,
|
||||
) (upstreamIndex consts.DnsRequestOutboundIndex, err error) {
|
||||
var domainMatchBitmap []uint32
|
||||
if qName != "" {
|
||||
domainMatchBitmap = m.domainMatcher.MatchDomainBitmap(qName)
|
||||
}
|
||||
|
||||
goodSubrule := false
|
||||
badRule := false
|
||||
for i, match := range m.matches {
|
||||
if badRule || goodSubrule {
|
||||
goto beforeNextLoop
|
||||
}
|
||||
switch match.Type {
|
||||
case consts.MatchType_DomainSet:
|
||||
if domainMatchBitmap != nil && (domainMatchBitmap[i/32]>>(i%32))&1 > 0 {
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_QType:
|
||||
if qType == dnsmessage.Type(match.Value) {
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_Fallback:
|
||||
goodSubrule = true
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown match type: %v", match.Type)
|
||||
}
|
||||
beforeNextLoop:
|
||||
upstream := consts.DnsRequestOutboundIndex(match.Upstream)
|
||||
if upstream != consts.DnsRequestOutboundIndex_LogicalOr {
|
||||
// This match_set reaches the end of subrule.
|
||||
// We are now at end of rule, or next match_set belongs to another
|
||||
// subrule.
|
||||
|
||||
if goodSubrule == match.Not {
|
||||
// This subrule does not hit.
|
||||
badRule = true
|
||||
}
|
||||
|
||||
// Reset goodSubrule.
|
||||
goodSubrule = false
|
||||
}
|
||||
|
||||
if upstream&consts.DnsRequestOutboundIndex_LogicalMask !=
|
||||
consts.DnsRequestOutboundIndex_LogicalMask {
|
||||
// Tail of a rule (line).
|
||||
// Decide whether to hit.
|
||||
if !badRule {
|
||||
return upstream, nil
|
||||
}
|
||||
badRule = false
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("no match set hit")
|
||||
}
|
320
component/dns/response_routing.go
Normal file
320
component/dns/response_routing.go
Normal file
@ -0,0 +1,320 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/component/routing"
|
||||
"github.com/v2rayA/dae/component/routing/domain_matcher"
|
||||
"github.com/v2rayA/dae/config"
|
||||
"github.com/v2rayA/dae/pkg/config_parser"
|
||||
"github.com/v2rayA/dae/pkg/trie"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ValidCidrChars = trie.NewValidChars([]byte{'0', '1'})
|
||||
|
||||
type ResponseMatcherBuilder struct {
|
||||
upstreamName2Id map[string]uint8
|
||||
simulatedDomainSet []routing.DomainSet
|
||||
ipSet []*trie.Trie
|
||||
fallback *routing.Outbound
|
||||
rules []responseMatchSet
|
||||
}
|
||||
|
||||
func NewResponseMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, upstreamName2Id map[string]uint8, fallback config.FunctionOrString) (b *ResponseMatcherBuilder, err error) {
|
||||
b = &ResponseMatcherBuilder{upstreamName2Id: upstreamName2Id}
|
||||
rulesBuilder := routing.NewRulesBuilder(log)
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_QName, routing.PlainParserFactory(b.addQName))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_QType, TypeParserFactory(b.addQType))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_Ip, routing.IpParserFactory(b.addIp))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_Upstream, routing.EmptyKeyPlainParserFactory(b.addUpstream))
|
||||
if err = rulesBuilder.Apply(rules); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = b.addFallback(fallback); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) upstreamToId(upstream string) (upstreamId consts.DnsResponseOutboundIndex, err error) {
|
||||
switch upstream {
|
||||
case consts.DnsResponseOutboundIndex_Accept.String():
|
||||
upstreamId = consts.DnsResponseOutboundIndex_Accept
|
||||
case consts.DnsResponseOutboundIndex_Reject.String():
|
||||
upstreamId = consts.DnsResponseOutboundIndex_Reject
|
||||
case consts.DnsResponseOutboundIndex_LogicalAnd.String():
|
||||
upstreamId = consts.DnsResponseOutboundIndex_LogicalAnd
|
||||
case consts.DnsResponseOutboundIndex_LogicalOr.String():
|
||||
upstreamId = consts.DnsResponseOutboundIndex_LogicalOr
|
||||
default:
|
||||
_upstreamId, ok := b.upstreamName2Id[upstream]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("upstream %v not found; please define it in \"dns.upstream\"", strconv.Quote(upstream))
|
||||
}
|
||||
upstreamId = consts.DnsResponseOutboundIndex(_upstreamId)
|
||||
}
|
||||
return upstreamId, nil
|
||||
}
|
||||
|
||||
func prefix2bin128(prefix netip.Prefix) (bin128 string) {
|
||||
bits := prefix.Bits()
|
||||
if prefix.Addr().Is4() {
|
||||
bits += 96
|
||||
}
|
||||
ip := prefix.Addr().As16()
|
||||
buf := buffer.NewBuffer(128)
|
||||
defer buf.Put()
|
||||
loop:
|
||||
for i := 0; i < len(ip); i++ {
|
||||
for j := 0; j < 8; j++ {
|
||||
if (ip[i]>>j)&1 == 1 {
|
||||
buf.WriteByte('1')
|
||||
} else {
|
||||
buf.WriteByte('0')
|
||||
}
|
||||
bits--
|
||||
if bits == 0 {
|
||||
break loop
|
||||
}
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) addIp(f *config_parser.Function, cidrs []netip.Prefix, upstream *routing.Outbound) (err error) {
|
||||
upstreamId, err := b.upstreamToId(upstream.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rule := responseMatchSet{
|
||||
Value: uint16(len(b.ipSet)),
|
||||
Type: consts.MatchType_IpSet,
|
||||
Not: f.Not,
|
||||
Upstream: uint8(upstreamId),
|
||||
}
|
||||
var keys []string
|
||||
// Convert netip.Prefix -> '0' '1' string
|
||||
for _, prefix := range cidrs {
|
||||
keys = append(keys, prefix2bin128(prefix))
|
||||
}
|
||||
t, err := trie.NewTrie(keys, ValidCidrChars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.ipSet = append(b.ipSet, t)
|
||||
b.rules = append(b.rules, rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) addQName(f *config_parser.Function, key string, values []string, upstream *routing.Outbound) (err error) {
|
||||
switch consts.RoutingDomainKey(key) {
|
||||
case consts.RoutingDomainKey_Regex,
|
||||
consts.RoutingDomainKey_Full,
|
||||
consts.RoutingDomainKey_Keyword,
|
||||
consts.RoutingDomainKey_Suffix:
|
||||
default:
|
||||
return fmt.Errorf("addQName: unsupported key: %v", key)
|
||||
}
|
||||
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
|
||||
Key: consts.RoutingDomainKey(key),
|
||||
RuleIndex: len(b.simulatedDomainSet),
|
||||
Domains: values,
|
||||
})
|
||||
upstreamId, err := b.upstreamToId(upstream.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, responseMatchSet{
|
||||
Type: consts.MatchType_DomainSet,
|
||||
Not: f.Not,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) addUpstream(f *config_parser.Function, values []string, upstream *routing.Outbound) (err error) {
|
||||
for i, value := range values {
|
||||
upstreamName := consts.OutboundLogicalOr.String()
|
||||
if i == len(values)-1 {
|
||||
upstreamName = upstream.Name
|
||||
}
|
||||
upstreamId, err := b.upstreamToId(upstreamName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lastUpstreamId, err := b.upstreamToId(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, responseMatchSet{
|
||||
Type: consts.MatchType_Upstream,
|
||||
Value: uint16(lastUpstreamId),
|
||||
Not: f.Not,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
|
||||
for i, value := range values {
|
||||
upstreamName := consts.OutboundLogicalOr.String()
|
||||
if i == len(values)-1 {
|
||||
upstreamName = upstream.Name
|
||||
}
|
||||
upstreamId, err := b.upstreamToId(upstreamName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, responseMatchSet{
|
||||
Type: consts.MatchType_QType,
|
||||
Value: uint16(value),
|
||||
Not: f.Not,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) addFallback(fallbackOutbound config.FunctionOrString) (err error) {
|
||||
upstream, err := routing.ParseOutbound(config.FunctionOrStringToFunction(fallbackOutbound))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
upstreamId, err := b.upstreamToId(upstream.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, responseMatchSet{
|
||||
Type: consts.MatchType_Fallback,
|
||||
Upstream: uint8(upstreamId),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) Build() (matcher *ResponseMatcher, err error) {
|
||||
var m ResponseMatcher
|
||||
// Build domainMatcher.
|
||||
m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen)
|
||||
for _, domains := range b.simulatedDomainSet {
|
||||
m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
|
||||
}
|
||||
if err = m.domainMatcher.Build(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// IpSet.
|
||||
m.ipSet = b.ipSet
|
||||
|
||||
// Write routings.
|
||||
// Fallback rule MUST be the last.
|
||||
if b.rules[len(b.rules)-1].Type != consts.MatchType_Fallback {
|
||||
return nil, fmt.Errorf("fallback rule MUST be the last")
|
||||
}
|
||||
m.matches = b.rules
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
type ResponseMatcher struct {
|
||||
domainMatcher routing.DomainMatcher // All domain matchSets use one DomainMatcher.
|
||||
ipSet []*trie.Trie
|
||||
|
||||
matches []responseMatchSet
|
||||
}
|
||||
|
||||
type responseMatchSet struct {
|
||||
Value uint16
|
||||
Not bool
|
||||
Type consts.MatchType
|
||||
Upstream uint8
|
||||
}
|
||||
|
||||
func (m *ResponseMatcher) Match(
|
||||
qName string,
|
||||
qType dnsmessage.Type,
|
||||
ips []netip.Addr,
|
||||
upstream consts.DnsRequestOutboundIndex,
|
||||
) (upstreamIndex consts.DnsResponseOutboundIndex, err error) {
|
||||
if qName == "" {
|
||||
return 0, fmt.Errorf("qName cannot be empty")
|
||||
}
|
||||
qName = strings.TrimSuffix(strings.ToLower(qName), ".")
|
||||
domainMatchBitmap := m.domainMatcher.MatchDomainBitmap(qName)
|
||||
bin128 := make([]string, 0, len(ips))
|
||||
for _, ip := range ips {
|
||||
bin128 = append(bin128, prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(ip.As16()), 128)))
|
||||
}
|
||||
|
||||
goodSubrule := false
|
||||
badRule := false
|
||||
for i, match := range m.matches {
|
||||
if badRule || goodSubrule {
|
||||
goto beforeNextLoop
|
||||
}
|
||||
switch match.Type {
|
||||
case consts.MatchType_DomainSet:
|
||||
if domainMatchBitmap != nil && (domainMatchBitmap[i/32]>>(i%32))&1 > 0 {
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_IpSet:
|
||||
for _, bin128 := range bin128 {
|
||||
// Check if any of IP hit the rule.
|
||||
if m.ipSet[match.Value].HasPrefix(bin128) {
|
||||
goodSubrule = true
|
||||
break
|
||||
}
|
||||
}
|
||||
case consts.MatchType_QType:
|
||||
if qType == dnsmessage.Type(match.Value) {
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_Upstream:
|
||||
if upstream == consts.DnsRequestOutboundIndex(match.Value) {
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_Fallback:
|
||||
goodSubrule = true
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown match type: %v", match.Type)
|
||||
}
|
||||
beforeNextLoop:
|
||||
upstream := consts.DnsResponseOutboundIndex(match.Upstream)
|
||||
if upstream != consts.DnsResponseOutboundIndex_LogicalOr {
|
||||
// This match_set reaches the end of subrule.
|
||||
// We are now at end of rule, or next match_set belongs to another
|
||||
// subrule.
|
||||
|
||||
if goodSubrule == match.Not {
|
||||
// This subrule does not hit.
|
||||
badRule = true
|
||||
}
|
||||
|
||||
// Reset goodSubrule.
|
||||
goodSubrule = false
|
||||
}
|
||||
|
||||
if upstream&consts.DnsResponseOutboundIndex_LogicalMask !=
|
||||
consts.DnsResponseOutboundIndex_LogicalMask {
|
||||
// Tail of a rule (line).
|
||||
// Decide whether to hit.
|
||||
if !badRule {
|
||||
return upstream, nil
|
||||
}
|
||||
badRule = false
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("no match set hit")
|
||||
}
|
@ -3,68 +3,68 @@
|
||||
* Copyright (c) 2022-2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package control
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/mzz2017/softwind/protocol/direct"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/common/netutils"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DnsUpstreamScheme string
|
||||
type UpstreamScheme string
|
||||
|
||||
const (
|
||||
DnsUpstreamScheme_TCP DnsUpstreamScheme = "tcp"
|
||||
DnsUpstreamScheme_UDP DnsUpstreamScheme = "udp"
|
||||
DnsUpstreamScheme_TCP_UDP DnsUpstreamScheme = "tcp+udp"
|
||||
UpstreamScheme_TCP UpstreamScheme = "tcp"
|
||||
UpstreamScheme_UDP UpstreamScheme = "udp"
|
||||
UpstreamScheme_TCP_UDP UpstreamScheme = "tcp+udp"
|
||||
)
|
||||
|
||||
func (s DnsUpstreamScheme) ContainsTcp() bool {
|
||||
func (s UpstreamScheme) ContainsTcp() bool {
|
||||
switch s {
|
||||
case DnsUpstreamScheme_TCP,
|
||||
DnsUpstreamScheme_TCP_UDP:
|
||||
case UpstreamScheme_TCP,
|
||||
UpstreamScheme_TCP_UDP:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type DnsUpstream struct {
|
||||
Scheme DnsUpstreamScheme
|
||||
Hostname string
|
||||
Port uint16
|
||||
*netutils.Ip46
|
||||
}
|
||||
|
||||
func ParseDnsUpstream(dnsUpstream *url.URL) (scheme DnsUpstreamScheme, hostname string, port uint16, err error) {
|
||||
func ParseRawUpstream(raw *url.URL) (scheme UpstreamScheme, hostname string, port uint16, err error) {
|
||||
var __port string
|
||||
switch scheme = DnsUpstreamScheme(dnsUpstream.Scheme); scheme {
|
||||
case DnsUpstreamScheme_TCP, DnsUpstreamScheme_UDP, DnsUpstreamScheme_TCP_UDP:
|
||||
__port = dnsUpstream.Port()
|
||||
switch scheme = UpstreamScheme(raw.Scheme); scheme {
|
||||
case UpstreamScheme_TCP, UpstreamScheme_UDP, UpstreamScheme_TCP_UDP:
|
||||
__port = raw.Port()
|
||||
if __port == "" {
|
||||
__port = "53"
|
||||
}
|
||||
default:
|
||||
return "", "", 0, fmt.Errorf("unexpected dns_upstream format")
|
||||
}
|
||||
_port, err := strconv.ParseUint(dnsUpstream.Port(), 10, 16)
|
||||
port = uint16(_port)
|
||||
_port, err := strconv.ParseUint(__port, 10, 16)
|
||||
if err != nil {
|
||||
return "", "", 0, fmt.Errorf("parse dns_upstream port: %v", err)
|
||||
return "", "", 0, fmt.Errorf("failed to parse dns_upstream port: %v", err)
|
||||
}
|
||||
hostname = dnsUpstream.Hostname()
|
||||
port = uint16(_port)
|
||||
hostname = raw.Hostname()
|
||||
return scheme, hostname, port, nil
|
||||
}
|
||||
|
||||
func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstream, err error) {
|
||||
scheme, hostname, port, err := ParseDnsUpstream(dnsUpstream)
|
||||
type Upstream struct {
|
||||
Scheme UpstreamScheme
|
||||
Hostname string
|
||||
Port uint16
|
||||
*netutils.Ip46
|
||||
}
|
||||
|
||||
func NewUpstream(ctx context.Context, upstream *url.URL) (up *Upstream, err error) {
|
||||
scheme, hostname, port, err := ParseRawUpstream(upstream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -79,7 +79,7 @@ func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstr
|
||||
}
|
||||
}()
|
||||
|
||||
ip46, err := netutils.ParseIp46(ctx, direct.SymmetricDirect, systemDns, hostname, false)
|
||||
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve dns_upstream: %w", err)
|
||||
}
|
||||
@ -87,7 +87,7 @@ func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstr
|
||||
return nil, fmt.Errorf("dns_upstream has no record")
|
||||
}
|
||||
|
||||
return &DnsUpstream{
|
||||
return &Upstream{
|
||||
Scheme: scheme,
|
||||
Hostname: hostname,
|
||||
Port: port,
|
||||
@ -95,7 +95,7 @@ func ResolveDnsUpstream(ctx context.Context, dnsUpstream *url.URL) (up *DnsUpstr
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *DnsUpstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4protos []consts.L4ProtoStr) {
|
||||
func (u *Upstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4protos []consts.L4ProtoStr) {
|
||||
if u.Ip4.IsValid() && u.Ip6.IsValid() {
|
||||
ipversions = []consts.IpVersionStr{consts.IpVersionStr_4, consts.IpVersionStr_6}
|
||||
} else {
|
||||
@ -106,27 +106,31 @@ func (u *DnsUpstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4p
|
||||
}
|
||||
}
|
||||
switch u.Scheme {
|
||||
case DnsUpstreamScheme_TCP:
|
||||
case UpstreamScheme_TCP:
|
||||
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_TCP}
|
||||
case DnsUpstreamScheme_UDP:
|
||||
case UpstreamScheme_UDP:
|
||||
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP}
|
||||
case DnsUpstreamScheme_TCP_UDP:
|
||||
case UpstreamScheme_TCP_UDP:
|
||||
// UDP first.
|
||||
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP, consts.L4ProtoStr_TCP}
|
||||
}
|
||||
return ipversions, l4protos
|
||||
}
|
||||
|
||||
type DnsUpstreamRaw struct {
|
||||
Raw common.UrlOrEmpty
|
||||
func (u *Upstream) String() string {
|
||||
return string(u.Scheme) + "://" + net.JoinHostPort(u.Hostname, strconv.Itoa(int(u.Port)))
|
||||
}
|
||||
|
||||
type UpstreamResolver struct {
|
||||
Raw *url.URL
|
||||
// FinishInitCallback may be invoked again if err is not nil
|
||||
FinishInitCallback func(raw common.UrlOrEmpty, upstream *DnsUpstream) (err error)
|
||||
FinishInitCallback func(raw *url.URL, upstream *Upstream) (err error)
|
||||
mu sync.Mutex
|
||||
upstream *DnsUpstream
|
||||
upstream *Upstream
|
||||
init bool
|
||||
}
|
||||
|
||||
func (u *DnsUpstreamRaw) GetUpstream() (_ *DnsUpstream, err error) {
|
||||
func (u *UpstreamResolver) GetUpstream() (_ *Upstream, err error) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
if !u.init {
|
||||
@ -141,12 +145,8 @@ func (u *DnsUpstreamRaw) GetUpstream() (_ *DnsUpstream, err error) {
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), 10*time.Second)
|
||||
defer cancel()
|
||||
if !u.Raw.Empty {
|
||||
if u.upstream, err = ResolveDnsUpstream(ctx, u.Raw.Url); err != nil {
|
||||
return nil, fmt.Errorf("failed to init dns upstream: %v", err)
|
||||
}
|
||||
} else {
|
||||
// Empty string. As-is.
|
||||
if u.upstream, err = NewUpstream(ctx, u.Raw); err != nil {
|
||||
return nil, fmt.Errorf("failed to init dns upstream: %v", err)
|
||||
}
|
||||
}
|
||||
return u.upstream, nil
|
@ -90,11 +90,8 @@ func (a *AliveDialerSet) GetMinLatency() (d *Dialer, latency time.Duration) {
|
||||
}
|
||||
|
||||
func (a *AliveDialerSet) printLatencies() {
|
||||
if !a.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
return
|
||||
}
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintf("%v (%v):\n", a.dialerGroupName, a.CheckTyp.String()))
|
||||
builder.WriteString(fmt.Sprintf("Group '%v' [%v]:\n", a.dialerGroupName, a.CheckTyp.String()))
|
||||
for _, d := range a.inorderedAliveDialerSet {
|
||||
latency, ok := a.dialerToLatency[d]
|
||||
if !ok {
|
||||
@ -210,9 +207,13 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
|
||||
string(a.selectionPolicy): a.minLatency.latency,
|
||||
"group": a.dialerGroupName,
|
||||
"network": a.CheckTyp.String(),
|
||||
"new dialer": a.minLatency.dialer.Name(),
|
||||
"old dialer": oldDialerName,
|
||||
"new_dialer": a.minLatency.dialer.Name(),
|
||||
"old_dialer": oldDialerName,
|
||||
}).Infof("Group %vselects dialer", re)
|
||||
|
||||
if a.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
a.printLatencies()
|
||||
}
|
||||
} else {
|
||||
// Alive -> not alive
|
||||
defer a.aliveChangeCallback(false)
|
||||
@ -221,9 +222,6 @@ func (a *AliveDialerSet) NotifyLatencyChange(dialer *Dialer, alive bool) {
|
||||
"network": a.CheckTyp.String(),
|
||||
}).Infof("Group has no dialer alive")
|
||||
}
|
||||
if a.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
a.printLatencies()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if alive && minPolicy && a.minLatency.dialer == nil {
|
||||
|
@ -118,7 +118,7 @@ func ParseTcpCheckOption(ctx context.Context, rawURL string) (opt *TcpCheckOptio
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ip46, err := netutils.ParseIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), false)
|
||||
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -153,7 +153,7 @@ func ParseCheckDnsOption(ctx context.Context, dnsHostPort string) (opt *CheckDns
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("bad port: %v", err)
|
||||
}
|
||||
ip46, err := netutils.ParseIp46(ctx, direct.SymmetricDirect, systemDns, host, false)
|
||||
ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, host, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -409,6 +409,10 @@ func (d *Dialer) NotifyCheck() {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dialer) MustGetLatencies10(typ *NetworkType) *LatenciesN {
|
||||
return d.mustGetCollection(typ).Latencies10
|
||||
}
|
||||
|
||||
// RegisterAliveDialerSet is thread-safe.
|
||||
func (d *Dialer) RegisterAliveDialerSet(a *AliveDialerSet) {
|
||||
if a == nil {
|
||||
|
@ -19,7 +19,7 @@ type DialerSelectionPolicy struct {
|
||||
FixedIndex int
|
||||
}
|
||||
|
||||
func NewDialerSelectionPolicyFromGroupParam(param *config.GroupParam) (policy *DialerSelectionPolicy, err error) {
|
||||
func NewDialerSelectionPolicyFromGroupParam(param *config.Group) (policy *DialerSelectionPolicy, err error) {
|
||||
switch val := param.Policy.(type) {
|
||||
case string:
|
||||
switch consts.DialerSelectionPolicy(val) {
|
||||
|
@ -15,6 +15,8 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ValidDomainChars = trie.NewValidChars([]byte("0123456789abcdefghijklmnopqrstuvwxyz-.^"))
|
||||
|
||||
type AhocorasickSlimtrie struct {
|
||||
validAcIndexes []int
|
||||
validTrieIndexes []int
|
||||
@ -173,7 +175,7 @@ func (n *AhocorasickSlimtrie) Build() (err error) {
|
||||
}
|
||||
toBuild = ToSuffixTrieStrings(toBuild)
|
||||
sort.Strings(toBuild)
|
||||
n.trie[i], err = trie.NewTrie(toBuild)
|
||||
n.trie[i], err = trie.NewTrie(toBuild, ValidDomainChars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -94,7 +94,6 @@ var TestSample = []string{
|
||||
}
|
||||
|
||||
type RoutingMatcherBuilder struct {
|
||||
*routing.DefaultMatcherBuilder
|
||||
outboundName2Id map[string]uint8
|
||||
simulatedDomainSet []routing.DomainSet
|
||||
Fallback string
|
||||
@ -118,7 +117,7 @@ func (b *RoutingMatcherBuilder) AddDomain(f *config_parser.Function, key string,
|
||||
consts.RoutingDomainKey_Keyword,
|
||||
consts.RoutingDomainKey_Suffix:
|
||||
default:
|
||||
b.err = fmt.Errorf("AddDomain: unsupported key: %v", key)
|
||||
b.err = fmt.Errorf("addDomain: unsupported key: %v", key)
|
||||
return
|
||||
}
|
||||
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
|
||||
@ -132,22 +131,16 @@ func getDomain() (simulatedDomainSet []routing.DomainSet, err error) {
|
||||
var rules []*config_parser.RoutingRule
|
||||
sections, err := config_parser.Parse(`
|
||||
routing {
|
||||
pname(NetworkManager, dnsmasq, systemd-resolved) -> must_direct # Traffic of DNS in local must be direct to avoid loop when binding to WAN.
|
||||
pname(sogou-qimpanel, sogou-qimpanel-watchdog) -> block
|
||||
ip(geoip:private, 224.0.0.0/3, 'ff00::/8') -> direct # Put it in front unless you know what you're doing.
|
||||
domain(geosite:bing)->us
|
||||
domain(full:dns.google) && port(53) -> direct
|
||||
domain(full:dns.google) -> direct
|
||||
domain(geosite:category-ads-all) -> block
|
||||
ip(geoip:private) -> direct
|
||||
ip(geoip:cn) -> direct
|
||||
domain(geosite:cn) -> direct
|
||||
fallback: my_group
|
||||
}`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var r config.Routing
|
||||
if err = config.RoutingRuleAndParamParser(reflect.ValueOf(&r), sections[0]); err != nil {
|
||||
if err = config.SectionParser(reflect.ValueOf(&r), sections[0]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rules, err = routing.ApplyRulesOptimizers(r.Rules,
|
||||
@ -159,8 +152,13 @@ routing {
|
||||
return nil, fmt.Errorf("ApplyRulesOptimizers error:\n%w", err)
|
||||
}
|
||||
builder := RoutingMatcherBuilder{}
|
||||
if err = routing.ApplyMatcherBuilder(logrus.StandardLogger(), &builder, rules, r.Fallback); err != nil {
|
||||
return nil, fmt.Errorf("ApplyMatcherBuilder: %w", err)
|
||||
rb := routing.NewRulesBuilder(logrus.StandardLogger())
|
||||
rb.RegisterFunctionParser("domain", func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) {
|
||||
builder.AddDomain(f, key, paramValueGroup, overrideOutbound)
|
||||
return nil
|
||||
})
|
||||
if err = rb.Apply(rules); err != nil {
|
||||
return nil, fmt.Errorf("Apply: %w", err)
|
||||
}
|
||||
return builder.simulatedDomainSet, nil
|
||||
}
|
||||
|
139
component/routing/function_parser.go
Normal file
139
component/routing/function_parser.go
Normal file
@ -0,0 +1,139 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/pkg/config_parser"
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type FunctionParser func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error)
|
||||
|
||||
// Preset function parser factories.
|
||||
|
||||
// PlainParserFactory is for style unity.
|
||||
func PlainParserFactory(callback func(f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
return callback(f, key, paramValueGroup, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
// EmptyKeyPlainParserFactory only accepts function with empty key.
|
||||
func EmptyKeyPlainParserFactory(callback func(f *config_parser.Function, values []string, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
if key != "" {
|
||||
return fmt.Errorf("this function cannot accept a key")
|
||||
}
|
||||
return callback(f, paramValueGroup, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func IpParserFactory(callback func(f *config_parser.Function, cidrs []netip.Prefix, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
cidrs, err := parsePrefixes(paramValueGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return callback(f, cidrs, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func MacParserFactory(callback func(f *config_parser.Function, macAddrs [][6]byte, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
var macAddrs [][6]byte
|
||||
for _, v := range paramValueGroup {
|
||||
mac, err := common.ParseMac(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
macAddrs = append(macAddrs, mac)
|
||||
}
|
||||
return callback(f, macAddrs, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func PortRangeParserFactory(callback func(f *config_parser.Function, portRanges [][2]uint16, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
var portRanges [][2]uint16
|
||||
for _, v := range paramValueGroup {
|
||||
portRange, err := common.ParsePortRange(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
portRanges = append(portRanges, portRange)
|
||||
}
|
||||
return callback(f, portRanges, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func L4ProtoParserFactory(callback func(f *config_parser.Function, l4protoType consts.L4ProtoType, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
var l4protoType consts.L4ProtoType
|
||||
for _, v := range paramValueGroup {
|
||||
switch v {
|
||||
case "tcp":
|
||||
l4protoType |= consts.L4ProtoType_TCP
|
||||
case "udp":
|
||||
l4protoType |= consts.L4ProtoType_UDP
|
||||
}
|
||||
}
|
||||
return callback(f, l4protoType, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func IpVersionParserFactory(callback func(f *config_parser.Function, ipVersion consts.IpVersionType, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
var ipVersion consts.IpVersionType
|
||||
for _, v := range paramValueGroup {
|
||||
switch v {
|
||||
case "4":
|
||||
ipVersion |= consts.IpVersion_4
|
||||
case "6":
|
||||
ipVersion |= consts.IpVersion_6
|
||||
}
|
||||
}
|
||||
return callback(f, ipVersion, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func ProcessNameParserFactory(callback func(f *config_parser.Function, procNames [][consts.TaskCommLen]byte, overrideOutbound *Outbound) (err error)) FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *Outbound) (err error) {
|
||||
var procNames [][consts.TaskCommLen]byte
|
||||
for _, v := range paramValueGroup {
|
||||
if len([]byte(v)) > consts.TaskCommLen {
|
||||
log.Infof(`pname routing: trim "%v" to "%v" because it is too long.`, v, string([]byte(v)[:consts.TaskCommLen]))
|
||||
}
|
||||
procNames = append(procNames, toProcessName(v))
|
||||
}
|
||||
return callback(f, procNames, overrideOutbound)
|
||||
}
|
||||
}
|
||||
|
||||
func parsePrefixes(values []string) (cidrs []netip.Prefix, err error) {
|
||||
for _, value := range values {
|
||||
toParse := value
|
||||
if strings.LastIndexByte(value, '/') == -1 {
|
||||
toParse += "/32"
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(toParse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse %v: %w", value, err)
|
||||
}
|
||||
cidrs = append(cidrs, prefix)
|
||||
}
|
||||
return cidrs, nil
|
||||
}
|
||||
|
||||
func toProcessName(processName string) (procName [consts.TaskCommLen]byte) {
|
||||
n := []byte(processName)
|
||||
copy(procName[:], n)
|
||||
return procName
|
||||
}
|
@ -8,18 +8,11 @@ package routing
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/pkg/config_parser"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var FakeOutbound_MUST_DIRECT = consts.OutboundMustDirect.String()
|
||||
var FakeOutbound_AND = consts.OutboundLogicalAnd.String()
|
||||
var FakeOutbound_OR = consts.OutboundLogicalOr.String()
|
||||
|
||||
type DomainSet struct {
|
||||
Key consts.RoutingDomainKey
|
||||
RuleIndex int
|
||||
@ -31,22 +24,71 @@ type Outbound struct {
|
||||
Mark uint32
|
||||
}
|
||||
|
||||
type MatcherBuilder interface {
|
||||
AddDomain(f *config_parser.Function, key string, values []string, outbound *Outbound)
|
||||
AddIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound)
|
||||
AddPort(f *config_parser.Function, values [][2]uint16, outbound *Outbound)
|
||||
AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound)
|
||||
AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound *Outbound)
|
||||
AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *Outbound)
|
||||
AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *Outbound)
|
||||
AddSourceMac(f *config_parser.Function, values [][6]byte, outbound *Outbound)
|
||||
AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *Outbound)
|
||||
AddFallback(outbound *Outbound)
|
||||
AddAnyBefore(f *config_parser.Function, key string, values []string, outbound *Outbound)
|
||||
AddAnyAfter(f *config_parser.Function, key string, values []string, outbound *Outbound)
|
||||
type RulesBuilder struct {
|
||||
log *logrus.Logger
|
||||
parsers map[string]FunctionParser
|
||||
}
|
||||
|
||||
func GroupParamValuesByKey(params []*config_parser.Param) (keyToValues map[string][]string, keyOrder []string) {
|
||||
func NewRulesBuilder(log *logrus.Logger) *RulesBuilder {
|
||||
return &RulesBuilder{
|
||||
log: log,
|
||||
parsers: make(map[string]FunctionParser),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *RulesBuilder) RegisterFunctionParser(funcName string, parser FunctionParser) {
|
||||
b.parsers[funcName] = parser
|
||||
}
|
||||
|
||||
func (b *RulesBuilder) Apply(rules []*config_parser.RoutingRule) (err error) {
|
||||
for _, rule := range rules {
|
||||
b.log.Debugln("[rule]", rule.String(true))
|
||||
outbound, err := ParseOutbound(&rule.Outbound)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// rule is like: domain(domain:baidu.com) && port(443) -> proxy
|
||||
for iFunc, f := range rule.AndFunctions {
|
||||
// f is like: domain(domain:baidu.com)
|
||||
functionParser, ok := b.parsers[f.Name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown function: %v", f.Name)
|
||||
}
|
||||
paramValueGroups, keyOrder := groupParamValuesByKey(f.Params)
|
||||
for jMatchSet, key := range keyOrder {
|
||||
paramValueGroup := paramValueGroups[key]
|
||||
// Preprocess the outbound.
|
||||
overrideOutbound := &Outbound{
|
||||
Name: consts.OutboundLogicalOr.String(),
|
||||
Mark: outbound.Mark,
|
||||
}
|
||||
if jMatchSet == len(keyOrder)-1 {
|
||||
overrideOutbound.Name = consts.OutboundLogicalAnd.String()
|
||||
if iFunc == len(rule.AndFunctions)-1 {
|
||||
overrideOutbound.Name = outbound.Name
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Debug
|
||||
symNot := ""
|
||||
if f.Not {
|
||||
symNot = "!"
|
||||
}
|
||||
b.log.Debugf("\t%v%v(%v) -> %v", symNot, f.Name, key, overrideOutbound.Name)
|
||||
}
|
||||
|
||||
if err = functionParser(b.log, f, key, paramValueGroup, overrideOutbound); err != nil {
|
||||
return fmt.Errorf("failed to parse '%v': %w", f.String(false), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func groupParamValuesByKey(params []*config_parser.Param) (keyToValues map[string][]string, keyOrder []string) {
|
||||
groups := make(map[string][]string)
|
||||
for _, param := range params {
|
||||
if _, ok := groups[param.Key]; !ok {
|
||||
@ -57,28 +99,7 @@ func GroupParamValuesByKey(params []*config_parser.Param) (keyToValues map[strin
|
||||
return groups, keyOrder
|
||||
}
|
||||
|
||||
func ParsePrefixes(values []string) (cidrs []netip.Prefix, err error) {
|
||||
for _, value := range values {
|
||||
toParse := value
|
||||
if strings.LastIndexByte(value, '/') == -1 {
|
||||
toParse += "/32"
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(toParse)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse %v: %w", value, err)
|
||||
}
|
||||
cidrs = append(cidrs, prefix)
|
||||
}
|
||||
return cidrs, nil
|
||||
}
|
||||
|
||||
func ToProcessName(processName string) (procName [consts.TaskCommLen]byte) {
|
||||
n := []byte(processName)
|
||||
copy(procName[:], n)
|
||||
return procName
|
||||
}
|
||||
|
||||
func parseOutbound(rawOutbound *config_parser.Function) (outbound *Outbound, err error) {
|
||||
func ParseOutbound(rawOutbound *config_parser.Function) (outbound *Outbound, err error) {
|
||||
outbound = &Outbound{
|
||||
Name: rawOutbound.Name,
|
||||
Mark: 0,
|
||||
@ -98,164 +119,3 @@ func parseOutbound(rawOutbound *config_parser.Function) (outbound *Outbound, err
|
||||
}
|
||||
return outbound, nil
|
||||
}
|
||||
|
||||
func ApplyMatcherBuilder(log *logrus.Logger, builder MatcherBuilder, rules []*config_parser.RoutingRule, fallbackOutbound interface{}) (err error) {
|
||||
for _, rule := range rules {
|
||||
log.Debugln("[rule]", rule.String(true))
|
||||
outbound, err := parseOutbound(&rule.Outbound)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// rule is like: domain(domain:baidu.com) && port(443) -> proxy
|
||||
for iFunc, f := range rule.AndFunctions {
|
||||
// f is like: domain(domain:baidu.com)
|
||||
paramValueGroups, keyOrder := GroupParamValuesByKey(f.Params)
|
||||
for jMatchSet, key := range keyOrder {
|
||||
paramValueGroup := paramValueGroups[key]
|
||||
// Preprocess the outbound.
|
||||
overrideOutbound := &Outbound{
|
||||
Name: FakeOutbound_OR,
|
||||
Mark: outbound.Mark,
|
||||
}
|
||||
if jMatchSet == len(keyOrder)-1 {
|
||||
overrideOutbound.Name = FakeOutbound_AND
|
||||
if iFunc == len(rule.AndFunctions)-1 {
|
||||
overrideOutbound.Name = outbound.Name
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Debug
|
||||
symNot := ""
|
||||
if f.Not {
|
||||
symNot = "!"
|
||||
}
|
||||
log.Debugf("\t%v%v(%v) -> %v", symNot, f.Name, key, overrideOutbound)
|
||||
}
|
||||
|
||||
builder.AddAnyBefore(f, key, paramValueGroup, overrideOutbound)
|
||||
switch f.Name {
|
||||
case consts.Function_Domain:
|
||||
builder.AddDomain(f, key, paramValueGroup, overrideOutbound)
|
||||
case consts.Function_Ip, consts.Function_SourceIp:
|
||||
cidrs, err := ParsePrefixes(paramValueGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f.Name == consts.Function_Ip {
|
||||
builder.AddIp(f, cidrs, overrideOutbound)
|
||||
} else {
|
||||
builder.AddSourceIp(f, cidrs, overrideOutbound)
|
||||
}
|
||||
case consts.Function_Mac:
|
||||
var macAddrs [][6]byte
|
||||
for _, v := range paramValueGroup {
|
||||
mac, err := common.ParseMac(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
macAddrs = append(macAddrs, mac)
|
||||
}
|
||||
builder.AddSourceMac(f, macAddrs, overrideOutbound)
|
||||
case consts.Function_Port, consts.Function_SourcePort:
|
||||
var portRanges [][2]uint16
|
||||
for _, v := range paramValueGroup {
|
||||
portRange, err := common.ParsePortRange(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
portRanges = append(portRanges, portRange)
|
||||
}
|
||||
if f.Name == consts.Function_Port {
|
||||
builder.AddPort(f, portRanges, overrideOutbound)
|
||||
} else {
|
||||
builder.AddSourcePort(f, portRanges, overrideOutbound)
|
||||
}
|
||||
case consts.Function_L4Proto:
|
||||
var l4protoType consts.L4ProtoType
|
||||
for _, v := range paramValueGroup {
|
||||
switch v {
|
||||
case "tcp":
|
||||
l4protoType |= consts.L4ProtoType_TCP
|
||||
case "udp":
|
||||
l4protoType |= consts.L4ProtoType_UDP
|
||||
}
|
||||
}
|
||||
builder.AddL4Proto(f, l4protoType, overrideOutbound)
|
||||
case consts.Function_IpVersion:
|
||||
var ipVersion consts.IpVersionType
|
||||
for _, v := range paramValueGroup {
|
||||
switch v {
|
||||
case "4":
|
||||
ipVersion |= consts.IpVersion_4
|
||||
case "6":
|
||||
ipVersion |= consts.IpVersion_6
|
||||
}
|
||||
}
|
||||
builder.AddIpVersion(f, ipVersion, overrideOutbound)
|
||||
case consts.Function_ProcessName:
|
||||
var procNames [][consts.TaskCommLen]byte
|
||||
for _, v := range paramValueGroup {
|
||||
if len([]byte(v)) > consts.TaskCommLen {
|
||||
log.Infof(`pname routing: trim "%v" to "%v" because it is too long.`, v, string([]byte(v)[:consts.TaskCommLen]))
|
||||
}
|
||||
procNames = append(procNames, ToProcessName(v))
|
||||
}
|
||||
builder.AddProcessName(f, procNames, overrideOutbound)
|
||||
default:
|
||||
return fmt.Errorf("unsupported function name: %v", f.Name)
|
||||
}
|
||||
builder.AddAnyAfter(f, key, paramValueGroup, overrideOutbound)
|
||||
}
|
||||
}
|
||||
}
|
||||
var rawFallback *config_parser.Function
|
||||
switch fallback := fallbackOutbound.(type) {
|
||||
case string:
|
||||
rawFallback = &config_parser.Function{Name: fallback}
|
||||
case *config_parser.Function:
|
||||
rawFallback = fallback
|
||||
default:
|
||||
return fmt.Errorf("unknown type of 'fallback' in section routing: %T", fallback)
|
||||
}
|
||||
fallback, err := parseOutbound(rawFallback)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
builder.AddAnyBefore(&config_parser.Function{
|
||||
Name: "fallback",
|
||||
}, "", nil, fallback)
|
||||
builder.AddFallback(fallback)
|
||||
builder.AddAnyAfter(&config_parser.Function{
|
||||
Name: "fallback",
|
||||
}, "", nil, fallback)
|
||||
return nil
|
||||
}
|
||||
|
||||
type DefaultMatcherBuilder struct {
|
||||
}
|
||||
|
||||
func (d *DefaultMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddSourceMac(f *config_parser.Function, values [][6]byte, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddFallback(outbound *Outbound) {}
|
||||
func (d *DefaultMatcherBuilder) AddAnyBefore(f *config_parser.Function, key string, values []string, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *Outbound) {
|
||||
}
|
||||
func (d *DefaultMatcherBuilder) AddAnyAfter(f *config_parser.Function, key string, values []string, outbound *Outbound) {
|
||||
}
|
||||
|
@ -157,7 +157,7 @@ func (r *CryptoFrameRelocation) BytesFromPool() []byte {
|
||||
return pool.Get(0)
|
||||
}
|
||||
right := r.o[len(r.o)-1]
|
||||
return r.copyBytes(0, 0, len(r.o)-1, len(right.Data)-1, r.length)
|
||||
return r.copyBytesToPool(0, 0, len(r.o)-1, len(right.Data)-1, r.length)
|
||||
}
|
||||
|
||||
// RangeFromPool copy bytes from iUpperAppOffset to jUpperAppOffset.
|
||||
@ -191,11 +191,11 @@ func (r *CryptoFrameRelocation) RangeFromPool(i, j int) []byte {
|
||||
}
|
||||
}
|
||||
|
||||
return r.copyBytes(iOuter, iInner, jOuter, jInner, j-i+1)
|
||||
return r.copyBytesToPool(iOuter, iInner, jOuter, jInner, j-i+1)
|
||||
}
|
||||
|
||||
// copyBytes copy bytes including i and j.
|
||||
func (r *CryptoFrameRelocation) copyBytes(iOuter, iInner, jOuter, jInner, size int) []byte {
|
||||
// copyBytesToPool copy bytes including i and j.
|
||||
func (r *CryptoFrameRelocation) copyBytesToPool(iOuter, iInner, jOuter, jInner, size int) []byte {
|
||||
b := pool.Get(size)
|
||||
//io := r.o[iOuter]
|
||||
k := 0
|
||||
|
@ -8,6 +8,7 @@ package sniffing
|
||||
import (
|
||||
"github.com/mzz2017/softwind/pool"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Sniffer struct {
|
||||
@ -15,12 +16,13 @@ type Sniffer struct {
|
||||
buf []byte
|
||||
bufAt int
|
||||
stream bool
|
||||
readMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewStreamSniffer(r io.Reader, bufSize int) *Sniffer {
|
||||
s := &Sniffer{
|
||||
r: r,
|
||||
buf: pool.Get(bufSize),
|
||||
buf: make([]byte, bufSize),
|
||||
stream: true,
|
||||
}
|
||||
return s
|
||||
@ -37,6 +39,8 @@ func NewPacketSniffer(data []byte) *Sniffer {
|
||||
type sniff func() (d string, err error)
|
||||
|
||||
func (s *Sniffer) SniffTcp() (d string, err error) {
|
||||
s.readMu.Lock()
|
||||
defer s.readMu.Unlock()
|
||||
if s.stream {
|
||||
n, err := s.r.Read(s.buf)
|
||||
if err != nil {
|
||||
@ -65,6 +69,8 @@ func (s *Sniffer) SniffTcp() (d string, err error) {
|
||||
}
|
||||
|
||||
func (s *Sniffer) Read(p []byte) (n int, err error) {
|
||||
s.readMu.Lock()
|
||||
defer s.readMu.Unlock()
|
||||
if s.buf != nil && s.bufAt < len(s.buf) {
|
||||
// Read buf first.
|
||||
n = copy(p, s.buf[s.bufAt:])
|
||||
@ -84,6 +90,5 @@ func (s *Sniffer) Read(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (s *Sniffer) Close() (err error) {
|
||||
// DO NOT use pool.Put() here because Close() may not interrupt the reading, which will modify the value of the pool buffer.
|
||||
return nil
|
||||
}
|
||||
|
@ -7,7 +7,6 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"github.com/v2rayA/dae/pkg/config_parser"
|
||||
"reflect"
|
||||
"time"
|
||||
@ -18,40 +17,67 @@ type Global struct {
|
||||
LogLevel string `mapstructure:"log_level" default:"info"`
|
||||
// We use DirectTcpCheckUrl to check (tcp)*(ipv4/ipv6) connectivity for direct.
|
||||
//DirectTcpCheckUrl string `mapstructure:"direct_tcp_check_url" default:"http://www.qualcomm.cn/generate_204"`
|
||||
TcpCheckUrl string `mapstructure:"tcp_check_url" default:"http://keep-alv.google.com/generate_204"`
|
||||
UdpCheckDns string `mapstructure:"udp_check_dns" default:"dns.google:53"`
|
||||
CheckInterval time.Duration `mapstructure:"check_interval" default:"30s"`
|
||||
CheckTolerance time.Duration `mapstructure:"check_tolerance" default:"0"`
|
||||
DnsUpstream common.UrlOrEmpty `mapstructure:"dns_upstream" default:""`
|
||||
LanInterface []string `mapstructure:"lan_interface"`
|
||||
LanNatDirect bool `mapstructure:"lan_nat_direct" default:"true"`
|
||||
WanInterface []string `mapstructure:"wan_interface"`
|
||||
AllowInsecure bool `mapstructure:"allow_insecure" default:"false"`
|
||||
DialMode string `mapstructure:"dial_mode" default:"domain"`
|
||||
TcpCheckUrl string `mapstructure:"tcp_check_url" default:"http://keep-alv.google.com/generate_204"`
|
||||
UdpCheckDns string `mapstructure:"udp_check_dns" default:"dns.google:53"`
|
||||
CheckInterval time.Duration `mapstructure:"check_interval" default:"30s"`
|
||||
CheckTolerance time.Duration `mapstructure:"check_tolerance" default:"0"`
|
||||
DnsUpstream string `mapstructure:"dns_upstream" default:"<empty>"`
|
||||
LanInterface []string `mapstructure:"lan_interface"`
|
||||
LanNatDirect bool `mapstructure:"lan_nat_direct" default:"true"`
|
||||
WanInterface []string `mapstructure:"wan_interface"`
|
||||
AllowInsecure bool `mapstructure:"allow_insecure" default:"false"`
|
||||
DialMode string `mapstructure:"dial_mode" default:"domain"`
|
||||
}
|
||||
|
||||
type FunctionOrString interface{}
|
||||
|
||||
func FunctionOrStringToFunction(fs FunctionOrString) (f *config_parser.Function) {
|
||||
switch fs := fs.(type) {
|
||||
case string:
|
||||
return &config_parser.Function{Name: fs}
|
||||
case *config_parser.Function:
|
||||
return fs
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown type of 'fallback' in section routing: %T", fs))
|
||||
}
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
Name string
|
||||
Param GroupParam
|
||||
}
|
||||
Name string `mapstructure:"_"`
|
||||
|
||||
type GroupParam struct {
|
||||
Filter []*config_parser.Function `mapstructure:"filter"`
|
||||
Policy interface{} `mapstructure:"policy" required:""`
|
||||
}
|
||||
|
||||
type DnsRequestRouting struct {
|
||||
Rules []*config_parser.RoutingRule `mapstructure:"_"`
|
||||
Fallback FunctionOrString `mapstructure:"fallback" required:""`
|
||||
}
|
||||
type DnsResponseRouting struct {
|
||||
Rules []*config_parser.RoutingRule `mapstructure:"_"`
|
||||
Fallback FunctionOrString `mapstructure:"fallback" required:""`
|
||||
}
|
||||
type Dns struct {
|
||||
Upstream []string `mapstructure:"upstream"`
|
||||
Routing struct {
|
||||
Request DnsRequestRouting `mapstructure:"request"`
|
||||
Response DnsResponseRouting `mapstructure:"response"`
|
||||
} `mapstructure:"routing"`
|
||||
}
|
||||
|
||||
type Routing struct {
|
||||
Rules []*config_parser.RoutingRule `mapstructure:"_"`
|
||||
Fallback interface{} `mapstructure:"fallback"`
|
||||
Final interface{} `mapstructure:"final"`
|
||||
Fallback FunctionOrString `mapstructure:"fallback"`
|
||||
Final FunctionOrString `mapstructure:"final"`
|
||||
}
|
||||
|
||||
type Params struct {
|
||||
Global Global `mapstructure:"global" parser:"ParamParser"`
|
||||
Subscription []string `mapstructure:"subscription" parser:"StringListParser"`
|
||||
Node []string `mapstructure:"node" parser:"StringListParser"`
|
||||
Group []Group `mapstructure:"group" parser:"GroupListParser"`
|
||||
Routing Routing `mapstructure:"routing" parser:"RoutingRuleAndParamParser"`
|
||||
Global Global `mapstructure:"global" required:""`
|
||||
Subscription []string `mapstructure:"subscription"`
|
||||
Node []string `mapstructure:"node"`
|
||||
Group []Group `mapstructure:"group" required:""`
|
||||
Routing Routing `mapstructure:"routing" required:""`
|
||||
Dns Dns `mapstructure:"dns"`
|
||||
}
|
||||
|
||||
// New params from sections. This func assumes merging (section "include") and deduplication for section names has been executed.
|
||||
@ -82,21 +108,15 @@ func New(sections []*config_parser.Section) (params *Params, err error) {
|
||||
}
|
||||
section, ok := nameToSection[sectionName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("section %v is required but not provided", sectionName)
|
||||
}
|
||||
|
||||
// Find corresponding parser func.
|
||||
parserName, ok := structField.Tag.Lookup("parser")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no parser is specified in field %v", structField.Name)
|
||||
}
|
||||
parser, ok := ParserMap[parserName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown parser %v in field %v", parserName, structField.Name)
|
||||
if _, required := structField.Tag.Lookup("required"); required {
|
||||
return nil, fmt.Errorf("section %v is required but not provided", sectionName)
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Parse section and unmarshal to field.
|
||||
if err := parser(field.Addr(), section.Val); err != nil {
|
||||
if err := SectionParser(field.Addr(), section.Val); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse \"%v\": %w", sectionName, err)
|
||||
}
|
||||
section.Parsed = true
|
||||
|
@ -8,12 +8,15 @@ package config
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
)
|
||||
|
||||
type patch func(params *Params) error
|
||||
|
||||
var patches = []patch{
|
||||
patchRoutingFallback,
|
||||
patchEmptyDns,
|
||||
patchDeprecatedGlobalDnsUpstream,
|
||||
}
|
||||
|
||||
func patchRoutingFallback(params *Params) error {
|
||||
@ -28,3 +31,20 @@ func patchRoutingFallback(params *Params) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func patchEmptyDns(params *Params) error {
|
||||
if params.Dns.Routing.Request.Fallback == nil {
|
||||
params.Dns.Routing.Request.Fallback = consts.DnsRequestOutboundIndex_AsIs.String()
|
||||
}
|
||||
if params.Dns.Routing.Response.Fallback == nil {
|
||||
params.Dns.Routing.Response.Fallback = consts.DnsResponseOutboundIndex_Accept.String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func patchDeprecatedGlobalDnsUpstream(params *Params) error {
|
||||
if params.Global.DnsUpstream != "<empty>" {
|
||||
return fmt.Errorf("'global.dns_upstream' was deprecated, please refer to the latest examples and docs for help")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -13,16 +13,6 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Parser is section items parser
|
||||
type Parser func(to reflect.Value, section *config_parser.Section) error
|
||||
|
||||
var ParserMap = map[string]Parser{
|
||||
"StringListParser": StringListParser,
|
||||
"ParamParser": ParamParser,
|
||||
"GroupListParser": GroupListParser,
|
||||
"RoutingRuleAndParamParser": RoutingRuleAndParamParser,
|
||||
}
|
||||
|
||||
func StringListParser(to reflect.Value, section *config_parser.Section) error {
|
||||
if to.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("StringListParser can only unmarshal section to *[]string")
|
||||
@ -44,7 +34,7 @@ func StringListParser(to reflect.Value, section *config_parser.Section) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func paramParser(to reflect.Value, section *config_parser.Section, ignoreType []reflect.Type) error {
|
||||
func ParamParser(to reflect.Value, section *config_parser.Section, ignoreType []reflect.Type) error {
|
||||
if to.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("ParamParser can only unmarshal section to *struct")
|
||||
}
|
||||
@ -67,7 +57,7 @@ func paramParser(to reflect.Value, section *config_parser.Section, ignoreType []
|
||||
// Set up key to field mapping.
|
||||
key, ok := structField.Tag.Lookup("mapstructure")
|
||||
if !ok {
|
||||
return fmt.Errorf("field %v has no mapstructure tag", structField.Name)
|
||||
return fmt.Errorf("field \"%v\" has no mapstructure tag", structField.Name)
|
||||
}
|
||||
if key == "_" {
|
||||
// omit
|
||||
@ -95,11 +85,11 @@ func paramParser(to reflect.Value, section *config_parser.Section, ignoreType []
|
||||
switch itemVal := item.Value.(type) {
|
||||
case *config_parser.Param:
|
||||
if itemVal.Key == "" {
|
||||
return fmt.Errorf("section %v does not support text without a key: %v", section.Name, itemVal.String(true))
|
||||
return fmt.Errorf("unsupported text without a key: %v", itemVal.String(true))
|
||||
}
|
||||
field, ok := keyToField[itemVal.Key]
|
||||
if !ok {
|
||||
return fmt.Errorf("section %v does not support key: %v", section.Name, itemVal.Key)
|
||||
return fmt.Errorf("unexpected key: %v", itemVal.Key)
|
||||
}
|
||||
if itemVal.AndFunctions != nil {
|
||||
// AndFunctions.
|
||||
@ -108,7 +98,7 @@ func paramParser(to reflect.Value, section *config_parser.Section, ignoreType []
|
||||
field.Val.Type() == reflect.TypeOf(itemVal.AndFunctions) {
|
||||
field.Val.Set(reflect.ValueOf(itemVal.AndFunctions))
|
||||
} else {
|
||||
return fmt.Errorf("failed to parse \"%v.%v\": value \"%v\" cannot be convert to %v", section.Name, itemVal.Key, itemVal.Val, field.Val.Type().String())
|
||||
return fmt.Errorf("failed to parse \"%v\": value \"%v\" cannot be convert to %v", itemVal.Key, itemVal.Val, field.Val.Type().String())
|
||||
}
|
||||
} else {
|
||||
// String value.
|
||||
@ -122,21 +112,42 @@ func paramParser(to reflect.Value, section *config_parser.Section, ignoreType []
|
||||
for _, value := range values {
|
||||
vPointerNew := reflect.New(field.Val.Type().Elem())
|
||||
if !common.FuzzyDecode(vPointerNew.Interface(), value) {
|
||||
return fmt.Errorf("failed to parse \"%v.%v\": value \"%v\" cannot be convert to %v", section.Name, itemVal.Key, itemVal.Val, field.Val.Type().Elem().String())
|
||||
return fmt.Errorf("failed to parse \"%v\": value \"%v\" cannot be convert to %v", itemVal.Key, itemVal.Val, field.Val.Type().Elem().String())
|
||||
}
|
||||
field.Val.Set(reflect.Append(field.Val, vPointerNew.Elem()))
|
||||
}
|
||||
default:
|
||||
// Field is not interface{}, we can decode.
|
||||
if !common.FuzzyDecode(field.Val.Addr().Interface(), itemVal.Val) {
|
||||
return fmt.Errorf("failed to parse \"%v.%v\": value \"%v\" cannot be convert to %v", section.Name, itemVal.Key, itemVal.Val, field.Val.Type().String())
|
||||
return fmt.Errorf("failed to parse \"%v\": value \"%v\" cannot be convert to %v", itemVal.Key, itemVal.Val, field.Val.Type().String())
|
||||
}
|
||||
}
|
||||
}
|
||||
field.Set = true
|
||||
case *config_parser.Section:
|
||||
// Named section config item.
|
||||
field, ok := keyToField[itemVal.Name]
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected key: %v", itemVal.Name)
|
||||
}
|
||||
if err := SectionParser(field.Val.Addr(), itemVal); err != nil {
|
||||
return fmt.Errorf("failed to parse %v: %w", itemVal.Name, err)
|
||||
}
|
||||
field.Set = true
|
||||
case *config_parser.RoutingRule:
|
||||
// Assign. "to" should have field "Rules".
|
||||
structField, ok := to.Type().FieldByName("Rules")
|
||||
if !ok || structField.Type != reflect.TypeOf([]*config_parser.RoutingRule{}) {
|
||||
return fmt.Errorf("unexpected type: \"routing rule\": %v", itemVal.String(true))
|
||||
}
|
||||
if structField.Tag.Get("mapstructure") != "_" {
|
||||
return fmt.Errorf("a []*RoutingRule field \"Rules\" with mapstructure:\"_\" is required in struct %v to parse section", to.Type().String())
|
||||
}
|
||||
field := to.FieldByName("Rules")
|
||||
field.Set(reflect.Append(field, reflect.ValueOf(itemVal)))
|
||||
default:
|
||||
if _, ignore := ignoreTypeSet[reflect.TypeOf(itemVal)]; !ignore {
|
||||
return fmt.Errorf("section %v does not support type %v: %v", section.Name, item.Type.String(), item.String())
|
||||
return fmt.Errorf("unexpected type %v: %v", item.Type.String(), item.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -155,76 +166,66 @@ func paramParser(to reflect.Value, section *config_parser.Section, ignoreType []
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParamParser(to reflect.Value, section *config_parser.Section) error {
|
||||
return paramParser(to, section, nil)
|
||||
}
|
||||
|
||||
func GroupListParser(to reflect.Value, section *config_parser.Section) error {
|
||||
func SectionParser(to reflect.Value, section *config_parser.Section) error {
|
||||
if to.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("GroupListParser can only unmarshal section to *[]Group")
|
||||
return fmt.Errorf("SectionParser can only unmarshal section to a pointer")
|
||||
}
|
||||
to = to.Elem()
|
||||
if to.Type() != reflect.TypeOf([]Group{}) {
|
||||
return fmt.Errorf("GroupListParser can only unmarshal section to *[]Group")
|
||||
}
|
||||
|
||||
for _, item := range section.Items {
|
||||
switch itemVal := item.Value.(type) {
|
||||
case *config_parser.Section:
|
||||
group := Group{
|
||||
Name: itemVal.Name,
|
||||
Param: GroupParam{},
|
||||
switch to.Kind() {
|
||||
case reflect.Slice:
|
||||
elemType := to.Type().Elem()
|
||||
switch elemType.Kind() {
|
||||
case reflect.String:
|
||||
return StringListParser(to.Addr(), section)
|
||||
case reflect.Struct:
|
||||
// "to" is a section list (sections in section).
|
||||
/**
|
||||
to {
|
||||
field1 {
|
||||
...
|
||||
}
|
||||
field2 {
|
||||
...
|
||||
}
|
||||
}
|
||||
should be parsed to:
|
||||
to []struct {
|
||||
Name string `mapstructure: "_"`
|
||||
...
|
||||
}
|
||||
*/
|
||||
// The struct should contain Name.
|
||||
nameStructField, ok := elemType.FieldByName("Name")
|
||||
if !ok || nameStructField.Type.Kind() != reflect.String || nameStructField.Tag.Get("mapstructure") != "_" {
|
||||
return fmt.Errorf("a string field \"Name\" with mapstructure:\"_\" is required in struct %v to parse section", to.Type().Elem().String())
|
||||
}
|
||||
paramVal := reflect.ValueOf(&group.Param)
|
||||
if err := paramParser(paramVal, itemVal, nil); err != nil {
|
||||
return fmt.Errorf("failed to parse \"%v\": %w", itemVal.Name, err)
|
||||
// Scan sections.
|
||||
for _, item := range section.Items {
|
||||
elem := reflect.New(elemType).Elem()
|
||||
switch itemVal := item.Value.(type) {
|
||||
case *config_parser.Section:
|
||||
elem.FieldByName("Name").SetString(itemVal.Name)
|
||||
if err := SectionParser(elem.Addr(), itemVal); err != nil {
|
||||
return fmt.Errorf("error when parse \"%v\": %w", itemVal.Name, err)
|
||||
}
|
||||
to.Set(reflect.Append(to, elem))
|
||||
default:
|
||||
return fmt.Errorf("unmatched type: %v -> %v", item.Type.String(), elemType)
|
||||
}
|
||||
}
|
||||
to.Set(reflect.Append(to, reflect.ValueOf(group)))
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("section %v does not support type %v: %v", section.Name, item.Type.String(), item.String())
|
||||
goto unsupported
|
||||
}
|
||||
case reflect.Struct:
|
||||
// Section.
|
||||
return ParamParser(to.Addr(), section, nil)
|
||||
default:
|
||||
goto unsupported
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RoutingRuleAndParamParser(to reflect.Value, section *config_parser.Section) error {
|
||||
if to.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("RoutingRuleAndParamParser can only unmarshal section to *struct")
|
||||
}
|
||||
to = to.Elem()
|
||||
if to.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("RoutingRuleAndParamParser can only unmarshal section to *struct")
|
||||
}
|
||||
|
||||
// Find the first []*RoutingRule field to unmarshal.
|
||||
targetType := reflect.TypeOf([]*config_parser.RoutingRule{})
|
||||
var ruleTo *reflect.Value
|
||||
for i := 0; i < to.NumField(); i++ {
|
||||
field := to.Field(i)
|
||||
|
||||
if field.Type() == targetType {
|
||||
ruleTo = &field
|
||||
break
|
||||
}
|
||||
}
|
||||
if ruleTo == nil {
|
||||
return fmt.Errorf(`no %v field found`, targetType.String())
|
||||
}
|
||||
|
||||
// Parse and unmarshal list of RoutingRule to ruleTo.
|
||||
for _, item := range section.Items {
|
||||
switch itemVal := item.Value.(type) {
|
||||
case *config_parser.RoutingRule:
|
||||
ruleTo.Set(reflect.Append(*ruleTo, reflect.ValueOf(itemVal)))
|
||||
case *config_parser.Param:
|
||||
// pass
|
||||
default:
|
||||
return fmt.Errorf("section %v does not support type %v: %v", section.Name, item.Type.String(), item.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Parse Param.
|
||||
return paramParser(to.Addr(), section,
|
||||
[]reflect.Type{reflect.TypeOf(&config_parser.RoutingRule{})},
|
||||
)
|
||||
|
||||
panic("code should not reach here")
|
||||
|
||||
unsupported:
|
||||
return fmt.Errorf("unsupported section type %v", to.Type())
|
||||
}
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/component/dns"
|
||||
"github.com/v2rayA/dae/component/outbound"
|
||||
"github.com/v2rayA/dae/component/outbound/dialer"
|
||||
"github.com/v2rayA/dae/component/routing"
|
||||
@ -24,9 +25,9 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -44,10 +45,8 @@ type ControlPlane struct {
|
||||
// TODO: add mutex?
|
||||
outbounds []*outbound.DialerGroup
|
||||
|
||||
// mutex protects the dnsCache.
|
||||
dnsCacheMu sync.Mutex
|
||||
dnsCache map[string]*dnsCache
|
||||
dnsUpstream DnsUpstreamRaw
|
||||
dnsController *DnsController
|
||||
onceNetworkReady sync.Once
|
||||
|
||||
dialMode consts.DialMode
|
||||
|
||||
@ -60,6 +59,7 @@ func NewControlPlane(
|
||||
groups []config.Group,
|
||||
routingA *config.Routing,
|
||||
global *config.Global,
|
||||
dnsConfig *config.Dns,
|
||||
) (c *ControlPlane, err error) {
|
||||
kernelVersion, e := internal.KernelVersion()
|
||||
if e != nil {
|
||||
@ -199,13 +199,6 @@ func NewControlPlane(
|
||||
}
|
||||
|
||||
/// DialerGroups (outbounds).
|
||||
checkDnsTcp := false
|
||||
if !global.DnsUpstream.Empty {
|
||||
if scheme, _, _, err := ParseDnsUpstream(global.DnsUpstream.Url); err == nil &&
|
||||
scheme.ContainsTcp() {
|
||||
checkDnsTcp = true
|
||||
}
|
||||
}
|
||||
if global.AllowInsecure {
|
||||
log.Warnln("AllowInsecure is enabled, but it is not recommended. Please make sure you have to turn it on.")
|
||||
}
|
||||
@ -215,7 +208,7 @@ func NewControlPlane(
|
||||
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: global.UdpCheckDns},
|
||||
CheckInterval: global.CheckInterval,
|
||||
CheckTolerance: global.CheckTolerance,
|
||||
CheckDnsTcp: checkDnsTcp,
|
||||
CheckDnsTcp: true,
|
||||
AllowInsecure: global.AllowInsecure,
|
||||
}
|
||||
outbounds := []*outbound.DialerGroup{
|
||||
@ -237,12 +230,12 @@ func NewControlPlane(
|
||||
dialerSet := outbound.NewDialerSetFromLinks(option, tagToNodeList)
|
||||
for _, group := range groups {
|
||||
// Parse policy.
|
||||
policy, err := outbound.NewDialerSelectionPolicyFromGroupParam(&group.Param)
|
||||
policy, err := outbound.NewDialerSelectionPolicyFromGroupParam(&group)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create group %v: %w", group.Name, err)
|
||||
}
|
||||
// Filter nodes with user given filters.
|
||||
dialers, err := dialerSet.Filter(group.Param.Filter)
|
||||
dialers, err := dialerSet.Filter(group.Filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(`failed to create group "%v": %w`, group.Name, err)
|
||||
}
|
||||
@ -276,7 +269,7 @@ func NewControlPlane(
|
||||
outboundId2Name[uint8(i)] = o.Name
|
||||
}
|
||||
core.outboundId2Name = outboundId2Name
|
||||
builder := NewRoutingMatcherBuilder(outboundName2Id, &bpf)
|
||||
// Apply rules optimizers.
|
||||
var rules []*config_parser.RoutingRule
|
||||
if rules, err = routing.ApplyRulesOptimizers(routingA.Rules,
|
||||
&routing.RefineFunctionParamKeyOptimizer{},
|
||||
@ -294,121 +287,134 @@ func NewControlPlane(
|
||||
}
|
||||
log.Debugf("RoutingA:\n%vfallback: %v\n", debugBuilder.String(), routingA.Fallback)
|
||||
}
|
||||
if err = routing.ApplyMatcherBuilder(log, builder, rules, routingA.Fallback); err != nil {
|
||||
return nil, fmt.Errorf("ApplyMatcherBuilder: %w", err)
|
||||
// Parse rules and build.
|
||||
builder, err := NewRoutingMatcherBuilder(log, rules, outboundName2Id, &bpf, routingA.Fallback)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewRoutingMatcherBuilder: %w", err)
|
||||
}
|
||||
if err = builder.BuildKernspace(); err != nil {
|
||||
return nil, fmt.Errorf("RoutingMatcherBuilder.BuildKernspace: %w", err)
|
||||
}
|
||||
routingMatcher, err := builder.BuildUserspace()
|
||||
routingMatcher, err := builder.BuildUserspace(core.bpf.LpmArrayMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("RoutingMatcherBuilder.BuildKernspace: %w", err)
|
||||
}
|
||||
|
||||
/// Dial mode.
|
||||
dialMode, err := consts.ParseDialMode(global.DialMode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c = &ControlPlane{
|
||||
log: log,
|
||||
core: core,
|
||||
deferFuncs: nil,
|
||||
listenIp: "0.0.0.0",
|
||||
outbounds: outbounds,
|
||||
dnsCacheMu: sync.Mutex{},
|
||||
dnsCache: make(map[string]*dnsCache),
|
||||
dnsUpstream: DnsUpstreamRaw{
|
||||
Raw: global.DnsUpstream,
|
||||
FinishInitCallback: nil,
|
||||
},
|
||||
log: log,
|
||||
core: core,
|
||||
deferFuncs: nil,
|
||||
listenIp: "0.0.0.0",
|
||||
outbounds: outbounds,
|
||||
dialMode: dialMode,
|
||||
routingMatcher: routingMatcher,
|
||||
}
|
||||
|
||||
/// DNS upstream
|
||||
c.dnsUpstream.FinishInitCallback = c.finishInitDnsUpstreamResolve
|
||||
// Try to invoke once to avoid dns leaking at the very beginning.
|
||||
_, _ = c.dnsUpstream.GetUpstream()
|
||||
/// DNS upstream.
|
||||
dnsUpstream, err := dns.New(log, dnsConfig, &dns.NewOption{
|
||||
UpstreamReadyCallback: c.dnsUpstreamReadyCallback,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
/// Dns controller.
|
||||
c.dnsController, err = NewDnsController(dnsUpstream, &DnsControllerOption{
|
||||
Log: log,
|
||||
CacheAccessCallback: func(cache *DnsCache) (err error) {
|
||||
// Write mappings into eBPF map:
|
||||
// IP record (from dns lookup) -> domain routing
|
||||
if err = core.BatchUpdateDomainRouting(cache); err != nil {
|
||||
return fmt.Errorf("BatchUpdateDomainRouting: %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
NewCache: func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error) {
|
||||
return &DnsCache{
|
||||
DomainBitmap: c.routingMatcher.domainMatcher.MatchDomainBitmap(fqdn),
|
||||
Answers: answers,
|
||||
Deadline: deadline,
|
||||
}, nil
|
||||
},
|
||||
BestDialerChooser: c.chooseBestDnsDialer,
|
||||
})
|
||||
|
||||
// Call GC to release memory.
|
||||
runtime.GC()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *ControlPlane) finishInitDnsUpstreamResolve(raw common.UrlOrEmpty, dnsUpstream *DnsUpstream) (err error) {
|
||||
func (c *ControlPlane) dnsUpstreamReadyCallback(raw *url.URL, dnsUpstream *dns.Upstream) (err error) {
|
||||
/// Notify dialers to check.
|
||||
for _, out := range c.outbounds {
|
||||
for _, d := range out.Dialers {
|
||||
d.NotifyCheck()
|
||||
c.onceNetworkReady.Do(func() {
|
||||
for _, out := range c.outbounds {
|
||||
for _, d := range out.Dialers {
|
||||
d.NotifyCheck()
|
||||
}
|
||||
}
|
||||
if dnsUpstream != nil {
|
||||
// Control plane DNS routing.
|
||||
if err = c.core.bpf.ParamMap.Update(consts.ControlPlaneDnsRoutingKey, uint32(1), ebpf.UpdateAny); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// As-is.
|
||||
if err = c.core.bpf.ParamMap.Update(consts.ControlPlaneDnsRoutingKey, uint32(0), ebpf.UpdateAny); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if dnsUpstream == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
/// Updates dns cache to support domain routing for hostname of dns_upstream.
|
||||
if !raw.Empty {
|
||||
ip4in6 := dnsUpstream.Ip4.As16()
|
||||
ip6 := dnsUpstream.Ip6.As16()
|
||||
if err = c.core.bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{
|
||||
Ip4: common.Ipv6ByteSliceToUint32Array(ip4in6[:]),
|
||||
Ip6: common.Ipv6ByteSliceToUint32Array(ip6[:]),
|
||||
HasIp4: dnsUpstream.Ip4.IsValid(),
|
||||
HasIp6: dnsUpstream.Ip6.IsValid(),
|
||||
Port: common.Htons(dnsUpstream.Port),
|
||||
}, ebpf.UpdateAny); err != nil {
|
||||
// Ten years later.
|
||||
deadline := time.Now().Add(time.Hour * 24 * 365 * 10)
|
||||
fqdn := dnsUpstream.Hostname
|
||||
if !strings.HasSuffix(fqdn, ".") {
|
||||
fqdn = fqdn + "."
|
||||
}
|
||||
|
||||
if dnsUpstream.Ip4.IsValid() {
|
||||
typ := dnsmessage.TypeA
|
||||
answers := []dnsmessage.Resource{{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: dnsmessage.MustNewName(fqdn),
|
||||
Type: typ,
|
||||
Class: dnsmessage.ClassINET,
|
||||
TTL: 0, // Must be zero.
|
||||
},
|
||||
Body: &dnsmessage.AResource{
|
||||
A: dnsUpstream.Ip4.As4(),
|
||||
},
|
||||
}}
|
||||
if err = c.dnsController.UpdateDnsCache(dnsUpstream.Hostname, typ, answers, deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
/// Update dns cache to support domain routing for hostname of dns_upstream.
|
||||
// Ten years later.
|
||||
deadline := time.Now().Add(24 * time.Hour * 365 * 10)
|
||||
fqdn := dnsUpstream.Hostname
|
||||
if !strings.HasSuffix(fqdn, ".") {
|
||||
fqdn = fqdn + "."
|
||||
}
|
||||
}
|
||||
|
||||
if dnsUpstream.Ip4.IsValid() {
|
||||
typ := dnsmessage.TypeA
|
||||
answers := []dnsmessage.Resource{{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: dnsmessage.MustNewName(fqdn),
|
||||
Type: typ,
|
||||
Class: dnsmessage.ClassINET,
|
||||
TTL: 0, // Must be zero.
|
||||
},
|
||||
Body: &dnsmessage.AResource{
|
||||
A: dnsUpstream.Ip4.As4(),
|
||||
},
|
||||
}}
|
||||
if err = c.UpdateDnsCache(dnsUpstream.Hostname, typ, answers, deadline); err != nil {
|
||||
c = nil
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if dnsUpstream.Ip6.IsValid() {
|
||||
typ := dnsmessage.TypeAAAA
|
||||
answers := []dnsmessage.Resource{{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: dnsmessage.MustNewName(fqdn),
|
||||
Type: typ,
|
||||
Class: dnsmessage.ClassINET,
|
||||
TTL: 0, // Must be zero.
|
||||
},
|
||||
Body: &dnsmessage.AAAAResource{
|
||||
AAAA: dnsUpstream.Ip6.As16(),
|
||||
},
|
||||
}}
|
||||
if err = c.UpdateDnsCache(dnsUpstream.Hostname, typ, answers, deadline); err != nil {
|
||||
c = nil
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Empty string. As-is.
|
||||
if err = c.core.bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfDnsUpstream{
|
||||
Ip4: [4]uint32{},
|
||||
Ip6: [4]uint32{},
|
||||
HasIp4: false,
|
||||
HasIp6: false,
|
||||
// Zero port indicates no element, because bpf_map_lookup_elem cannot return 0 for map_type_array.
|
||||
Port: 0,
|
||||
}, ebpf.UpdateAny); err != nil {
|
||||
if dnsUpstream.Ip6.IsValid() {
|
||||
typ := dnsmessage.TypeAAAA
|
||||
answers := []dnsmessage.Resource{{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: dnsmessage.MustNewName(fqdn),
|
||||
Type: typ,
|
||||
Class: dnsmessage.ClassINET,
|
||||
TTL: 0, // Must be zero.
|
||||
},
|
||||
Body: &dnsmessage.AAAAResource{
|
||||
AAAA: dnsUpstream.Ip6.As16(),
|
||||
},
|
||||
}}
|
||||
if err = c.dnsController.UpdateDnsCache(dnsUpstream.Hostname, typ, answers, deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -421,9 +427,8 @@ func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip
|
||||
if !outbound.IsReserved() && domain != "" {
|
||||
switch c.dialMode {
|
||||
case consts.DialMode_Domain:
|
||||
dstIp := common.ConvergeIp(dst.Addr())
|
||||
cache := c.lookupDnsRespCache(domain, common.AddrToDnsType(dstIp))
|
||||
if cache != nil && cache.IncludeIp(dstIp) {
|
||||
cache := c.dnsController.LookupDnsRespCache(domain, common.AddrToDnsType(dst.Addr()))
|
||||
if cache != nil && cache.IncludeIp(dst.Addr()) {
|
||||
mode = consts.DialMode_Domain
|
||||
}
|
||||
case consts.DialMode_DomainPlus:
|
||||
@ -552,7 +557,7 @@ func (c *ControlPlane) ListenAndServe(port uint16) (err error) {
|
||||
} else {
|
||||
realDst = pktDst
|
||||
}
|
||||
if e := c.handlePkt(udpConn, data, src, pktDst, realDst, routingResult); e != nil {
|
||||
if e := c.handlePkt(udpConn, data, common.ConvergeAddrPort(src), common.ConvergeAddrPort(pktDst), common.ConvergeAddrPort(realDst), routingResult); e != nil {
|
||||
c.log.Warnln("handlePkt:", e)
|
||||
}
|
||||
}(newBuf, src)
|
||||
@ -562,6 +567,103 @@ func (c *ControlPlane) ListenAndServe(port uint16) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ControlPlane) chooseBestDnsDialer(
|
||||
req *udpRequest,
|
||||
dnsUpstream *dns.Upstream,
|
||||
) (*dialArgument, error) {
|
||||
/// Choose the best l4proto+ipversion dialer, and change taregt DNS to the best ipversion DNS upstream for DNS request.
|
||||
// Get available ipversions and l4protos for DNS upstream.
|
||||
ipversions, l4protos := dnsUpstream.SupportedNetworks()
|
||||
var (
|
||||
bestLatency time.Duration
|
||||
l4proto consts.L4ProtoStr
|
||||
ipversion consts.IpVersionStr
|
||||
bestDialer *dialer.Dialer
|
||||
bestOutbound *outbound.DialerGroup
|
||||
bestTarget netip.AddrPort
|
||||
dialMark uint32
|
||||
)
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"ipversions": ipversions,
|
||||
"l4protos": l4protos,
|
||||
"upstream": dnsUpstream.String(),
|
||||
}).Traceln("Choose DNS path")
|
||||
}
|
||||
// Get the min latency path.
|
||||
networkType := dialer.NetworkType{
|
||||
IsDns: true,
|
||||
}
|
||||
for _, ver := range ipversions {
|
||||
for _, proto := range l4protos {
|
||||
networkType.L4Proto = proto
|
||||
networkType.IpVersion = ver
|
||||
var dAddr netip.Addr
|
||||
switch ver {
|
||||
case consts.IpVersionStr_4:
|
||||
dAddr = dnsUpstream.Ip4
|
||||
case consts.IpVersionStr_6:
|
||||
dAddr = dnsUpstream.Ip6
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected ipversion: %v", ver)
|
||||
}
|
||||
outboundIndex, mark, err := c.Route(req.realSrc, netip.AddrPortFrom(dAddr, dnsUpstream.Port), "", proto.ToL4ProtoType(), req.routingResult)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Already "must direct".
|
||||
if outboundIndex == consts.OutboundMustDirect {
|
||||
outboundIndex = consts.OutboundDirect
|
||||
}
|
||||
if int(outboundIndex) >= len(c.outbounds) {
|
||||
return nil, fmt.Errorf("bad outbound index: %v", outboundIndex)
|
||||
}
|
||||
dialerGroup := c.outbounds[outboundIndex]
|
||||
d, latency, err := dialerGroup.Select(&networkType)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
//if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
// c.log.WithFields(logrus.Fields{
|
||||
// "name": d.Name(),
|
||||
// "latency": latency,
|
||||
// "network": networkType.String(),
|
||||
// "outbound": dialerGroup.Name,
|
||||
// }).Traceln("Choice")
|
||||
//}
|
||||
if bestDialer == nil || latency < bestLatency {
|
||||
bestDialer = d
|
||||
bestOutbound = dialerGroup
|
||||
bestLatency = latency
|
||||
l4proto = proto
|
||||
ipversion = ver
|
||||
dialMark = mark
|
||||
|
||||
if bestLatency == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if bestDialer == nil {
|
||||
return nil, fmt.Errorf("no proper dialer for DNS upstream: %v", dnsUpstream.String())
|
||||
}
|
||||
switch ipversion {
|
||||
case consts.IpVersionStr_4:
|
||||
bestTarget = netip.AddrPortFrom(dnsUpstream.Ip4, dnsUpstream.Port)
|
||||
case consts.IpVersionStr_6:
|
||||
bestTarget = netip.AddrPortFrom(dnsUpstream.Ip6, dnsUpstream.Port)
|
||||
}
|
||||
return &dialArgument{
|
||||
l4proto: l4proto,
|
||||
ipversion: ipversion,
|
||||
bestDialer: bestDialer,
|
||||
bestOutbound: bestOutbound,
|
||||
bestTarget: bestTarget,
|
||||
mark: dialMark,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *ControlPlane) Close() (err error) {
|
||||
// Invoke defer funcs in reverse order.
|
||||
for i := len(c.deferFuncs) - 1; i >= 0; i-- {
|
||||
|
@ -11,11 +11,14 @@ import (
|
||||
ciliumLink "github.com/cilium/ebpf/link"
|
||||
"github.com/safchain/ethtool"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
internal "github.com/v2rayA/dae/pkg/ebpf_internal"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"golang.org/x/sys/unix"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"regexp"
|
||||
)
|
||||
@ -415,3 +418,42 @@ func (c *ControlPlaneCore) bindWan(ifname string) error {
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchUpdateDomainRouting update bpf map domain_routing. Since one IP may have multiple domains, this function should
|
||||
// be invoked every A/AAAA-record lookup.
|
||||
func (c *ControlPlaneCore) BatchUpdateDomainRouting(cache *DnsCache) error {
|
||||
// Parse ips from DNS resp answers.
|
||||
var ips []netip.Addr
|
||||
for _, ans := range cache.Answers {
|
||||
switch ans.Header.Type {
|
||||
case dnsmessage.TypeA:
|
||||
ips = append(ips, netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A))
|
||||
case dnsmessage.TypeAAAA:
|
||||
ips = append(ips, netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA))
|
||||
}
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update bpf map.
|
||||
// Construct keys and vals, and BpfMapBatchUpdate.
|
||||
var keys [][4]uint32
|
||||
var vals []bpfDomainRouting
|
||||
for _, ip := range ips {
|
||||
ip6 := ip.As16()
|
||||
keys = append(keys, common.Ipv6ByteSliceToUint32Array(ip6[:]))
|
||||
r := bpfDomainRouting{}
|
||||
if len(cache.DomainBitmap) != len(r.Bitmap) {
|
||||
return fmt.Errorf("domain bitmap length not sync with kern program")
|
||||
}
|
||||
copy(r.Bitmap[:], cache.DomainBitmap)
|
||||
vals = append(vals, r)
|
||||
}
|
||||
if _, err := BpfMapBatchUpdate(c.bpf.DomainRoutingMap, keys, vals, &ebpf.BatchOptions{
|
||||
ElemFlags: uint64(ebpf.UpdateAny),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
370
control/dns.go
370
control/dns.go
@ -1,370 +0,0 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2022-2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cilium/ebpf"
|
||||
"github.com/mohae/deepcopy"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"hash/fnv"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
SuspectedRushAnswerError = fmt.Errorf("suspected DNS rush-answer")
|
||||
UnsupportedQuestionTypeError = fmt.Errorf("unsupported question type")
|
||||
)
|
||||
|
||||
type dnsCache struct {
|
||||
DomainBitmap []uint32
|
||||
Answers []dnsmessage.Resource
|
||||
Deadline time.Time
|
||||
}
|
||||
|
||||
func (c *dnsCache) FillInto(req *dnsmessage.Message) {
|
||||
req.Answers = deepcopy.Copy(c.Answers).([]dnsmessage.Resource)
|
||||
// Align question and answer Name.
|
||||
if len(req.Questions) > 0 {
|
||||
q := req.Questions[0]
|
||||
for i := range req.Answers {
|
||||
if strings.EqualFold(req.Answers[i].Header.Name.String(), q.Name.String()) {
|
||||
req.Answers[i].Header.Name.Data = q.Name.Data
|
||||
}
|
||||
}
|
||||
}
|
||||
req.RCode = dnsmessage.RCodeSuccess
|
||||
req.Response = true
|
||||
req.RecursionAvailable = true
|
||||
req.Truncated = false
|
||||
}
|
||||
|
||||
func (c *dnsCache) IncludeIp(ip netip.Addr) bool {
|
||||
ip = common.ConvergeIp(ip)
|
||||
for _, ans := range c.Answers {
|
||||
switch body := ans.Body.(type) {
|
||||
case *dnsmessage.AResource:
|
||||
if !ip.Is4() {
|
||||
continue
|
||||
}
|
||||
if netip.AddrFrom4(body.A) == ip {
|
||||
return true
|
||||
}
|
||||
case *dnsmessage.AAAAResource:
|
||||
if !ip.Is6() {
|
||||
continue
|
||||
}
|
||||
if netip.AddrFrom16(body.AAAA) == ip {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// BatchUpdateDomainRouting update bpf map domain_routing. Since one IP may have multiple domains, this function should
|
||||
// be invoked every A/AAAA-record lookup.
|
||||
func (c *ControlPlane) BatchUpdateDomainRouting(cache *dnsCache) error {
|
||||
// Parse ips from DNS resp answers.
|
||||
var ips []netip.Addr
|
||||
for _, ans := range cache.Answers {
|
||||
switch ans.Header.Type {
|
||||
case dnsmessage.TypeA:
|
||||
ips = append(ips, netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A))
|
||||
case dnsmessage.TypeAAAA:
|
||||
ips = append(ips, netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA))
|
||||
}
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update bpf map.
|
||||
// Construct keys and vals, and BpfMapBatchUpdate.
|
||||
var keys [][4]uint32
|
||||
var vals []bpfDomainRouting
|
||||
for _, ip := range ips {
|
||||
ip6 := ip.As16()
|
||||
keys = append(keys, common.Ipv6ByteSliceToUint32Array(ip6[:]))
|
||||
vals = append(vals, bpfDomainRouting{
|
||||
Bitmap: [3]uint32{},
|
||||
})
|
||||
if len(cache.DomainBitmap) != len(vals[len(vals)-1].Bitmap) {
|
||||
return fmt.Errorf("domain bitmap length not sync with kern program")
|
||||
}
|
||||
copy(vals[len(vals)-1].Bitmap[:], cache.DomainBitmap)
|
||||
}
|
||||
if _, err := BpfMapBatchUpdate(c.core.bpf.DomainRoutingMap, keys, vals, &ebpf.BatchOptions{
|
||||
ElemFlags: uint64(ebpf.UpdateAny),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ControlPlane) lookupDnsRespCache(domain string, t dnsmessage.Type) (cache *dnsCache) {
|
||||
now := time.Now()
|
||||
|
||||
// To fqdn.
|
||||
if !strings.HasSuffix(domain, ".") {
|
||||
domain = domain + "."
|
||||
}
|
||||
c.dnsCacheMu.Lock()
|
||||
cache, ok := c.dnsCache[strings.ToLower(domain)+t.String()]
|
||||
c.dnsCacheMu.Unlock()
|
||||
if ok && cache.Deadline.After(now) {
|
||||
return cache
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ControlPlane) LookupDnsRespCache_(msg *dnsmessage.Message) (resp []byte) {
|
||||
if len(msg.Questions) == 0 {
|
||||
return nil
|
||||
}
|
||||
q := msg.Questions[0]
|
||||
if msg.Response {
|
||||
return nil
|
||||
}
|
||||
switch q.Type {
|
||||
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
cache := c.lookupDnsRespCache(q.Name.String(), q.Type)
|
||||
if cache != nil {
|
||||
cache.FillInto(msg)
|
||||
b, err := msg.Pack()
|
||||
if err != nil {
|
||||
c.log.Warnf("failed to pack: %v", err)
|
||||
return nil
|
||||
}
|
||||
if err = c.BatchUpdateDomainRouting(cache); err != nil {
|
||||
c.log.Warnf("failed to BatchUpdateDomainRouting: %v", err)
|
||||
return nil
|
||||
}
|
||||
return b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FlipDnsQuestionCase is used to reduce dns pollution.
|
||||
func FlipDnsQuestionCase(dm *dnsmessage.Message) {
|
||||
if len(dm.Questions) == 0 {
|
||||
return
|
||||
}
|
||||
q := &dm.Questions[0]
|
||||
// For reproducibility, we use dm.ID as input and add some entropy to make the results more discrete.
|
||||
h := fnv.New64()
|
||||
var buf [4]byte
|
||||
binary.BigEndian.PutUint16(buf[:], dm.ID)
|
||||
h.Write(buf[:2])
|
||||
binary.BigEndian.PutUint32(buf[:], 20230204) // entropy
|
||||
h.Write(buf[:])
|
||||
r := rand.New(rand.NewSource(int64(h.Sum64())))
|
||||
perm := r.Perm(int(q.Name.Length))
|
||||
for i := 0; i < int(q.Name.Length/3); i++ {
|
||||
j := perm[i]
|
||||
// Upper to lower; lower to upper.
|
||||
if q.Name.Data[j] >= 'a' && q.Name.Data[j] <= 'z' {
|
||||
q.Name.Data[j] -= 'a' - 'A'
|
||||
} else if q.Name.Data[j] >= 'A' && q.Name.Data[j] <= 'Z' {
|
||||
q.Name.Data[j] += 'a' - 'A'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureAdditionalOpt makes sure there is additional record OPT in the request.
|
||||
func EnsureAdditionalOpt(dm *dnsmessage.Message, isReqAdd bool) (bool, error) {
|
||||
// Check healthy resp.
|
||||
if isReqAdd == dm.Response || dm.RCode != dnsmessage.RCodeSuccess || len(dm.Questions) == 0 {
|
||||
return false, UnsupportedQuestionTypeError
|
||||
}
|
||||
q := dm.Questions[0]
|
||||
switch q.Type {
|
||||
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
|
||||
default:
|
||||
return false, UnsupportedQuestionTypeError
|
||||
}
|
||||
|
||||
for _, ad := range dm.Additionals {
|
||||
if ad.Header.Type == dnsmessage.TypeOPT {
|
||||
// Already has additional record OPT.
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
if !isReqAdd {
|
||||
return false, nil
|
||||
}
|
||||
// Add one.
|
||||
dm.Additionals = append(dm.Additionals, dnsmessage.Resource{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: dnsmessage.MustNewName("."),
|
||||
Type: dnsmessage.TypeOPT,
|
||||
Class: 512, TTL: 0, Length: 0,
|
||||
},
|
||||
Body: &dnsmessage.OPTResource{
|
||||
Options: nil,
|
||||
},
|
||||
})
|
||||
return false, nil
|
||||
}
|
||||
|
||||
type RscWrapper struct {
|
||||
Rsc dnsmessage.Resource
|
||||
}
|
||||
|
||||
func (w RscWrapper) String() string {
|
||||
return fmt.Sprintf("%v: %v", w.Rsc.Header.GoString(), w.Rsc.Body.GoString())
|
||||
}
|
||||
func FormatDnsRsc(ans []dnsmessage.Resource) (w []string) {
|
||||
for _, a := range ans {
|
||||
w = append(w, RscWrapper{Rsc: a}.String())
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
||||
// DnsRespHandler handle DNS resp. This function should be invoked when cache miss.
|
||||
func (c *ControlPlane) DnsRespHandler(data []byte, validateRushAns bool) (newData []byte, err error) {
|
||||
var msg dnsmessage.Message
|
||||
if err = msg.Unpack(data); err != nil {
|
||||
return nil, fmt.Errorf("unpack dns pkt: %w", err)
|
||||
}
|
||||
// Check healthy resp.
|
||||
if !msg.Response || len(msg.Questions) == 0 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
FlipDnsQuestionCase(&msg)
|
||||
q := msg.Questions[0]
|
||||
// Align Name.
|
||||
for i := range msg.Answers {
|
||||
if strings.EqualFold(msg.Answers[i].Header.Name.String(), q.Name.String()) {
|
||||
msg.Answers[i].Header.Name.Data = q.Name.Data
|
||||
}
|
||||
}
|
||||
for i := range msg.Additionals {
|
||||
if strings.EqualFold(msg.Additionals[i].Header.Name.String(), q.Name.String()) {
|
||||
msg.Additionals[i].Header.Name.Data = q.Name.Data
|
||||
}
|
||||
}
|
||||
for i := range msg.Authorities {
|
||||
if strings.EqualFold(msg.Authorities[i].Header.Name.String(), q.Name.String()) {
|
||||
msg.Authorities[i].Header.Name.Data = q.Name.Data
|
||||
}
|
||||
}
|
||||
|
||||
// Check suc resp.
|
||||
if msg.RCode != dnsmessage.RCodeSuccess {
|
||||
return msg.Pack()
|
||||
}
|
||||
|
||||
// Check req type.
|
||||
switch q.Type {
|
||||
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
|
||||
default:
|
||||
return msg.Pack()
|
||||
}
|
||||
|
||||
// Set ttl.
|
||||
var ttl uint32
|
||||
for i := range msg.Answers {
|
||||
if ttl == 0 {
|
||||
ttl = msg.Answers[i].Header.TTL
|
||||
}
|
||||
// Set TTL = zero. This requests applications must resend every request.
|
||||
// However, it may be not defined in the standard.
|
||||
msg.Answers[i].Header.TTL = 0
|
||||
}
|
||||
|
||||
// Check if there is any A/AAAA record.
|
||||
var hasIpRecord bool
|
||||
loop:
|
||||
for i := range msg.Answers {
|
||||
switch msg.Answers[i].Header.Type {
|
||||
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
|
||||
hasIpRecord = true
|
||||
break loop
|
||||
}
|
||||
}
|
||||
if !hasIpRecord {
|
||||
return msg.Pack()
|
||||
}
|
||||
|
||||
if validateRushAns {
|
||||
exist, e := EnsureAdditionalOpt(&msg, false)
|
||||
if e != nil && !errors.Is(e, UnsupportedQuestionTypeError) {
|
||||
c.log.Warnf("EnsureAdditionalOpt: %v", e)
|
||||
}
|
||||
if e == nil && !exist {
|
||||
// Additional record OPT in the request was ensured, and in normal case the resp should also set it.
|
||||
// This DNS packet may be a rush-answer, and we should reject it.
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"ques": q,
|
||||
"addi": FormatDnsRsc(msg.Additionals),
|
||||
"ans": FormatDnsRsc(msg.Answers),
|
||||
}).Traceln("DNS rush-answer detected")
|
||||
return nil, SuspectedRushAnswerError
|
||||
}
|
||||
}
|
||||
|
||||
// Update dnsCache.
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"qname": q.Name,
|
||||
"rcode": msg.RCode,
|
||||
"ans": FormatDnsRsc(msg.Answers),
|
||||
"auth": FormatDnsRsc(msg.Authorities),
|
||||
"addi": FormatDnsRsc(msg.Additionals),
|
||||
}).Tracef("Update DNS record cache")
|
||||
}
|
||||
if err = c.UpdateDnsCache(q.Name.String(), q.Type, msg.Answers, time.Now().Add(time.Duration(ttl)*time.Second+DnsNatTimeout)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Pack to get newData.
|
||||
return msg.Pack()
|
||||
}
|
||||
|
||||
func (c *ControlPlane) UpdateDnsCache(host string, typ dnsmessage.Type, answers []dnsmessage.Resource, deadline time.Time) (err error) {
|
||||
var fqdn string
|
||||
if strings.HasSuffix(host, ".") {
|
||||
fqdn = host
|
||||
host = host[:len(host)-1]
|
||||
} else {
|
||||
fqdn = host + "."
|
||||
}
|
||||
// Bypass pure IP.
|
||||
if _, err = netip.ParseAddr(host); err == nil {
|
||||
return nil
|
||||
}
|
||||
cacheKey := fqdn + typ.String()
|
||||
c.dnsCacheMu.Lock()
|
||||
cache, ok := c.dnsCache[cacheKey]
|
||||
if ok {
|
||||
c.dnsCacheMu.Unlock()
|
||||
cache.Deadline = deadline
|
||||
cache.Answers = answers
|
||||
} else {
|
||||
cache = &dnsCache{
|
||||
DomainBitmap: c.routingMatcher.domainMatcher.MatchDomainBitmap(fqdn),
|
||||
Answers: answers,
|
||||
Deadline: deadline,
|
||||
}
|
||||
c.dnsCache[cacheKey] = cache
|
||||
c.dnsCacheMu.Unlock()
|
||||
}
|
||||
if err = c.BatchUpdateDomainRouting(cache); err != nil {
|
||||
return fmt.Errorf("BatchUpdateDomainRouting: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
59
control/dns_cache.go
Normal file
59
control/dns_cache.go
Normal file
@ -0,0 +1,59 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2022-2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"github.com/mohae/deepcopy"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DnsCache struct {
|
||||
DomainBitmap []uint32
|
||||
Answers []dnsmessage.Resource
|
||||
Deadline time.Time
|
||||
}
|
||||
|
||||
func (c *DnsCache) FillInto(req *dnsmessage.Message) {
|
||||
req.Answers = deepcopy.Copy(c.Answers).([]dnsmessage.Resource)
|
||||
// Align question and answer Name.
|
||||
if len(req.Questions) > 0 {
|
||||
q := req.Questions[0]
|
||||
for i := range req.Answers {
|
||||
if strings.EqualFold(req.Answers[i].Header.Name.String(), q.Name.String()) {
|
||||
req.Answers[i].Header.Name.Data = q.Name.Data
|
||||
}
|
||||
}
|
||||
}
|
||||
req.RCode = dnsmessage.RCodeSuccess
|
||||
req.Response = true
|
||||
req.RecursionAvailable = true
|
||||
req.Truncated = false
|
||||
}
|
||||
|
||||
func (c *DnsCache) IncludeIp(ip netip.Addr) bool {
|
||||
for _, ans := range c.Answers {
|
||||
switch body := ans.Body.(type) {
|
||||
case *dnsmessage.AResource:
|
||||
if !ip.Is4() {
|
||||
continue
|
||||
}
|
||||
if netip.AddrFrom4(body.A) == ip {
|
||||
return true
|
||||
}
|
||||
case *dnsmessage.AAAAResource:
|
||||
if !ip.Is6() {
|
||||
continue
|
||||
}
|
||||
if netip.AddrFrom16(body.AAAA) == ip {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
557
control/dns_control.go
Normal file
557
control/dns_control.go
Normal file
@ -0,0 +1,557 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/mzz2017/softwind/netproxy"
|
||||
"github.com/mzz2017/softwind/pool"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/common/netutils"
|
||||
"github.com/v2rayA/dae/component/dns"
|
||||
"github.com/v2rayA/dae/component/outbound"
|
||||
"github.com/v2rayA/dae/component/outbound/dialer"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxDnsLookupDepth = 3
|
||||
)
|
||||
|
||||
var (
|
||||
SuspectedRushAnswerError = fmt.Errorf("suspected DNS rush-answer")
|
||||
UnsupportedQuestionTypeError = fmt.Errorf("unsupported question type")
|
||||
)
|
||||
|
||||
type DnsControllerOption struct {
|
||||
Log *logrus.Logger
|
||||
CacheAccessCallback func(cache *DnsCache) (err error)
|
||||
NewCache func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error)
|
||||
BestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error)
|
||||
}
|
||||
|
||||
type DnsController struct {
|
||||
routing *dns.Dns
|
||||
|
||||
log *logrus.Logger
|
||||
cacheAccessCallback func(cache *DnsCache) (err error)
|
||||
newCache func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error)
|
||||
bestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error)
|
||||
|
||||
// mutex protects the dnsCache.
|
||||
dnsCacheMu sync.Mutex
|
||||
dnsCache map[string]*DnsCache
|
||||
}
|
||||
|
||||
func NewDnsController(routing *dns.Dns, option *DnsControllerOption) (c *DnsController, err error) {
|
||||
return &DnsController{
|
||||
routing: routing,
|
||||
|
||||
log: option.Log,
|
||||
cacheAccessCallback: option.CacheAccessCallback,
|
||||
newCache: option.NewCache,
|
||||
bestDialerChooser: option.BestDialerChooser,
|
||||
|
||||
dnsCacheMu: sync.Mutex{},
|
||||
dnsCache: make(map[string]*DnsCache),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *DnsController) LookupDnsRespCache(domain string, t dnsmessage.Type) (cache *DnsCache) {
|
||||
now := time.Now()
|
||||
|
||||
// To fqdn.
|
||||
if !strings.HasSuffix(domain, ".") {
|
||||
domain = domain + "."
|
||||
}
|
||||
c.dnsCacheMu.Lock()
|
||||
cache, ok := c.dnsCache[strings.ToLower(domain)+t.String()]
|
||||
c.dnsCacheMu.Unlock()
|
||||
if ok && cache.Deadline.After(now) {
|
||||
return cache
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LookupDnsRespCache_ will modify the msg in place.
|
||||
func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Message) (resp []byte) {
|
||||
if len(msg.Questions) == 0 {
|
||||
return nil
|
||||
}
|
||||
q := msg.Questions[0]
|
||||
if msg.Response {
|
||||
return nil
|
||||
}
|
||||
switch q.Type {
|
||||
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
cache := c.LookupDnsRespCache(q.Name.String(), q.Type)
|
||||
if cache != nil {
|
||||
cache.FillInto(msg)
|
||||
b, err := msg.Pack()
|
||||
if err != nil {
|
||||
c.log.Warnf("failed to pack: %v", err)
|
||||
return nil
|
||||
}
|
||||
if err = c.cacheAccessCallback(cache); err != nil {
|
||||
c.log.Warnf("failed to BatchUpdateDomainRouting: %v", err)
|
||||
return nil
|
||||
}
|
||||
return b
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DnsRespHandler handle DNS resp.
|
||||
func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMsg *dnsmessage.Message, err error) {
|
||||
var msg dnsmessage.Message
|
||||
if err = msg.Unpack(data); err != nil {
|
||||
return nil, fmt.Errorf("unpack dns pkt: %w", err)
|
||||
}
|
||||
// Check healthy resp.
|
||||
if !msg.Response || len(msg.Questions) == 0 {
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
q := msg.Questions[0]
|
||||
|
||||
// Check suc resp.
|
||||
if msg.RCode != dnsmessage.RCodeSuccess {
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
// Check req type.
|
||||
switch q.Type {
|
||||
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
|
||||
default:
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
// Set ttl.
|
||||
var ttl uint32
|
||||
for i := range msg.Answers {
|
||||
if ttl == 0 {
|
||||
ttl = msg.Answers[i].Header.TTL
|
||||
}
|
||||
// Set TTL = zero. This requests applications must resend every request.
|
||||
// However, it may be not defined in the standard.
|
||||
msg.Answers[i].Header.TTL = 0
|
||||
}
|
||||
|
||||
// Check if there is any A/AAAA record.
|
||||
var hasIpRecord bool
|
||||
loop:
|
||||
for i := range msg.Answers {
|
||||
switch msg.Answers[i].Header.Type {
|
||||
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
|
||||
hasIpRecord = true
|
||||
break loop
|
||||
}
|
||||
}
|
||||
if !hasIpRecord {
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
if validateRushAns {
|
||||
exist, e := EnsureAdditionalOpt(&msg, false)
|
||||
if e != nil && !errors.Is(e, UnsupportedQuestionTypeError) {
|
||||
c.log.Warnf("EnsureAdditionalOpt: %v", e)
|
||||
}
|
||||
if e == nil && !exist {
|
||||
// Additional record OPT in the request was ensured, and in normal case the resp should also set it.
|
||||
// This DNS packet may be a rush-answer, and we should reject it.
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"ques": q,
|
||||
"addition": FormatDnsRsc(msg.Additionals),
|
||||
"ans": FormatDnsRsc(msg.Answers),
|
||||
}).Traceln("DNS rush-answer detected")
|
||||
return nil, SuspectedRushAnswerError
|
||||
}
|
||||
}
|
||||
|
||||
// Update DnsCache.
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"qname": q.Name,
|
||||
"rcode": msg.RCode,
|
||||
"ans": FormatDnsRsc(msg.Answers),
|
||||
"auth": FormatDnsRsc(msg.Authorities),
|
||||
"addition": FormatDnsRsc(msg.Additionals),
|
||||
}).Tracef("Update DNS record cache")
|
||||
}
|
||||
if err = c.UpdateDnsCache(q.Name.String(), q.Type, msg.Answers, time.Now().Add(time.Duration(ttl)*time.Second+DnsNatTimeout)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Pack to get newData.
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func (c *DnsController) UpdateDnsCache(host string, typ dnsmessage.Type, answers []dnsmessage.Resource, deadline time.Time) (err error) {
|
||||
var fqdn string
|
||||
if strings.HasSuffix(host, ".") {
|
||||
fqdn = host
|
||||
host = host[:len(host)-1]
|
||||
} else {
|
||||
fqdn = host + "."
|
||||
}
|
||||
// Bypass pure IP.
|
||||
if _, err = netip.ParseAddr(host); err == nil {
|
||||
return nil
|
||||
}
|
||||
cacheKey := fqdn + typ.String()
|
||||
c.dnsCacheMu.Lock()
|
||||
cache, ok := c.dnsCache[cacheKey]
|
||||
if ok {
|
||||
c.dnsCacheMu.Unlock()
|
||||
cache.Deadline = deadline
|
||||
cache.Answers = answers
|
||||
} else {
|
||||
cache, err = c.newCache(fqdn, answers, deadline)
|
||||
if err != nil {
|
||||
c.dnsCacheMu.Unlock()
|
||||
return err
|
||||
}
|
||||
c.dnsCache[cacheKey] = cache
|
||||
c.dnsCacheMu.Unlock()
|
||||
}
|
||||
if err = c.cacheAccessCallback(cache); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DnsController) DnsRespHandlerFactory(req *udpRequest, validateRushAnsFunc func(from netip.AddrPort) bool) func(data []byte, from netip.AddrPort) (msg *dnsmessage.Message, err error) {
|
||||
return func(data []byte, from netip.AddrPort) (msg *dnsmessage.Message, err error) {
|
||||
// Do not return conn-unrelated err in this func.
|
||||
|
||||
validateRushAns := validateRushAnsFunc(from)
|
||||
msg, err = c.DnsRespHandler(data, validateRushAns)
|
||||
if err != nil {
|
||||
if errors.Is(err, SuspectedRushAnswerError) {
|
||||
if validateRushAns {
|
||||
// Reject DNS rush-answer.
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"from": from,
|
||||
}).Tracef("DNS rush-answer rejected")
|
||||
return nil, nil
|
||||
}
|
||||
} else {
|
||||
if c.log.IsLevelEnabled(logrus.DebugLevel) {
|
||||
c.log.Debugf("DnsRespHandler: %v", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
|
||||
type udpRequest struct {
|
||||
lanWanFlag consts.LanWanFlag
|
||||
realSrc netip.AddrPort
|
||||
realDst netip.AddrPort
|
||||
src netip.AddrPort
|
||||
lConn *net.UDPConn
|
||||
routingResult *bpfRoutingResult
|
||||
}
|
||||
|
||||
type dialArgument struct {
|
||||
l4proto consts.L4ProtoStr
|
||||
ipversion consts.IpVersionStr
|
||||
bestDialer *dialer.Dialer
|
||||
bestOutbound *outbound.DialerGroup
|
||||
bestTarget netip.AddrPort
|
||||
mark uint32
|
||||
}
|
||||
|
||||
func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) (err error) {
|
||||
if resp := c.LookupDnsRespCache_(dnsMessage); resp != nil {
|
||||
// Send cache to client directly.
|
||||
if err = sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn, req.lanWanFlag); err != nil {
|
||||
return fmt.Errorf("failed to write cached DNS resp: %w", err)
|
||||
}
|
||||
if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 {
|
||||
q := dnsMessage.Questions[0]
|
||||
c.log.Tracef("UDP(DNS) %v <-> Cache: %v %v",
|
||||
RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), strings.ToLower(q.Name.String()), q.Type,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make sure there is additional record OPT in the request to filter DNS rush-answer in the response process.
|
||||
// Because rush-answer has no resp OPT. We can distinguish them from multiple responses.
|
||||
// Note that additional record OPT may not be supported by home router either.
|
||||
_, _ = EnsureAdditionalOpt(dnsMessage, true)
|
||||
|
||||
// Route request.
|
||||
upstream, err := c.routing.RequestSelect(dnsMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
upstreamName := "asis"
|
||||
if upstream != nil {
|
||||
upstreamName = upstream.String()
|
||||
}
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"question": dnsMessage.Questions,
|
||||
"upstream": upstreamName,
|
||||
}).Traceln("Request to DNS upstream")
|
||||
}
|
||||
|
||||
// Re-pack DNS packet.
|
||||
data, err := dnsMessage.Pack()
|
||||
if err != nil {
|
||||
return fmt.Errorf("pack DNS packet: %w", err)
|
||||
}
|
||||
return c.dialSend(req, data, upstream, 0)
|
||||
}
|
||||
|
||||
func (c *DnsController) dialSend(req *udpRequest, data []byte, upstream *dns.Upstream, invokingDepth int) (err error) {
|
||||
if invokingDepth >= MaxDnsLookupDepth {
|
||||
return fmt.Errorf("too deep DNS lookup invoking (depth: %v); there may be infinite loop in your DNS response routing", MaxDnsLookupDepth)
|
||||
}
|
||||
|
||||
upstreamName := "asis"
|
||||
if upstream == nil {
|
||||
// As-is.
|
||||
|
||||
// As-is should not be valid in response routing, thus using connection realDest is reasonable.
|
||||
var ip46 netutils.Ip46
|
||||
if req.realDst.Addr().Is4() {
|
||||
ip46.Ip4 = req.realDst.Addr()
|
||||
} else {
|
||||
ip46.Ip6 = req.realDst.Addr()
|
||||
}
|
||||
upstream = &dns.Upstream{
|
||||
Scheme: "udp",
|
||||
Hostname: req.realDst.Addr().String(),
|
||||
Port: req.realDst.Port(),
|
||||
Ip46: &ip46,
|
||||
}
|
||||
} else {
|
||||
upstreamName = upstream.String()
|
||||
}
|
||||
|
||||
// Select best dial arguments (outbound, dialer, l4proto, ipversion, etc.)
|
||||
dialArgument, err := c.bestDialerChooser(req, upstream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
networkType := &dialer.NetworkType{
|
||||
L4Proto: dialArgument.l4proto,
|
||||
IpVersion: dialArgument.ipversion,
|
||||
IsDns: true, // UDP relies on DNS check result.
|
||||
}
|
||||
|
||||
// dnsRespHandler caches dns response and check rush answers.
|
||||
dnsRespHandler := c.DnsRespHandlerFactory(req, func(from netip.AddrPort) bool {
|
||||
// We only validate rush-ans when outbound is direct and pkt does not send to a home device.
|
||||
// Because additional record OPT may not be supported by home router.
|
||||
// So se should trust home devices even if they make rush-answer (or looks like).
|
||||
return dialArgument.bestDialer.Name() == "direct" && !from.Addr().IsPrivate()
|
||||
})
|
||||
// Dial and send.
|
||||
var respMsg *dnsmessage.Message
|
||||
// defer in a recursive call will delay Close(), thus we Close() before
|
||||
// the next recursive call. However, a connection cannot be closed twice.
|
||||
// We should set a connClosed flag to avoid it.
|
||||
var connClosed bool
|
||||
var conn netproxy.Conn
|
||||
// TODO: Rewritten domain should not use full-cone (such as VMess Packet Addr).
|
||||
// Maybe we should set up a mapping for UDP: Dialer + Target Domain => Remote Resolved IP.
|
||||
// However, games may not use QUIC for communication, thus we cannot use domain to dial, which is fine.
|
||||
switch dialArgument.l4proto {
|
||||
case consts.L4ProtoStr_UDP:
|
||||
// Get udp endpoint.
|
||||
|
||||
// TODO: connection pool.
|
||||
conn, err = dialArgument.bestDialer.Dial(
|
||||
MagicNetwork("udp", dialArgument.mark),
|
||||
dialArgument.bestTarget.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to dial '%v': %w", dialArgument.bestTarget, err)
|
||||
}
|
||||
defer func() {
|
||||
if !connClosed {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
_ = conn.SetDeadline(time.Now().Add(DnsNatTimeout))
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
if c.log.IsLevelEnabled(logrus.DebugLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"to": dialArgument.bestTarget.String(),
|
||||
"pid": req.routingResult.Pid,
|
||||
"pname": ProcessName2String(req.routingResult.Pname[:]),
|
||||
"mac": Mac2String(req.routingResult.Mac[:]),
|
||||
"from": req.realSrc.String(),
|
||||
"network": networkType.String(),
|
||||
"err": err.Error(),
|
||||
}).Debugln("Failed to write UDP(DNS) packet request.")
|
||||
}
|
||||
return fmt.Errorf("failed to write UDP(DNS) packet request: %w", err)
|
||||
}
|
||||
|
||||
// We can block here because we are in a coroutine.
|
||||
respBuf := pool.Get(512)
|
||||
defer pool.Put(respBuf)
|
||||
for {
|
||||
// Wait for response.
|
||||
n, err := conn.Read(respBuf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read from: %v (dialer: %v): %w", dialArgument.bestTarget, dialArgument.bestDialer.Name(), err)
|
||||
}
|
||||
respMsg, err = dnsRespHandler(respBuf[:n], dialArgument.bestTarget)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if respMsg != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
case consts.L4ProtoStr_TCP:
|
||||
// We can block here because we are in a coroutine.
|
||||
|
||||
conn, err = dialArgument.bestDialer.Dial(MagicNetwork("tcp", dialArgument.mark), dialArgument.bestTarget.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to dial proxy to tcp: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if !connClosed {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
_ = conn.SetDeadline(time.Now().Add(DnsNatTimeout))
|
||||
// We should write two byte length in the front of TCP DNS request.
|
||||
bReq := pool.Get(2 + len(data))
|
||||
defer pool.Put(bReq)
|
||||
binary.BigEndian.PutUint16(bReq, uint16(len(data)))
|
||||
copy(bReq[2:], data)
|
||||
_, err = conn.Write(bReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write DNS req: %w", err)
|
||||
}
|
||||
|
||||
// Read two byte length.
|
||||
if _, err = io.ReadFull(conn, bReq[:2]); err != nil {
|
||||
return fmt.Errorf("failed to read DNS resp payload length: %w", err)
|
||||
}
|
||||
respLen := int(binary.BigEndian.Uint16(bReq))
|
||||
// Try to reuse the buf.
|
||||
var buf []byte
|
||||
if len(bReq) < respLen {
|
||||
buf = pool.Get(respLen)
|
||||
defer pool.Put(buf)
|
||||
} else {
|
||||
buf = bReq
|
||||
}
|
||||
var n int
|
||||
if n, err = io.ReadFull(conn, buf[:respLen]); err != nil {
|
||||
return fmt.Errorf("failed to read DNS resp payload: %w", err)
|
||||
}
|
||||
respMsg, err = dnsRespHandler(buf[:n], dialArgument.bestTarget)
|
||||
if respMsg == nil && err == nil {
|
||||
err = fmt.Errorf("bad DNS response")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write DNS resp to client: %w", err)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unexpected l4proto: %v", dialArgument.l4proto)
|
||||
}
|
||||
|
||||
// Close conn before the recursive call.
|
||||
conn.Close()
|
||||
connClosed = true
|
||||
|
||||
// Route response.
|
||||
upstreamIndex, nextUpstream, err := c.routing.ResponseSelect(respMsg, upstream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch upstreamIndex {
|
||||
case consts.DnsResponseOutboundIndex_Accept:
|
||||
// Accept.
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"question": respMsg.Questions,
|
||||
"upstream": upstreamName,
|
||||
}).Traceln("Accept")
|
||||
}
|
||||
case consts.DnsResponseOutboundIndex_Reject:
|
||||
// Reject the request with empty answer.
|
||||
respMsg.Answers = nil
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"question": respMsg.Questions,
|
||||
"upstream": upstreamName,
|
||||
}).Traceln("Reject with empty answer")
|
||||
}
|
||||
default:
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"question": respMsg.Questions,
|
||||
"last_upstream": upstreamName,
|
||||
"next_upstream": nextUpstream.String(),
|
||||
}).Traceln("Change DNS upstream and resend")
|
||||
}
|
||||
return c.dialSend(req, data, nextUpstream, invokingDepth+1)
|
||||
}
|
||||
if upstreamIndex.IsReserved() && c.log.IsLevelEnabled(logrus.InfoLevel) {
|
||||
var qname, qtype string
|
||||
if len(respMsg.Questions) > 0 {
|
||||
q := respMsg.Questions[0]
|
||||
qname = strings.ToLower(q.Name.String())
|
||||
qtype = q.Type.String()
|
||||
}
|
||||
fields := logrus.Fields{
|
||||
"network": networkType.String(),
|
||||
"outbound": dialArgument.bestOutbound.Name,
|
||||
"policy": dialArgument.bestOutbound.GetSelectionPolicy(),
|
||||
"dialer": dialArgument.bestDialer.Name(),
|
||||
"qname": qname,
|
||||
"qtype": qtype,
|
||||
"pid": req.routingResult.Pid,
|
||||
"pname": ProcessName2String(req.routingResult.Pname[:]),
|
||||
"mac": Mac2String(req.routingResult.Mac[:]),
|
||||
}
|
||||
switch upstreamIndex {
|
||||
case consts.DnsResponseOutboundIndex_Accept:
|
||||
c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), RefineAddrPortToShow(dialArgument.bestTarget))
|
||||
case consts.DnsResponseOutboundIndex_Reject:
|
||||
c.log.WithFields(fields).Infof("%v -> reject", RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag))
|
||||
default:
|
||||
return fmt.Errorf("unknown upstream: %v", upstreamIndex.String())
|
||||
}
|
||||
}
|
||||
data, err = respMsg.Pack()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = sendPkt(data, req.realDst, req.realSrc, req.src, req.lConn, req.lanWanFlag); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
102
control/dns_utils.go
Normal file
102
control/dns_utils.go
Normal file
@ -0,0 +1,102 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"hash/fnv"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FlipDnsQuestionCase is used to reduce dns pollution.
|
||||
func FlipDnsQuestionCase(dm *dnsmessage.Message) {
|
||||
if len(dm.Questions) == 0 {
|
||||
return
|
||||
}
|
||||
q := &dm.Questions[0]
|
||||
// For reproducibility, we use dm.ID as input and add some entropy to make the results more discrete.
|
||||
h := fnv.New64()
|
||||
var buf [4]byte
|
||||
binary.BigEndian.PutUint16(buf[:], dm.ID)
|
||||
h.Write(buf[:2])
|
||||
binary.BigEndian.PutUint32(buf[:], 20230204) // entropy
|
||||
h.Write(buf[:])
|
||||
r := rand.New(rand.NewSource(int64(h.Sum64())))
|
||||
perm := r.Perm(int(q.Name.Length))
|
||||
for i := 0; i < int(q.Name.Length/3); i++ {
|
||||
j := perm[i]
|
||||
// Upper to lower; lower to upper.
|
||||
if q.Name.Data[j] >= 'a' && q.Name.Data[j] <= 'z' {
|
||||
q.Name.Data[j] -= 'a' - 'A'
|
||||
} else if q.Name.Data[j] >= 'A' && q.Name.Data[j] <= 'Z' {
|
||||
q.Name.Data[j] += 'a' - 'A'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureAdditionalOpt makes sure there is additional record OPT in the request.
|
||||
func EnsureAdditionalOpt(dm *dnsmessage.Message, isReqAdd bool) (bool, error) {
|
||||
// Check healthy resp.
|
||||
if isReqAdd == dm.Response || dm.RCode != dnsmessage.RCodeSuccess || len(dm.Questions) == 0 {
|
||||
return false, UnsupportedQuestionTypeError
|
||||
}
|
||||
q := dm.Questions[0]
|
||||
switch q.Type {
|
||||
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
|
||||
default:
|
||||
return false, UnsupportedQuestionTypeError
|
||||
}
|
||||
|
||||
for _, ad := range dm.Additionals {
|
||||
if ad.Header.Type == dnsmessage.TypeOPT {
|
||||
// Already has additional record OPT.
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
if !isReqAdd {
|
||||
return false, nil
|
||||
}
|
||||
// Add one.
|
||||
dm.Additionals = append(dm.Additionals, dnsmessage.Resource{
|
||||
Header: dnsmessage.ResourceHeader{
|
||||
Name: dnsmessage.MustNewName("."),
|
||||
Type: dnsmessage.TypeOPT,
|
||||
Class: 512, TTL: 0, Length: 0,
|
||||
},
|
||||
Body: &dnsmessage.OPTResource{
|
||||
Options: nil,
|
||||
},
|
||||
})
|
||||
return false, nil
|
||||
}
|
||||
|
||||
type RscWrapper struct {
|
||||
Rsc dnsmessage.Resource
|
||||
}
|
||||
|
||||
func (w RscWrapper) String() string {
|
||||
var strBody string
|
||||
switch body := w.Rsc.Body.(type) {
|
||||
case *dnsmessage.AResource:
|
||||
strBody = netip.AddrFrom4(body.A).String()
|
||||
case *dnsmessage.AAAAResource:
|
||||
strBody = netip.AddrFrom16(body.AAAA).String()
|
||||
default:
|
||||
strBody = body.GoString()
|
||||
}
|
||||
return fmt.Sprintf("%v(%v): %v", w.Rsc.Header.Name.String(), w.Rsc.Header.Type.String(), strBody)
|
||||
}
|
||||
func FormatDnsRsc(ans []dnsmessage.Resource) string {
|
||||
var w []string
|
||||
for _, a := range ans {
|
||||
w = append(w, RscWrapper{Rsc: a}.String())
|
||||
}
|
||||
return strings.Join(w, "; ")
|
||||
}
|
@ -59,7 +59,7 @@
|
||||
#define OUTBOUND_DIRECT 0
|
||||
#define OUTBOUND_BLOCK 1
|
||||
#define OUTBOUND_MUST_DIRECT 0xFC
|
||||
#define OUTBOUND_CONTROL_PLANE_DIRECT 0xFD
|
||||
#define OUTBOUND_CONTROL_PLANE_ROUTING 0xFD
|
||||
#define OUTBOUND_LOGICAL_OR 0xFE
|
||||
#define OUTBOUND_LOGICAL_AND 0xFF
|
||||
#define OUTBOUND_LOGICAL_MASK 0xFE
|
||||
@ -89,6 +89,7 @@ static const __u32 disable_l4_rx_checksum_key
|
||||
__attribute__((unused, deprecated)) = 3;
|
||||
static const __u32 control_plane_pid_key = 4;
|
||||
static const __u32 control_plane_nat_direct_key = 5;
|
||||
static const __u32 control_plane_dns_routing_key = 6;
|
||||
|
||||
// Outbound Connectivity Map:
|
||||
|
||||
@ -225,23 +226,6 @@ struct {
|
||||
__uint(pinning, LIBBPF_PIN_BY_NAME);
|
||||
} ipproto_hdrsize_map SEC(".maps");
|
||||
|
||||
// Dns upstream:
|
||||
|
||||
struct dns_upstream {
|
||||
__be32 ip4[4];
|
||||
__be32 ip6[4];
|
||||
bool hasIp4;
|
||||
bool hasIp6;
|
||||
__be16 port;
|
||||
};
|
||||
struct {
|
||||
__uint(type, BPF_MAP_TYPE_ARRAY);
|
||||
__type(key, __u32);
|
||||
__type(value, struct dns_upstream);
|
||||
/// FIXME: l4proto is always udp.
|
||||
__uint(max_entries, 1);
|
||||
} dns_upstream_map SEC(".maps");
|
||||
|
||||
// Interface Ips:
|
||||
struct if_params {
|
||||
bool rx_cksm_offload;
|
||||
@ -946,7 +930,7 @@ decap_after_udp_hdr(struct __sk_buff *skb, __u8 ipversion, __u8 ihl,
|
||||
// low -> high: outbound(8b) mark(32b) unused(23b) sign(1b)
|
||||
static __s64 __attribute__((noinline))
|
||||
routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
|
||||
const __be32 _daddr[4], const __be32 mac[4]) {
|
||||
const __be32 daddr[4], const __be32 mac[4]) {
|
||||
#define _l4proto_type flag[0]
|
||||
#define _ipversion_type flag[1]
|
||||
#define _pname &flag[2]
|
||||
@ -957,7 +941,6 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
|
||||
__u32 key = MatchType_L4Proto;
|
||||
__u16 h_dport;
|
||||
__u16 h_sport;
|
||||
__u32 daddr[4];
|
||||
|
||||
/// TODO: BPF_MAP_UPDATE_BATCH ?
|
||||
if (unlikely((ret = bpf_map_update_elem(&l4proto_ipversion_map, &key,
|
||||
@ -992,27 +975,11 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
|
||||
|
||||
// Modify DNS upstream for routing.
|
||||
if (h_dport == 53 && _l4proto_type == L4ProtoType_UDP) {
|
||||
struct dns_upstream *upstream =
|
||||
bpf_map_lookup_elem(&dns_upstream_map, &zero_key);
|
||||
if (upstream && upstream->port != 0) {
|
||||
h_dport = bpf_ntohs(upstream->port);
|
||||
if (_ipversion_type == IpVersionType_4 && upstream->hasIp4) {
|
||||
__builtin_memcpy(daddr, upstream->ip4, IPV6_BYTE_LENGTH);
|
||||
} else if (_ipversion_type == IpVersionType_6 && upstream->hasIp6) {
|
||||
__builtin_memcpy(daddr, upstream->ip6, IPV6_BYTE_LENGTH);
|
||||
} else if (upstream->hasIp4) {
|
||||
__builtin_memcpy(daddr, upstream->ip4, IPV6_BYTE_LENGTH);
|
||||
} else if (upstream->hasIp6) {
|
||||
__builtin_memcpy(daddr, upstream->ip6, IPV6_BYTE_LENGTH);
|
||||
} else {
|
||||
bpf_printk("bad dns upstream; use as-is.");
|
||||
__builtin_memcpy(daddr, _daddr, IPV6_BYTE_LENGTH);
|
||||
}
|
||||
} else {
|
||||
__builtin_memcpy(daddr, _daddr, IPV6_BYTE_LENGTH);
|
||||
__u32 *control_plane_dns_routing =
|
||||
bpf_map_lookup_elem(¶m_map, &control_plane_dns_routing_key);
|
||||
if (control_plane_dns_routing && *control_plane_dns_routing) {
|
||||
return OUTBOUND_CONTROL_PLANE_ROUTING;
|
||||
}
|
||||
} else {
|
||||
__builtin_memcpy(daddr, _daddr, IPV6_BYTE_LENGTH);
|
||||
}
|
||||
lpm_key_instance.trie_key.prefixlen = IPV6_BYTE_LENGTH * 8;
|
||||
__builtin_memcpy(lpm_key_instance.data, daddr, IPV6_BYTE_LENGTH);
|
||||
@ -1169,11 +1136,6 @@ routing(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
|
||||
bpf_printk("MATCHED: match_set->type: %u, match_set->not: %d",
|
||||
match_set->type, match_set->not );
|
||||
#endif
|
||||
if (match_set->outbound == OUTBOUND_DIRECT && h_dport == 53 &&
|
||||
_l4proto_type == L4ProtoType_UDP) {
|
||||
// DNS packet should go through control plane.
|
||||
return OUTBOUND_CONTROL_PLANE_DIRECT | (match_set->mark << 8);
|
||||
}
|
||||
return match_set->outbound | (match_set->mark << 8);
|
||||
}
|
||||
bad_rule = false;
|
||||
@ -1583,7 +1545,6 @@ int tproxy_wan_egress(struct __sk_buff *skb) {
|
||||
bpf_skc_lookup_tcp(skb, &tuple, tuple_size, BPF_F_CURRENT_NETNS, 0);
|
||||
if (sk) {
|
||||
// Not a tproxy WAN response. It is a tproxy LAN response.
|
||||
tproxy_response = false;
|
||||
bpf_sk_release(sk);
|
||||
return TC_ACT_OK;
|
||||
}
|
||||
@ -1594,6 +1555,9 @@ int tproxy_wan_egress(struct __sk_buff *skb) {
|
||||
// Packets from tproxy port.
|
||||
// We need to redirect it to original port.
|
||||
|
||||
// bpf_printk("tproxy_response: %pI6:%u", tuples.dip.u6_addr32,
|
||||
// bpf_ntohs(tuples.dport));
|
||||
|
||||
// Write mac.
|
||||
if ((ret = bpf_skb_store_bytes(skb, offsetof(struct ethhdr, h_dest),
|
||||
ethh.h_source, sizeof(ethh.h_source), 0))) {
|
||||
@ -1665,7 +1629,7 @@ int tproxy_wan_egress(struct __sk_buff *skb) {
|
||||
struct dst_routing_result *dst =
|
||||
bpf_map_lookup_elem(&tcp_dst_map, &key_src);
|
||||
if (!dst) {
|
||||
// Do not impact previous connections.
|
||||
// Do not impact previous connections and server connections.
|
||||
return TC_ACT_OK;
|
||||
}
|
||||
outbound = dst->routing_result.outbound;
|
||||
@ -1978,7 +1942,7 @@ int tproxy_wan_ingress(struct __sk_buff *skb) {
|
||||
return TC_ACT_SHOT;
|
||||
}
|
||||
|
||||
// bpf_printk("real from: %pI4:%u", &ori_src.ip, bpf_ntohs(ori_src.port));
|
||||
// bpf_printk("real from: %pI6:%u", ori_src.ip, bpf_ntohs(ori_src.port));
|
||||
|
||||
// Print packet in hex for debugging (checksum or something else).
|
||||
// bpf_printk("UDP EGRESS OK");
|
||||
|
@ -9,79 +9,96 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"github.com/cilium/ebpf"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/component/routing"
|
||||
"github.com/v2rayA/dae/component/routing/domain_matcher"
|
||||
"github.com/v2rayA/dae/config"
|
||||
"github.com/v2rayA/dae/pkg/config_parser"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type RoutingMatcherBuilder struct {
|
||||
*routing.DefaultMatcherBuilder
|
||||
outboundName2Id map[string]uint8
|
||||
bpf *bpfObjects
|
||||
rules []bpfMatchSet
|
||||
simulatedLpmTries [][]netip.Prefix
|
||||
simulatedDomainSet []routing.DomainSet
|
||||
|
||||
err error
|
||||
fallback *routing.Outbound
|
||||
}
|
||||
|
||||
func NewRoutingMatcherBuilder(outboundName2Id map[string]uint8, bpf *bpfObjects) *RoutingMatcherBuilder {
|
||||
return &RoutingMatcherBuilder{outboundName2Id: outboundName2Id, bpf: bpf}
|
||||
func NewRoutingMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, outboundName2Id map[string]uint8, bpf *bpfObjects, fallback config.FunctionOrString) (b *RoutingMatcherBuilder, err error) {
|
||||
b = &RoutingMatcherBuilder{outboundName2Id: outboundName2Id, bpf: bpf}
|
||||
rulesBuilder := routing.NewRulesBuilder(log)
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_Domain, routing.PlainParserFactory(b.addDomain))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_Ip, routing.IpParserFactory(b.addIp))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_SourceIp, routing.IpParserFactory(b.addSourceIp))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_Port, routing.PortRangeParserFactory(b.addPort))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_SourcePort, routing.PortRangeParserFactory(b.addSourcePort))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_L4Proto, routing.L4ProtoParserFactory(b.addL4Proto))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_Mac, routing.MacParserFactory(b.addSourceMac))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_ProcessName, routing.ProcessNameParserFactory(b.addProcessName))
|
||||
rulesBuilder.RegisterFunctionParser(consts.Function_IpVersion, routing.IpVersionParserFactory(b.addIpVersion))
|
||||
if err = rulesBuilder.Apply(rules); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = b.addFallback(fallback); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *RoutingMatcherBuilder) OutboundToId(outbound string) uint8 {
|
||||
func (b *RoutingMatcherBuilder) outboundToId(outbound string) (uint8, error) {
|
||||
var outboundId uint8
|
||||
switch outbound {
|
||||
case routing.FakeOutbound_MUST_DIRECT:
|
||||
case consts.OutboundMustDirect.String():
|
||||
outboundId = uint8(consts.OutboundMustDirect)
|
||||
case routing.FakeOutbound_AND:
|
||||
outboundId = uint8(consts.OutboundLogicalAnd)
|
||||
case routing.FakeOutbound_OR:
|
||||
case consts.OutboundLogicalOr.String():
|
||||
outboundId = uint8(consts.OutboundLogicalOr)
|
||||
case consts.OutboundLogicalAnd.String():
|
||||
outboundId = uint8(consts.OutboundLogicalAnd)
|
||||
default:
|
||||
var ok bool
|
||||
outboundId, ok = b.outboundName2Id[outbound]
|
||||
if !ok {
|
||||
b.err = fmt.Errorf("outbound (group) %v not found; please define it in section \"group\"", strconv.Quote(outbound))
|
||||
return 0, fmt.Errorf("outbound (group) %v not found; please define it in section \"group\"", strconv.Quote(outbound))
|
||||
}
|
||||
}
|
||||
return outboundId
|
||||
return outboundId, nil
|
||||
}
|
||||
|
||||
func (b *RoutingMatcherBuilder) AddDomain(f *config_parser.Function, key string, values []string, outbound *routing.Outbound) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
func (b *RoutingMatcherBuilder) addDomain(f *config_parser.Function, key string, values []string, outbound *routing.Outbound) (err error) {
|
||||
switch consts.RoutingDomainKey(key) {
|
||||
case consts.RoutingDomainKey_Regex,
|
||||
consts.RoutingDomainKey_Full,
|
||||
consts.RoutingDomainKey_Keyword,
|
||||
consts.RoutingDomainKey_Suffix:
|
||||
default:
|
||||
b.err = fmt.Errorf("AddDomain: unsupported key: %v", key)
|
||||
return
|
||||
return fmt.Errorf("addDomain: unsupported key: %v", key)
|
||||
}
|
||||
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
|
||||
Key: consts.RoutingDomainKey(key),
|
||||
RuleIndex: len(b.rules),
|
||||
Domains: values,
|
||||
})
|
||||
outboundId, err := b.outboundToId(outbound.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, bpfMatchSet{
|
||||
Type: uint8(consts.MatchType_DomainSet),
|
||||
Not: f.Not,
|
||||
Outbound: b.OutboundToId(outbound.Name),
|
||||
Outbound: outboundId,
|
||||
Mark: outbound.Mark,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RoutingMatcherBuilder) AddSourceMac(f *config_parser.Function, macAddrs [][6]byte, outbound *routing.Outbound) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
func (b *RoutingMatcherBuilder) addSourceMac(f *config_parser.Function, macAddrs [][6]byte, outbound *routing.Outbound) (err error) {
|
||||
var addr16 [16]byte
|
||||
values := make([]netip.Prefix, 0, len(macAddrs))
|
||||
for _, mac := range macAddrs {
|
||||
@ -91,41 +108,51 @@ func (b *RoutingMatcherBuilder) AddSourceMac(f *config_parser.Function, macAddrs
|
||||
}
|
||||
lpmTrieIndex := len(b.simulatedLpmTries)
|
||||
b.simulatedLpmTries = append(b.simulatedLpmTries, values)
|
||||
outboundId, err := b.outboundToId(outbound.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
set := bpfMatchSet{
|
||||
Value: [16]byte{},
|
||||
Type: uint8(consts.MatchType_Mac),
|
||||
Not: f.Not,
|
||||
Outbound: b.OutboundToId(outbound.Name),
|
||||
Outbound: outboundId,
|
||||
Mark: outbound.Mark,
|
||||
}
|
||||
binary.LittleEndian.PutUint32(set.Value[:], uint32(lpmTrieIndex))
|
||||
b.rules = append(b.rules, set)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RoutingMatcherBuilder) AddIp(f *config_parser.Function, values []netip.Prefix, outbound *routing.Outbound) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
func (b *RoutingMatcherBuilder) addIp(f *config_parser.Function, values []netip.Prefix, outbound *routing.Outbound) (err error) {
|
||||
lpmTrieIndex := len(b.simulatedLpmTries)
|
||||
b.simulatedLpmTries = append(b.simulatedLpmTries, values)
|
||||
outboundId, err := b.outboundToId(outbound.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
set := bpfMatchSet{
|
||||
Value: [16]byte{},
|
||||
Type: uint8(consts.MatchType_IpSet),
|
||||
Not: f.Not,
|
||||
Outbound: b.OutboundToId(outbound.Name),
|
||||
Outbound: outboundId,
|
||||
Mark: outbound.Mark,
|
||||
}
|
||||
binary.LittleEndian.PutUint32(set.Value[:], uint32(lpmTrieIndex))
|
||||
b.rules = append(b.rules, set)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RoutingMatcherBuilder) AddPort(f *config_parser.Function, values [][2]uint16, outbound *routing.Outbound) {
|
||||
func (b *RoutingMatcherBuilder) addPort(f *config_parser.Function, values [][2]uint16, outbound *routing.Outbound) (err error) {
|
||||
for i, value := range values {
|
||||
outboundName := routing.FakeOutbound_OR
|
||||
outboundName := consts.OutboundLogicalOr.String()
|
||||
if i == len(values)-1 {
|
||||
outboundName = outbound.Name
|
||||
}
|
||||
outboundId, err := b.outboundToId(outboundName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, bpfMatchSet{
|
||||
Type: uint8(consts.MatchType_Port),
|
||||
Value: _bpfPortRange{
|
||||
@ -133,35 +160,42 @@ func (b *RoutingMatcherBuilder) AddPort(f *config_parser.Function, values [][2]u
|
||||
PortEnd: value[1],
|
||||
}.Encode(),
|
||||
Not: f.Not,
|
||||
Outbound: b.OutboundToId(outboundName),
|
||||
Outbound: outboundId,
|
||||
Mark: outbound.Mark,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RoutingMatcherBuilder) AddSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *routing.Outbound) {
|
||||
if b.err != nil {
|
||||
return
|
||||
}
|
||||
func (b *RoutingMatcherBuilder) addSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *routing.Outbound) (err error) {
|
||||
lpmTrieIndex := len(b.simulatedLpmTries)
|
||||
b.simulatedLpmTries = append(b.simulatedLpmTries, values)
|
||||
outboundId, err := b.outboundToId(outbound.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
set := bpfMatchSet{
|
||||
Value: [16]byte{},
|
||||
Type: uint8(consts.MatchType_SourceIpSet),
|
||||
Not: f.Not,
|
||||
Outbound: b.OutboundToId(outbound.Name),
|
||||
Outbound: outboundId,
|
||||
Mark: outbound.Mark,
|
||||
}
|
||||
binary.LittleEndian.PutUint32(set.Value[:], uint32(lpmTrieIndex))
|
||||
b.rules = append(b.rules, set)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RoutingMatcherBuilder) AddSourcePort(f *config_parser.Function, values [][2]uint16, outbound *routing.Outbound) {
|
||||
func (b *RoutingMatcherBuilder) addSourcePort(f *config_parser.Function, values [][2]uint16, outbound *routing.Outbound) (err error) {
|
||||
for i, value := range values {
|
||||
outboundName := routing.FakeOutbound_OR
|
||||
outboundName := consts.OutboundLogicalOr.String()
|
||||
if i == len(values)-1 {
|
||||
outboundName = outbound.Name
|
||||
}
|
||||
outboundId, err := b.outboundToId(outboundName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, bpfMatchSet{
|
||||
Type: uint8(consts.MatchType_SourcePort),
|
||||
Value: _bpfPortRange{
|
||||
@ -169,70 +203,83 @@ func (b *RoutingMatcherBuilder) AddSourcePort(f *config_parser.Function, values
|
||||
PortEnd: value[1],
|
||||
}.Encode(),
|
||||
Not: f.Not,
|
||||
Outbound: b.OutboundToId(outboundName),
|
||||
Outbound: outboundId,
|
||||
Mark: outbound.Mark,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RoutingMatcherBuilder) AddL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *routing.Outbound) {
|
||||
if b.err != nil {
|
||||
return
|
||||
func (b *RoutingMatcherBuilder) addL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *routing.Outbound) (err error) {
|
||||
outboundId, err := b.outboundToId(outbound.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, bpfMatchSet{
|
||||
Value: [16]byte{byte(values)},
|
||||
Type: uint8(consts.MatchType_L4Proto),
|
||||
Not: f.Not,
|
||||
Outbound: b.OutboundToId(outbound.Name),
|
||||
Outbound: outboundId,
|
||||
Mark: outbound.Mark,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RoutingMatcherBuilder) AddIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *routing.Outbound) {
|
||||
if b.err != nil {
|
||||
return
|
||||
func (b *RoutingMatcherBuilder) addIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *routing.Outbound) (err error) {
|
||||
outboundId, err := b.outboundToId(outbound.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, bpfMatchSet{
|
||||
Value: [16]byte{byte(values)},
|
||||
Type: uint8(consts.MatchType_IpVersion),
|
||||
Not: f.Not,
|
||||
Outbound: b.OutboundToId(outbound.Name),
|
||||
Outbound: outboundId,
|
||||
Mark: outbound.Mark,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RoutingMatcherBuilder) AddProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *routing.Outbound) {
|
||||
func (b *RoutingMatcherBuilder) addProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *routing.Outbound) (err error) {
|
||||
for i, value := range values {
|
||||
outboundName := routing.FakeOutbound_OR
|
||||
outboundName := consts.OutboundLogicalOr.String()
|
||||
if i == len(values)-1 {
|
||||
outboundName = outbound.Name
|
||||
}
|
||||
outboundId, err := b.outboundToId(outboundName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
matchSet := bpfMatchSet{
|
||||
Type: uint8(consts.MatchType_ProcessName),
|
||||
Not: f.Not,
|
||||
Outbound: b.OutboundToId(outboundName),
|
||||
Outbound: outboundId,
|
||||
Mark: outbound.Mark,
|
||||
}
|
||||
copy(matchSet.Value[:], value[:])
|
||||
b.rules = append(b.rules, matchSet)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RoutingMatcherBuilder) AddFallback(outbound *routing.Outbound) {
|
||||
if b.err != nil {
|
||||
return
|
||||
func (b *RoutingMatcherBuilder) addFallback(fallbackOutbound config.FunctionOrString) (err error) {
|
||||
outbound, err := routing.ParseOutbound(config.FunctionOrStringToFunction(fallbackOutbound))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
outboundId, err := b.outboundToId(outbound.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.rules = append(b.rules, bpfMatchSet{
|
||||
Type: uint8(consts.MatchType_Fallback),
|
||||
Outbound: b.OutboundToId(outbound.Name),
|
||||
Outbound: outboundId,
|
||||
Mark: outbound.Mark,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RoutingMatcherBuilder) BuildKernspace() (err error) {
|
||||
if b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
// Update lpm_array_map.
|
||||
for i, cidrs := range b.simulatedLpmTries {
|
||||
var keys []_bpfLpmKey
|
||||
@ -255,8 +302,7 @@ func (b *RoutingMatcherBuilder) BuildKernspace() (err error) {
|
||||
// Write routings.
|
||||
// Fallback rule MUST be the last.
|
||||
if b.rules[len(b.rules)-1].Type != uint8(consts.MatchType_Fallback) {
|
||||
b.err = fmt.Errorf("fallback rule MUST be the last")
|
||||
return b.err
|
||||
return fmt.Errorf("fallback rule MUST be the last")
|
||||
}
|
||||
routingsLen := uint32(len(b.rules))
|
||||
routingsKeys := common.ARangeU32(routingsLen)
|
||||
@ -266,34 +312,28 @@ func (b *RoutingMatcherBuilder) BuildKernspace() (err error) {
|
||||
return fmt.Errorf("BpfMapBatchUpdate: %w", err)
|
||||
}
|
||||
|
||||
// Release.
|
||||
b.simulatedLpmTries = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RoutingMatcherBuilder) BuildUserspace() (matcher *RoutingMatcher, err error) {
|
||||
if b.err != nil {
|
||||
return nil, b.err
|
||||
}
|
||||
var m RoutingMatcher
|
||||
func (b *RoutingMatcherBuilder) BuildUserspace(lpmArrayMap *ebpf.Map) (matcher *RoutingMatcher, err error) {
|
||||
// Build domainMatcher
|
||||
m.domainMatcher = domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen)
|
||||
domainMatcher := domain_matcher.NewAhocorasickSlimtrie(consts.MaxMatchSetLen)
|
||||
for _, domains := range b.simulatedDomainSet {
|
||||
m.domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
|
||||
domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
|
||||
}
|
||||
if err = m.domainMatcher.Build(); err != nil {
|
||||
if err = domainMatcher.Build(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write routings.
|
||||
// Fallback rule MUST be the last.
|
||||
if b.rules[len(b.rules)-1].Type != uint8(consts.MatchType_Fallback) {
|
||||
b.err = fmt.Errorf("fallback rule MUST be the last")
|
||||
return nil, b.err
|
||||
return nil, fmt.Errorf("fallback rule MUST be the last")
|
||||
}
|
||||
m.matches = b.rules
|
||||
|
||||
// Release.
|
||||
b.simulatedDomainSet = nil
|
||||
return &m, nil
|
||||
return &RoutingMatcher{
|
||||
lpmArrayMap: lpmArrayMap,
|
||||
domainMatcher: domainMatcher,
|
||||
matches: b.rules,
|
||||
}, nil
|
||||
}
|
||||
|
@ -8,13 +8,11 @@ package control
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"github.com/Asphaltt/lpmtrie"
|
||||
"github.com/cilium/ebpf"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/component/routing"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type RoutingMatcher struct {
|
||||
@ -33,11 +31,11 @@ func (m *RoutingMatcher) Match(
|
||||
ipVersion consts.IpVersionType,
|
||||
l4proto consts.L4ProtoType,
|
||||
domain string,
|
||||
processName string,
|
||||
processName [16]uint8,
|
||||
mac []byte,
|
||||
) (outboundIndex consts.OutboundIndex, err error) {
|
||||
) (outboundIndex consts.OutboundIndex, mark uint32, err error) {
|
||||
if len(sourceAddr) != net.IPv6len || len(destAddr) != net.IPv6len || len(mac) != net.IPv6len {
|
||||
return 0, fmt.Errorf("bad address length")
|
||||
return 0, 0, fmt.Errorf("bad address length")
|
||||
}
|
||||
lpmKeys := make([]*_bpfLpmKey, consts.MatchType_Mac+1)
|
||||
lpmKeys[consts.MatchType_IpSet] = &_bpfLpmKey{
|
||||
@ -68,11 +66,13 @@ func (m *RoutingMatcher) Match(
|
||||
lpmIndex := uint32(binary.LittleEndian.Uint16(match.Value[:]))
|
||||
var lpm *ebpf.Map
|
||||
if err = m.lpmArrayMap.Lookup(lpmIndex, &lpm); err != nil {
|
||||
//logrus.Debugln("m.lpmArrayMap.Lookup:", err)
|
||||
break
|
||||
}
|
||||
var v uint32
|
||||
if err = lpm.Lookup(*lpmKeys[int(match.Type)], &v); err != nil {
|
||||
_ = lpm.Close()
|
||||
//logrus.Debugln("lpm.Lookup:", err, lpmKeys[int(match.Type)], match.Type, destAddr)
|
||||
break
|
||||
}
|
||||
_ = lpm.Close()
|
||||
@ -104,13 +104,13 @@ func (m *RoutingMatcher) Match(
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_ProcessName:
|
||||
if processName != "" && string(match.Value[:]) == processName {
|
||||
if processName[0] != 0 && match.Value == processName {
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_Fallback:
|
||||
goodSubrule = true
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown match type: %v", match.Type)
|
||||
return 0, 0, fmt.Errorf("unknown match type: %v", match.Type)
|
||||
}
|
||||
beforeNextLoop:
|
||||
outbound := consts.OutboundIndex(match.Outbound)
|
||||
@ -133,27 +133,10 @@ func (m *RoutingMatcher) Match(
|
||||
// Tail of a rule (line).
|
||||
// Decide whether to hit.
|
||||
if !badRule {
|
||||
if outbound == consts.OutboundDirect && destPort == 53 &&
|
||||
l4proto == consts.L4ProtoType_UDP {
|
||||
// DNS packet should go through control plane.
|
||||
return consts.OutboundControlPlaneDirect, nil
|
||||
}
|
||||
return outbound, nil
|
||||
return outbound, match.Mark, nil
|
||||
}
|
||||
badRule = false
|
||||
}
|
||||
}
|
||||
return 0, fmt.Errorf("no match set hit")
|
||||
}
|
||||
|
||||
func cidrToLpmTrieKey(prefix netip.Prefix) lpmtrie.Key {
|
||||
bits := prefix.Bits()
|
||||
if prefix.Addr().Is4() {
|
||||
bits += 96
|
||||
}
|
||||
ip := prefix.Addr().As16()
|
||||
return lpmtrie.Key{
|
||||
PrefixLen: bits,
|
||||
Data: ip[:],
|
||||
}
|
||||
return 0, 0, fmt.Errorf("no match set hit")
|
||||
}
|
||||
|
@ -59,20 +59,27 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
|
||||
}
|
||||
dst = netip.AddrPortFrom(dstAddr, common.Htons(value.Port))
|
||||
}
|
||||
src = common.ConvergeAddrPort(src)
|
||||
dst = common.ConvergeAddrPort(dst)
|
||||
|
||||
var outboundIndex = consts.OutboundIndex(routingResult.Outbound)
|
||||
|
||||
switch outboundIndex {
|
||||
case consts.OutboundDirect:
|
||||
case consts.OutboundMustDirect:
|
||||
fallthrough
|
||||
case consts.OutboundControlPlaneDirect:
|
||||
outboundIndex = consts.OutboundDirect
|
||||
case consts.OutboundControlPlaneRouting:
|
||||
if outboundIndex, routingResult.Mark, err = c.Route(src, dst, domain, consts.L4ProtoType_TCP, routingResult); err != nil {
|
||||
return err
|
||||
}
|
||||
routingResult.Outbound = uint8(outboundIndex)
|
||||
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.Tracef("outbound: %v => %v",
|
||||
consts.OutboundControlPlaneRouting.String(),
|
||||
outboundIndex.String(),
|
||||
consts.OutboundDirect.String(),
|
||||
)
|
||||
}
|
||||
outboundIndex = consts.OutboundDirect
|
||||
default:
|
||||
}
|
||||
outbound := c.outbounds[outboundIndex]
|
||||
@ -104,8 +111,7 @@ func (c *ControlPlane) handleConn(lConn net.Conn) (err error) {
|
||||
}
|
||||
|
||||
// Dial and relay.
|
||||
dst = netip.AddrPortFrom(common.ConvergeIp(dst.Addr()), dst.Port())
|
||||
rConn, err := d.Dial(GetNetwork("tcp", routingResult.Mark), c.ChooseDialTarget(outboundIndex, dst, domain))
|
||||
rConn, err := d.Dial(MagicNetwork("tcp", routingResult.Mark), c.ChooseDialTarget(outboundIndex, dst, domain))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to dial %v: %w", dst, err)
|
||||
}
|
||||
|
448
control/udp.go
448
control/udp.go
@ -7,21 +7,18 @@ package control
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
|
||||
"github.com/mzz2017/softwind/pool"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/v2rayA/dae/common"
|
||||
"github.com/v2rayA/dae/common/consts"
|
||||
"github.com/v2rayA/dae/component/outbound/dialer"
|
||||
"github.com/v2rayA/dae/component/sniffing"
|
||||
internal "github.com/v2rayA/dae/pkg/ebpf_internal"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
@ -33,16 +30,13 @@ const (
|
||||
MaxRetry = 2
|
||||
)
|
||||
|
||||
var (
|
||||
UnspecifiedAddr4 = netip.AddrFrom4([4]byte{})
|
||||
UnspecifiedAddr6 = netip.AddrFrom16([16]byte{})
|
||||
)
|
||||
|
||||
func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Message, timeout time.Duration) {
|
||||
var dnsmsg dnsmessage.Message
|
||||
if err := dnsmsg.Unpack(data); err == nil {
|
||||
//log.Printf("DEBUG: lookup %v", dnsmsg.Questions[0].Name)
|
||||
return &dnsmsg, DnsNatTimeout
|
||||
if sniffDns {
|
||||
var dnsmsg dnsmessage.Message
|
||||
if err := dnsmsg.Unpack(data); err == nil {
|
||||
//log.Printf("DEBUG: lookup %v", dnsmsg.Questions[0].Name)
|
||||
return &dnsmsg, DnsNatTimeout
|
||||
}
|
||||
}
|
||||
return nil, DefaultNatTimeout
|
||||
}
|
||||
@ -57,29 +51,34 @@ func ParseAddrHdr(data []byte) (hdr *bpfDstRoutingResult, dataOffset int, err er
|
||||
return &_hdr, dataOffset, nil
|
||||
}
|
||||
|
||||
func sendPktWithHdrWithFlag(data []byte, mark uint32, from netip.AddrPort, lConn *net.UDPConn, to netip.AddrPort, lanWanFlag consts.LanWanFlag) error {
|
||||
func sendPktWithHdrWithFlag(data []byte, realFrom netip.AddrPort, lConn *net.UDPConn, to netip.AddrPort, lanWanFlag consts.LanWanFlag) error {
|
||||
realFrom16 := realFrom.Addr().As16()
|
||||
hdr := bpfDstRoutingResult{
|
||||
Ip: common.Ipv6ByteSliceToUint32Array(from.Addr().AsSlice()),
|
||||
Port: common.Htons(from.Port()),
|
||||
Ip: common.Ipv6ByteSliceToUint32Array(realFrom16[:]),
|
||||
Port: common.Htons(realFrom.Port()),
|
||||
RoutingResult: bpfRoutingResult{
|
||||
Outbound: uint8(lanWanFlag), // Pass some message to the kernel program.
|
||||
},
|
||||
}
|
||||
buf := pool.Get(int(unsafe.Sizeof(hdr)) + len(data))
|
||||
defer pool.Put(buf)
|
||||
b := buffer.NewBufferFrom(buf)
|
||||
// Do not put this 'buf' because it has been taken by buffer.
|
||||
b := buffer.NewBuffer(int(unsafe.Sizeof(hdr)) + len(data))
|
||||
defer b.Put()
|
||||
if err := gob.NewEncoder(b).Encode(&hdr); err != nil {
|
||||
// Use internal.NativeEndian due to already big endian.
|
||||
if err := binary.Write(b, internal.NativeEndian, hdr); err != nil {
|
||||
return err
|
||||
}
|
||||
copy(buf[int(unsafe.Sizeof(hdr)):], data)
|
||||
//log.Println("from", from, "to", to)
|
||||
_, err := lConn.WriteToUDPAddrPort(buf, to)
|
||||
b.Write(data)
|
||||
//logrus.Debugln("sendPktWithHdrWithFlag: from", realFrom, "to", to)
|
||||
_, err := lConn.WriteToUDPAddrPort(b.Bytes(), to)
|
||||
return err
|
||||
}
|
||||
|
||||
// sendPkt uses bind first, and fallback to send hdr if addr is in use.
|
||||
func sendPkt(data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn *net.UDPConn, lanWanFlag consts.LanWanFlag) (err error) {
|
||||
if lanWanFlag == consts.LanWanFlag_IsWan {
|
||||
return sendPktWithHdrWithFlag(data, from, lConn, to, lanWanFlag)
|
||||
}
|
||||
|
||||
d := net.Dialer{Control: func(network, address string, c syscall.RawConn) error {
|
||||
return dialer.BindControl(c, from)
|
||||
}}
|
||||
@ -88,7 +87,7 @@ func sendPkt(data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn
|
||||
if err != nil {
|
||||
if errors.Is(err, syscall.EADDRINUSE) {
|
||||
// Port collision, use traditional method.
|
||||
return sendPktWithHdrWithFlag(data, 0, from, lConn, to, lanWanFlag)
|
||||
return sendPktWithHdrWithFlag(data, from, lConn, to, lanWanFlag)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@ -98,36 +97,6 @@ func sendPkt(data []byte, from netip.AddrPort, realTo, to netip.AddrPort, lConn
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ControlPlane) WriteToUDP(lanWanFlag consts.LanWanFlag, lConn *net.UDPConn, realTo, to netip.AddrPort, isDNS bool, dummyFrom *netip.AddrPort, validateRushAnsFunc func(from netip.AddrPort) bool) UdpHandler {
|
||||
return func(data []byte, from netip.AddrPort) (err error) {
|
||||
// Do not return conn-unrelated err in this func.
|
||||
|
||||
if isDNS {
|
||||
validateRushAns := validateRushAnsFunc(from)
|
||||
data, err = c.DnsRespHandler(data, validateRushAns)
|
||||
if err != nil {
|
||||
if validateRushAns && errors.Is(err, SuspectedRushAnswerError) {
|
||||
// Reject DNS rush-answer.
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"from": from,
|
||||
}).Tracef("DNS rush-answer rejected")
|
||||
return err
|
||||
}
|
||||
if c.log.IsLevelEnabled(logrus.DebugLevel) {
|
||||
c.log.Debugf("DnsRespHandler: %v", err)
|
||||
}
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if dummyFrom != nil {
|
||||
from = *dummyFrom
|
||||
}
|
||||
return sendPkt(data, from, realTo, to, lConn, lanWanFlag)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, realDst netip.AddrPort, routingResult *bpfRoutingResult) (err error) {
|
||||
var lanWanFlag consts.LanWanFlag
|
||||
var realSrc netip.AddrPort
|
||||
@ -142,60 +111,11 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
|
||||
realSrc = netip.AddrPortFrom(pktDst.Addr(), src.Port())
|
||||
}
|
||||
|
||||
mustDirect := false
|
||||
outboundIndex := consts.OutboundIndex(routingResult.Outbound)
|
||||
switch outboundIndex {
|
||||
case consts.OutboundDirect:
|
||||
case consts.OutboundMustDirect:
|
||||
mustDirect = true
|
||||
fallthrough
|
||||
case consts.OutboundControlPlaneDirect:
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.Tracef("outbound: %v => %v",
|
||||
outboundIndex.String(),
|
||||
consts.OutboundDirect.String(),
|
||||
)
|
||||
}
|
||||
outboundIndex = consts.OutboundDirect
|
||||
default:
|
||||
}
|
||||
if int(outboundIndex) >= len(c.outbounds) {
|
||||
return fmt.Errorf("outbound %v out of range [0, %v]", outboundIndex, len(c.outbounds)-1)
|
||||
}
|
||||
outbound := c.outbounds[outboundIndex]
|
||||
// To keep consistency with kernel program, we only sniff DNS request sent to 53.
|
||||
dnsMessage, natTimeout := ChooseNatTimeout(data, realDst.Port() == 53)
|
||||
// We should cache DNS records and set record TTL to 0, in order to monitor the dns req and resp in real time.
|
||||
isDns := dnsMessage != nil
|
||||
var dummyFrom *netip.AddrPort
|
||||
destToSend := realDst
|
||||
if isDns {
|
||||
if resp := c.LookupDnsRespCache_(dnsMessage); resp != nil {
|
||||
// Send cache to client directly.
|
||||
if err = sendPkt(resp, destToSend, realSrc, src, lConn, lanWanFlag); err != nil {
|
||||
return fmt.Errorf("failed to write cached DNS resp: %w", err)
|
||||
}
|
||||
if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 {
|
||||
q := dnsMessage.Questions[0]
|
||||
c.log.Tracef("UDP(DNS) %v <-[%v]-> Cache: %v %v",
|
||||
RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), outbound.Name, strings.ToLower(q.Name.String()), q.Type,
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flip dns question to reduce dns pollution.
|
||||
FlipDnsQuestionCase(dnsMessage)
|
||||
// Make sure there is additional record OPT in the request to filter DNS rush-answer in the response process.
|
||||
// Because rush-answer has no resp OPT. We can distinguish them from multiple responses.
|
||||
// Note that additional record OPT may not be supported by home router either.
|
||||
_, _ = EnsureAdditionalOpt(dnsMessage, true)
|
||||
|
||||
// Re-pack DNS packet.
|
||||
if data, err = dnsMessage.Pack(); err != nil {
|
||||
return fmt.Errorf("pack flipped dns packet: %w", err)
|
||||
}
|
||||
} else {
|
||||
if !isDns {
|
||||
// Sniff Quic
|
||||
sniffer := sniffing.NewPacketSniffer(data)
|
||||
domain, err = sniffer.SniffQuic()
|
||||
@ -206,247 +126,137 @@ func (c *ControlPlane) handlePkt(lConn *net.UDPConn, data []byte, src, pktDst, r
|
||||
sniffer.Close()
|
||||
}
|
||||
|
||||
l4proto := consts.L4ProtoStr_UDP
|
||||
ipversion := consts.IpVersionFromAddr(realDst.Addr())
|
||||
var dialerForNew *dialer.Dialer
|
||||
// Get outbound.
|
||||
outboundIndex := consts.OutboundIndex(routingResult.Outbound)
|
||||
switch outboundIndex {
|
||||
case consts.OutboundDirect:
|
||||
case consts.OutboundMustDirect:
|
||||
outboundIndex = consts.OutboundDirect
|
||||
isDns = false // Regard as plain traffic.
|
||||
case consts.OutboundControlPlaneRouting:
|
||||
if isDns {
|
||||
// Routing of DNS packets are managed by DNS controller.
|
||||
break
|
||||
}
|
||||
|
||||
// For DNS request, modify realDst to dns upstream.
|
||||
// NOTICE: We might modify l4proto and ipversion.
|
||||
dnsUpstream, err := c.dnsUpstream.GetUpstream()
|
||||
if err != nil {
|
||||
return err
|
||||
if outboundIndex, routingResult.Mark, err = c.Route(realSrc, realDst, domain, consts.L4ProtoType_TCP, routingResult); err != nil {
|
||||
return err
|
||||
}
|
||||
routingResult.Outbound = uint8(outboundIndex)
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.Tracef("outbound: %v => %v",
|
||||
consts.OutboundControlPlaneRouting.String(),
|
||||
outboundIndex.String(),
|
||||
)
|
||||
}
|
||||
default:
|
||||
}
|
||||
if isDns {
|
||||
return c.dnsController.Handle_(dnsMessage, &udpRequest{
|
||||
lanWanFlag: lanWanFlag,
|
||||
realSrc: realSrc,
|
||||
realDst: realDst,
|
||||
src: src,
|
||||
lConn: lConn,
|
||||
routingResult: routingResult,
|
||||
})
|
||||
}
|
||||
if isDns && dnsUpstream != nil && !mustDirect {
|
||||
// Modify dns target to upstream.
|
||||
// NOTICE: Routing was calculated in advance by the eBPF program.
|
||||
|
||||
/// Choose the best l4proto+ipversion dialer, and change taregt DNS to the best ipversion DNS upstream for DNS request.
|
||||
// Get available ipversions and l4protos for DNS upstream.
|
||||
ipversions, l4protos := dnsUpstream.SupportedNetworks()
|
||||
var (
|
||||
bestDialer *dialer.Dialer
|
||||
bestLatency time.Duration
|
||||
bestTarget netip.AddrPort
|
||||
)
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"ipversions": ipversions,
|
||||
"l4protos": l4protos,
|
||||
"src": realSrc.String(),
|
||||
}).Traceln("Choose DNS path")
|
||||
}
|
||||
// Get the min latency path.
|
||||
networkType := dialer.NetworkType{
|
||||
IsDns: isDns,
|
||||
}
|
||||
for _, ver := range ipversions {
|
||||
for _, proto := range l4protos {
|
||||
networkType.L4Proto = proto
|
||||
networkType.IpVersion = ver
|
||||
d, latency, err := outbound.Select(&networkType)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"name": d.Name(),
|
||||
"latency": latency,
|
||||
"network": networkType.String(),
|
||||
"outbound": outbound.Name,
|
||||
}).Traceln("Choice")
|
||||
}
|
||||
if bestDialer == nil || latency < bestLatency {
|
||||
bestDialer = d
|
||||
bestLatency = latency
|
||||
l4proto = proto
|
||||
ipversion = ver
|
||||
}
|
||||
}
|
||||
}
|
||||
switch ipversion {
|
||||
case consts.IpVersionStr_4:
|
||||
bestTarget = netip.AddrPortFrom(dnsUpstream.Ip4, dnsUpstream.Port)
|
||||
case consts.IpVersionStr_6:
|
||||
bestTarget = netip.AddrPortFrom(dnsUpstream.Ip6, dnsUpstream.Port)
|
||||
}
|
||||
dialerForNew = bestDialer
|
||||
dummyFrom = &realDst
|
||||
destToSend = bestTarget
|
||||
if c.log.IsLevelEnabled(logrus.TraceLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"Original": RefineAddrPortToShow(realDst),
|
||||
"New": destToSend,
|
||||
"Network": string(l4proto) + string(ipversion),
|
||||
}).Traceln("Modify DNS target")
|
||||
}
|
||||
if int(outboundIndex) >= len(c.outbounds) {
|
||||
return fmt.Errorf("outbound %v out of range [0, %v]", outboundIndex, len(c.outbounds)-1)
|
||||
}
|
||||
outbound := c.outbounds[outboundIndex]
|
||||
|
||||
// Select dialer from outbound (dialer group).
|
||||
networkType := &dialer.NetworkType{
|
||||
L4Proto: l4proto,
|
||||
IpVersion: ipversion,
|
||||
IsDns: true,
|
||||
L4Proto: consts.L4ProtoStr_UDP,
|
||||
IpVersion: consts.IpVersionFromAddr(realDst.Addr()),
|
||||
IsDns: true, // UDP relies on DNS check result.
|
||||
}
|
||||
if dialerForNew == nil {
|
||||
dialerForNew, _, err = outbound.Select(networkType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err)
|
||||
}
|
||||
dialerForNew, _, err := outbound.Select(networkType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to select dialer from group %v (%v, dns?:%v,from: %v): %w", outbound.Name, networkType.StringWithoutDns(), isDns, realSrc.String(), err)
|
||||
}
|
||||
|
||||
var isNew bool
|
||||
var realDialer *dialer.Dialer
|
||||
|
||||
udpHandler := c.WriteToUDP(lanWanFlag, lConn, realSrc, src, isDns, dummyFrom, func(from netip.AddrPort) bool {
|
||||
// We only validate rush-ans when outbound is direct and pkt does not send to a home device.
|
||||
// Because additional record OPT may not be supported by home router.
|
||||
// So se should trust home devices even if they make rush-answer (or looks like).
|
||||
return outboundIndex == consts.OutboundDirect && !common.ConvergeIp(from.Addr()).IsPrivate()
|
||||
})
|
||||
// Dial and send.
|
||||
// TODO: Rewritten domain should not use full-cone (such as VMess Packet Addr).
|
||||
// Maybe we should set up a mapping for UDP: Dialer + Target Domain => Remote Resolved IP.
|
||||
destToSend = netip.AddrPortFrom(common.ConvergeIp(destToSend.Addr()), destToSend.Port())
|
||||
tgtToSend := c.ChooseDialTarget(outboundIndex, destToSend, domain)
|
||||
switch l4proto {
|
||||
case consts.L4ProtoStr_UDP:
|
||||
// Get udp endpoint.
|
||||
var ue *UdpEndpoint
|
||||
retry := 0
|
||||
getNew:
|
||||
if retry > MaxRetry {
|
||||
return fmt.Errorf("touch max retry limit")
|
||||
}
|
||||
// However, games may not use QUIC for communication, thus we cannot use domain to dial, which is fine.
|
||||
dialTarget := c.ChooseDialTarget(outboundIndex, realDst, domain)
|
||||
|
||||
ue, isNew, err = DefaultUdpEndpointPool.GetOrCreate(realSrc, &UdpEndpointOptions{
|
||||
Handler: udpHandler,
|
||||
NatTimeout: natTimeout,
|
||||
Dialer: dialerForNew,
|
||||
Network: GetNetwork("udp", routingResult.Mark),
|
||||
Target: tgtToSend,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to GetOrCreate (policy: %v): %w", outbound.GetSelectionPolicy(), err)
|
||||
}
|
||||
// Get udp endpoint.
|
||||
var ue *UdpEndpoint
|
||||
retry := 0
|
||||
getNew:
|
||||
if retry > MaxRetry {
|
||||
return fmt.Errorf("touch max retry limit")
|
||||
}
|
||||
ue, isNew, err := DefaultUdpEndpointPool.GetOrCreate(realSrc, &UdpEndpointOptions{
|
||||
// Handler handles response packets and send it to the client.
|
||||
Handler: func(data []byte, from netip.AddrPort) (err error) {
|
||||
// Do not return conn-unrelated err in this func.
|
||||
return sendPkt(data, from, realSrc, src, lConn, lanWanFlag)
|
||||
},
|
||||
NatTimeout: natTimeout,
|
||||
Dialer: dialerForNew,
|
||||
Network: MagicNetwork("udp", routingResult.Mark),
|
||||
Target: dialTarget,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to GetOrCreate (policy: %v): %w", outbound.GetSelectionPolicy(), err)
|
||||
}
|
||||
|
||||
// If the udp endpoint has been not alive, remove it from pool and get a new one.
|
||||
if !isNew && outbound.GetSelectionPolicy() != consts.DialerSelectionPolicy_Fixed && !ue.Dialer.MustGetAlive(networkType) {
|
||||
// If the udp endpoint has been not alive, remove it from pool and get a new one.
|
||||
if !isNew && outbound.GetSelectionPolicy() != consts.DialerSelectionPolicy_Fixed && !ue.Dialer.MustGetAlive(networkType) {
|
||||
|
||||
if c.log.IsLevelEnabled(logrus.DebugLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"src": RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag),
|
||||
"network": networkType.String(),
|
||||
"dialer": ue.Dialer.Name(),
|
||||
"retry": retry,
|
||||
}).Debugln("Old udp endpoint was not alive and removed.")
|
||||
}
|
||||
_ = DefaultUdpEndpointPool.Remove(realSrc, ue)
|
||||
retry++
|
||||
goto getNew
|
||||
if c.log.IsLevelEnabled(logrus.DebugLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"src": RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag),
|
||||
"network": networkType.String(),
|
||||
"dialer": ue.Dialer.Name(),
|
||||
"retry": retry,
|
||||
}).Debugln("Old udp endpoint was not alive and removed.")
|
||||
}
|
||||
// This is real dialer.
|
||||
realDialer = ue.Dialer
|
||||
_ = DefaultUdpEndpointPool.Remove(realSrc, ue)
|
||||
retry++
|
||||
goto getNew
|
||||
}
|
||||
|
||||
_, err = ue.WriteTo(data, tgtToSend)
|
||||
if err != nil {
|
||||
if c.log.IsLevelEnabled(logrus.DebugLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"to": destToSend.String(),
|
||||
"domain": domain,
|
||||
"pid": routingResult.Pid,
|
||||
"pname": ProcessName2String(routingResult.Pname[:]),
|
||||
"mac": Mac2String(routingResult.Mac[:]),
|
||||
"from": realSrc.String(),
|
||||
"network": networkType.String(),
|
||||
"err": err.Error(),
|
||||
"retry": retry,
|
||||
}).Debugln("Failed to write UDP packet request. Try to remove old UDP endpoint and retry.")
|
||||
}
|
||||
_ = DefaultUdpEndpointPool.Remove(realSrc, ue)
|
||||
retry++
|
||||
goto getNew
|
||||
}
|
||||
case consts.L4ProtoStr_TCP:
|
||||
// MUST be DNS.
|
||||
if !isDns {
|
||||
return fmt.Errorf("UDP to TCP only support DNS request")
|
||||
}
|
||||
isNew = true
|
||||
realDialer = dialerForNew
|
||||
|
||||
// We can block because we are in a coroutine.
|
||||
|
||||
conn, err := dialerForNew.Dial(GetNetwork("tcp", routingResult.Mark), tgtToSend)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to dial proxy to tcp: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_ = conn.SetDeadline(time.Now().Add(natTimeout))
|
||||
// We should write two byte length in the front of TCP DNS request.
|
||||
bReq := pool.Get(2 + len(data))
|
||||
defer pool.Put(bReq)
|
||||
binary.BigEndian.PutUint16(bReq, uint16(len(data)))
|
||||
copy(bReq[2:], data)
|
||||
_, err = conn.Write(bReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write DNS req: %w", err)
|
||||
}
|
||||
|
||||
// Read two byte length.
|
||||
if _, err = io.ReadFull(conn, bReq[:2]); err != nil {
|
||||
return fmt.Errorf("failed to read DNS resp payload length: %w", err)
|
||||
}
|
||||
respLen := int(binary.BigEndian.Uint16(bReq))
|
||||
// Try to reuse the buf.
|
||||
var buf []byte
|
||||
if len(bReq) < respLen {
|
||||
buf = pool.Get(respLen)
|
||||
defer pool.Put(buf)
|
||||
} else {
|
||||
buf = bReq
|
||||
}
|
||||
var n int
|
||||
if n, err = io.ReadFull(conn, buf[:respLen]); err != nil {
|
||||
return fmt.Errorf("failed to read DNS resp payload: %w", err)
|
||||
}
|
||||
if err = udpHandler(buf[:n], destToSend); err != nil {
|
||||
return fmt.Errorf("failed to write DNS resp to client: %w", err)
|
||||
_, err = ue.WriteTo(data, dialTarget)
|
||||
if err != nil {
|
||||
if c.log.IsLevelEnabled(logrus.DebugLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"to": realDst.String(),
|
||||
"domain": domain,
|
||||
"pid": routingResult.Pid,
|
||||
"pname": ProcessName2String(routingResult.Pname[:]),
|
||||
"mac": Mac2String(routingResult.Mac[:]),
|
||||
"from": realSrc.String(),
|
||||
"network": networkType.StringWithoutDns(),
|
||||
"err": err.Error(),
|
||||
"retry": retry,
|
||||
}).Debugln("Failed to write UDP packet request. Try to remove old UDP endpoint and retry.")
|
||||
}
|
||||
_ = DefaultUdpEndpointPool.Remove(realSrc, ue)
|
||||
retry++
|
||||
goto getNew
|
||||
}
|
||||
|
||||
// Print log.
|
||||
if isNew || isDns {
|
||||
// Only print routing for new connection to avoid the log exploded (Quic and BT).
|
||||
if isDns && c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 {
|
||||
q := dnsMessage.Questions[0]
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"network": string(l4proto) + string(ipversion) + "(DNS)",
|
||||
// Only print routing for new connection to avoid the log exploded (Quic and BT).
|
||||
if isNew {
|
||||
if c.log.IsLevelEnabled(logrus.InfoLevel) {
|
||||
fields := logrus.Fields{
|
||||
"network": networkType.StringWithoutDns(),
|
||||
"outbound": outbound.Name,
|
||||
"policy": outbound.GetSelectionPolicy(),
|
||||
"dialer": realDialer.Name(),
|
||||
"qname": strings.ToLower(q.Name.String()),
|
||||
"qtype": q.Type,
|
||||
"pid": routingResult.Pid,
|
||||
"pname": ProcessName2String(routingResult.Pname[:]),
|
||||
"mac": Mac2String(routingResult.Mac[:]),
|
||||
}).Infof("%v <-> %v",
|
||||
RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), RefineAddrPortToShow(destToSend),
|
||||
)
|
||||
} else if c.log.IsLevelEnabled(logrus.InfoLevel) {
|
||||
if isDns && len(dnsMessage.Questions) > 0 {
|
||||
domain = strings.ToLower(dnsMessage.Questions[0].Name.String())
|
||||
}
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"network": string(l4proto) + string(ipversion),
|
||||
"outbound": outbound.Name,
|
||||
"policy": outbound.GetSelectionPolicy(),
|
||||
"dialer": realDialer.Name(),
|
||||
"dialer": ue.Dialer.Name(),
|
||||
"domain": domain,
|
||||
"pid": routingResult.Pid,
|
||||
"pname": ProcessName2String(routingResult.Pname[:]),
|
||||
"mac": Mac2String(routingResult.Mac[:]),
|
||||
}).Infof("%v <-> %v",
|
||||
RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), RefineAddrPortToShow(destToSend),
|
||||
)
|
||||
}
|
||||
c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(realSrc, realDst.Addr(), lanWanFlag), RefineAddrPortToShow(realDst))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -19,6 +19,32 @@ import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func (c *ControlPlane) Route(src, dst netip.AddrPort, domain string, l4proto consts.L4ProtoType, routingResult *bpfRoutingResult) (outboundIndex consts.OutboundIndex, mark uint32, err error) {
|
||||
var ipVersion consts.IpVersionType
|
||||
if dst.Addr().Is4() || dst.Addr().Is4In6() {
|
||||
ipVersion = consts.IpVersion_4
|
||||
} else {
|
||||
ipVersion = consts.IpVersion_6
|
||||
}
|
||||
bSrc := src.Addr().As16()
|
||||
bDst := dst.Addr().As16()
|
||||
if outboundIndex, mark, err = c.routingMatcher.Match(
|
||||
bSrc[:],
|
||||
bDst[:],
|
||||
src.Port(),
|
||||
dst.Port(),
|
||||
ipVersion,
|
||||
l4proto,
|
||||
domain,
|
||||
routingResult.Pname,
|
||||
append([]uint8{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, routingResult.Mac[:]...),
|
||||
); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return outboundIndex, mark, nil
|
||||
}
|
||||
|
||||
func (c *ControlPlaneCore) RetrieveRoutingResult(src, dst netip.AddrPort, l4proto uint8) (result *bpfRoutingResult, err error) {
|
||||
srcIp6 := src.Addr().As16()
|
||||
dstIp6 := dst.Addr().As16()
|
||||
@ -79,7 +105,7 @@ func CheckIpforward(ifname string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetNetwork(network string, mark uint32) string {
|
||||
func MagicNetwork(network string, mark uint32) string {
|
||||
if mark == 0 {
|
||||
return network
|
||||
} else {
|
||||
|
49
docs/dns.md
Normal file
49
docs/dns.md
Normal file
@ -0,0 +1,49 @@
|
||||
# DNS
|
||||
|
||||
## Examples:
|
||||
|
||||
```shell
|
||||
dns {
|
||||
upstream {
|
||||
# Value can be scheme://host:port.
|
||||
# Scheme list: tcp, udp, tcp+udp. Ongoing: https, tls, quic.
|
||||
# If host is a domain and has both IPv4 and IPv6 record, dae will automatically choose
|
||||
# IPv4 or IPv6 to use according to group policy (such as min latency policy).
|
||||
# Please make sure DNS traffic will go through and be forwarded by dae, which is REQUIRED for domain routing.
|
||||
# If dial_mode is "ip", the upstream DNS answer SHOULD NOT be polluted, so domestic public DNS is not recommended.
|
||||
|
||||
alidns: 'udp://dns.alidns.com:53'
|
||||
googledns: 'tcp+udp://dns.google:53'
|
||||
}
|
||||
# The routing format of 'request' and 'response' is similar with section 'routing'.
|
||||
# See https://github.com/v2rayA/dae/blob/main/docs/routing.md
|
||||
request {
|
||||
# Built-in upstream in 'request': asis.
|
||||
# You can also use user-defined upstreams.
|
||||
|
||||
# Available functions: qname, qtype.
|
||||
|
||||
# DNS request name (omit suffix dot '.').
|
||||
qname(suffix: abc.com, keyword: google) -> googledns
|
||||
qname(full: ok.com, regex: '^yes') -> googledns
|
||||
# DNS request type
|
||||
qtype(a, aaaa) -> alidns
|
||||
qtype(cname) -> googledns
|
||||
|
||||
# If no match, fallback to this upstream.
|
||||
fallback: asis
|
||||
}
|
||||
response {
|
||||
# No built-in upstream in 'response'.
|
||||
# You can use user-defined upstreams.
|
||||
|
||||
# Available functions: qname, qtype, upstream, ip.
|
||||
# Accept the response if the request is sent to upstream 'googledns'. This is useful to avoid loop.
|
||||
upstream(googledns) -> accept
|
||||
# If DNS request name is not in CN and response answers include private IP, which is most likely polluted
|
||||
# in China mainland. Therefore, resend DNS request to 'googledns' to get correct result.
|
||||
!qname(geosite:cn) && ip(geoip:private) -> googledns
|
||||
fallback: accept
|
||||
}
|
||||
}
|
||||
```
|
@ -1,4 +1,4 @@
|
||||
# routing
|
||||
# Routing
|
||||
|
||||
## Examples:
|
||||
|
33
example.dae
33
example.dae
@ -21,15 +21,6 @@ global {
|
||||
# Group will switch node only when new_latency <= old_latency - tolerance.
|
||||
check_tolerance: 50ms
|
||||
|
||||
# Value can be scheme://host:port or empty string ''.
|
||||
# The scheme can be tcp/udp/tcp+udp. Empty string '' indicates as-is.
|
||||
# If host is a domain and has both IPv4 and IPv6 record, dae will automatically choose
|
||||
# IPv4 or IPv6 to use according to group policy (such as min latency policy).
|
||||
# Please make sure DNS traffic will go through and be forwarded by dae, which is REQUIRED for domain routing.
|
||||
# The upstream DNS answer MUST NOT be polluted, so domestic public DNS is not recommended.
|
||||
# The request to DNS upstream follows the routing defined below.
|
||||
dns_upstream: 'udp://dns.alidns.com:53'
|
||||
|
||||
# The LAN interface to bind. Use it if you only want to proxy LAN instead of localhost.
|
||||
# Multiple interfaces split by ",".
|
||||
#lan_interface: docker0
|
||||
@ -79,6 +70,28 @@ node {
|
||||
'ss://LINK'
|
||||
}
|
||||
|
||||
# See more at https://github.com/v2rayA/dae/blob/main/docs/dns.md.
|
||||
dns {
|
||||
upstream {
|
||||
# Value can be scheme://host:port, where the scheme can be tcp/udp/tcp+udp.
|
||||
# If host is a domain and has both IPv4 and IPv6 record, dae will automatically choose
|
||||
# IPv4 or IPv6 to use according to group policy (such as min latency policy).
|
||||
# Please make sure DNS traffic will go through and be forwarded by dae, which is REQUIRED for domain routing.
|
||||
# If dial_mode is "ip", the upstream DNS answer SHOULD NOT be polluted, so domestic public DNS is not recommended.
|
||||
|
||||
alidns: 'udp://dns.alidns.com:53'
|
||||
googledns: 'tcp+udp://dns.google:53'
|
||||
}
|
||||
request {
|
||||
fallback: asis
|
||||
}
|
||||
response {
|
||||
upstream(googledns) -> accept
|
||||
!qname(geosite:cn) && ip(geoip:private) -> googledns
|
||||
fallback: accept
|
||||
}
|
||||
}
|
||||
|
||||
# Node group (outbound).
|
||||
group {
|
||||
my_group {
|
||||
@ -108,7 +121,7 @@ group {
|
||||
}
|
||||
}
|
||||
|
||||
# See routing.md for full examples.
|
||||
# See https://github.com/v2rayA/dae/blob/main/docs/routing.md for full examples.
|
||||
routing {
|
||||
### Preset rules.
|
||||
|
||||
|
3
go.mod
3
go.mod
@ -3,14 +3,13 @@ module github.com/v2rayA/dae
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/Asphaltt/lpmtrie v0.0.0-20220205153150-3d814250b8ab
|
||||
github.com/adrg/xdg v0.4.0
|
||||
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20221202181307-76fa05c21b12
|
||||
github.com/cilium/ebpf v0.10.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/json-iterator/go v1.1.12
|
||||
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
|
||||
github.com/mzz2017/softwind v0.0.0-20230220064728-6940dc11777f
|
||||
github.com/mzz2017/softwind v0.0.0-20230224125402-d460ce1c5b4b
|
||||
github.com/safchain/ethtool v0.0.0-20230116090318-67cc41908669
|
||||
github.com/sirupsen/logrus v1.9.0
|
||||
github.com/spf13/cobra v1.6.1
|
||||
|
6
go.sum
6
go.sum
@ -1,5 +1,3 @@
|
||||
github.com/Asphaltt/lpmtrie v0.0.0-20220205153150-3d814250b8ab h1:hzN25CB5VzeKk3/c1fi1oT03N+5365nVOMPAxixkADY=
|
||||
github.com/Asphaltt/lpmtrie v0.0.0-20220205153150-3d814250b8ab/go.mod h1:TdNTLzn3VVXKfmHAULK5gY+h/A1gLQ8NnwLB6cSN54g=
|
||||
github.com/adrg/xdg v0.4.0 h1:RzRqFcjH4nE5C6oTAxhBtoE2IRyjBSa62SCbyPidvls=
|
||||
github.com/adrg/xdg v0.4.0/go.mod h1:N6ag73EX4wyxeaoeHctc1mas01KZgsj5tYiAIwqJE/E=
|
||||
github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20221202181307-76fa05c21b12 h1:npHgfD4Tl2WJS3AJaMUi5ynGDPUBfkg3U3fCzDyXZ+4=
|
||||
@ -68,8 +66,8 @@ github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9
|
||||
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8=
|
||||
github.com/mzz2017/disk-bloom v1.0.1 h1:rEF9MiXd9qMW3ibRpqcerLXULoTgRlM21yqqJl1B90M=
|
||||
github.com/mzz2017/disk-bloom v1.0.1/go.mod h1:JLHETtUu44Z6iBmsqzkOtFlRvXSlKnxjwiBRDapizDI=
|
||||
github.com/mzz2017/softwind v0.0.0-20230220064728-6940dc11777f h1:Lmwy7FFI0PrWw0TgoQYtDiZBlCd/VZ1hBlySauTVWj4=
|
||||
github.com/mzz2017/softwind v0.0.0-20230220064728-6940dc11777f/go.mod h1:V8GFOtdpTgzCJtCVXRqjmdDsY+PIhCCx4JpD0zq8Z7I=
|
||||
github.com/mzz2017/softwind v0.0.0-20230224125402-d460ce1c5b4b h1:Do2nwPU6oKlZGBNUeTvyiNjFHRuOqAlunrQ+jwvSCJM=
|
||||
github.com/mzz2017/softwind v0.0.0-20230224125402-d460ce1c5b4b/go.mod h1:V8GFOtdpTgzCJtCVXRqjmdDsY+PIhCCx4JpD0zq8Z7I=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
|
@ -15,11 +15,12 @@ import (
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
var smallBufferSize = 16
|
||||
var defaultBufferSize = 64
|
||||
|
||||
// A Buffer is a variable-sized buffer of bytes with Read and Write methods.
|
||||
// The zero value for Buffer is an empty buffer ready to use.
|
||||
type Buffer[T constraints.Unsigned] struct {
|
||||
buf []T // contents are the bytes buf[off : len(buf)]
|
||||
buf []T // contents are the bytes buf[off : len(buf)]
|
||||
}
|
||||
|
||||
// ErrTooLarge is passed to panic if memory cannot be allocated to store data in a buffer.
|
||||
@ -158,7 +159,7 @@ func makeSlice[T constraints.Unsigned](n int) []T {
|
||||
// sufficient to initialize a Buffer.
|
||||
func NewBuffer[T constraints.Unsigned](size int) *Buffer[T] {
|
||||
if size == 0 {
|
||||
size = 512
|
||||
size = defaultBufferSize
|
||||
}
|
||||
return &Buffer[T]{buf: make([]T, 0, size)}
|
||||
}
|
||||
|
@ -181,7 +181,7 @@ type RoutingRule struct {
|
||||
Outbound Function
|
||||
}
|
||||
|
||||
func (r *RoutingRule) String(calcN bool) string {
|
||||
func (r *RoutingRule) String(replaceParamWithN bool) string {
|
||||
var builder strings.Builder
|
||||
var n int
|
||||
for i, f := range r.AndFunctions {
|
||||
@ -190,7 +190,7 @@ func (r *RoutingRule) String(calcN bool) string {
|
||||
}
|
||||
var paramBuilder strings.Builder
|
||||
n += len(f.Params)
|
||||
if calcN {
|
||||
if replaceParamWithN {
|
||||
paramBuilder.WriteString("[n = " + strconv.Itoa(n) + "]")
|
||||
} else {
|
||||
for j, param := range f.Params {
|
||||
|
@ -3,18 +3,26 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/v2rayA/dae/common"
|
||||
"encoding/binary"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Htons converts the unsigned short integer hostshort from host byte order to network byte order.
|
||||
func Htons(i uint16) uint16 {
|
||||
b := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(b, i)
|
||||
return *(*uint16)(unsafe.Pointer(&b[0]))
|
||||
}
|
||||
|
||||
func OpenRawSock(index int) (int, error) {
|
||||
sock, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, int(common.Htons(syscall.ETH_P_ALL)))
|
||||
sock, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, int(Htons(syscall.ETH_P_ALL)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sll := syscall.SockaddrLinklayer{
|
||||
Ifindex: index,
|
||||
Protocol: common.Htons(syscall.ETH_P_ALL),
|
||||
Protocol: Htons(syscall.ETH_P_ALL),
|
||||
}
|
||||
if err := syscall.Bind(sock, &sll); err != nil {
|
||||
return 0, err
|
||||
|
@ -1,5 +1,5 @@
|
||||
// Package trie is modified from https://github.com/openacid/succinct/blob/loc100/sskv.go.
|
||||
// Slower than about 50% but more memory saving.
|
||||
// Slower than about 30% but more than 40% memory saving.
|
||||
|
||||
package trie
|
||||
|
||||
@ -9,53 +9,30 @@ import (
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
var table = [256]byte{
|
||||
97: 0, // 'a'
|
||||
98: 1,
|
||||
99: 2,
|
||||
100: 3,
|
||||
101: 4,
|
||||
102: 5,
|
||||
103: 6,
|
||||
104: 7,
|
||||
105: 8,
|
||||
106: 9,
|
||||
107: 10,
|
||||
108: 11,
|
||||
109: 12,
|
||||
110: 13,
|
||||
111: 14,
|
||||
112: 15,
|
||||
113: 16,
|
||||
114: 17,
|
||||
115: 18,
|
||||
116: 19,
|
||||
117: 20,
|
||||
118: 21,
|
||||
119: 22,
|
||||
120: 23,
|
||||
121: 24,
|
||||
122: 25,
|
||||
'-': 26,
|
||||
'.': 27,
|
||||
'^': 28,
|
||||
'$': 29,
|
||||
'1': 30,
|
||||
'2': 31,
|
||||
'3': 32,
|
||||
'4': 33,
|
||||
'5': 34,
|
||||
'6': 35,
|
||||
'7': 36,
|
||||
'8': 37,
|
||||
'9': 38,
|
||||
'0': 39,
|
||||
type ValidChars struct {
|
||||
table [256]byte
|
||||
n uint16
|
||||
zeroChar byte
|
||||
}
|
||||
|
||||
const N = 40
|
||||
func NewValidChars(validChars []byte) (v *ValidChars) {
|
||||
v = new(ValidChars)
|
||||
for _, c := range validChars {
|
||||
if v.n == 0 {
|
||||
v.zeroChar = c
|
||||
}
|
||||
v.table[c] = byte(v.n)
|
||||
v.n++
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func IsValidChar(b byte) bool {
|
||||
return table[b] > 0 || b == 'a'
|
||||
func (v *ValidChars) Size() int {
|
||||
return int(v.n)
|
||||
}
|
||||
|
||||
func (v *ValidChars) IsValidChar(c byte) bool {
|
||||
return v.table[c] > 0 || c == v.zeroChar
|
||||
}
|
||||
|
||||
// Trie is a succinct, sorted and static string set impl with compacted trie as
|
||||
@ -103,22 +80,26 @@ type Trie struct {
|
||||
ranks, selects []int32
|
||||
labels *bitlist.CompactBitList
|
||||
ranksBL, selectsBL *bitlist.CompactBitList
|
||||
|
||||
chars *ValidChars
|
||||
}
|
||||
|
||||
// NewTrie creates a new *Trie struct, from a slice of sorted strings.
|
||||
func NewTrie(keys []string) (*Trie, error) {
|
||||
func NewTrie(keys []string, chars *ValidChars) (*Trie, error) {
|
||||
|
||||
// Check chars.
|
||||
for _, key := range keys {
|
||||
for _, c := range []byte(key) {
|
||||
if !IsValidChar(c) {
|
||||
if !chars.IsValidChar(c) {
|
||||
return nil, fmt.Errorf("char out of range: %c", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ss := &Trie{}
|
||||
ss.labels = bitlist.NewCompactBitList(bits.Len8(N))
|
||||
ss := &Trie{
|
||||
chars: chars,
|
||||
labels: bitlist.NewCompactBitList(bits.Len(uint(chars.Size()))),
|
||||
}
|
||||
lIdx := 0
|
||||
|
||||
type qElt struct{ s, e, col int }
|
||||
@ -142,7 +123,7 @@ func NewTrie(keys []string) (*Trie, error) {
|
||||
}
|
||||
|
||||
queue = append(queue, qElt{frm, j, elt.col + 1})
|
||||
ss.labels.Append(uint64(table[keys[frm][elt.col]]))
|
||||
ss.labels.Append(uint64(chars.table[keys[frm][elt.col]]))
|
||||
setBit(&ss.labelBitmap, lIdx, 0)
|
||||
lIdx++
|
||||
}
|
||||
@ -190,13 +171,16 @@ func (ss *Trie) HasPrefix(word string) bool {
|
||||
return true
|
||||
}
|
||||
c := word[i]
|
||||
if !ss.chars.IsValidChar(c) {
|
||||
return false
|
||||
}
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(ss.labelBitmap, bmIdx) != 0 {
|
||||
// no more labels in this node
|
||||
return false
|
||||
}
|
||||
|
||||
if byte(ss.labels.Get(bmIdx-nodeId)) == table[c] {
|
||||
if byte(ss.labels.Get(bmIdx-nodeId)) == ss.chars.table[c] {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ func TestTrie(t *testing.T) {
|
||||
"zib.fmc^",
|
||||
"zk.ytamlacbci.",
|
||||
"zk.ytamlacbci^",
|
||||
})
|
||||
}, NewValidChars([]byte("0123456789abcdefghijklmnopqrstuvwxyz-.^")))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -110,6 +110,9 @@ func TestTrie(t *testing.T) {
|
||||
if !(trie.HasPrefix("nc.^") == true) {
|
||||
t.Fatal("^.cn")
|
||||
}
|
||||
if !(trie.HasPrefix("nc._") == true) {
|
||||
t.Fatal("_.cn")
|
||||
}
|
||||
if !(trie.HasPrefix("n") == false) {
|
||||
t.Fatal("n")
|
||||
}
|
||||
|
Reference in New Issue
Block a user