refactor(dns): replace dnsmessage with miekg/dns (#188)

This commit is contained in:
mzz 2023-07-09 16:02:17 +08:00 committed by GitHub
parent b82b31e350
commit 00cf4bc3cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 327 additions and 427 deletions

View File

@ -12,14 +12,13 @@ import (
"io" "io"
"math" "math"
"net/netip" "net/netip"
"strings"
"sync" "sync"
"time" "time"
dnsmessage "github.com/miekg/dns"
"github.com/mzz2017/softwind/netproxy" "github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pkg/fastrand" "github.com/mzz2017/softwind/pkg/fastrand"
"github.com/mzz2017/softwind/pool" "github.com/mzz2017/softwind/pool"
"golang.org/x/net/dns/dnsmessage"
) )
var ( var (
@ -90,29 +89,37 @@ func SystemDns() (dns netip.AddrPort, err error) {
return systemDns, nil return systemDns, nil
} }
func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, network string) (addrs []netip.Addr, err error) { func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ uint16, network string) (addrs []netip.Addr, err error) {
resources, err := resolve(ctx, d, dns, host, typ, network) resources, err := resolve(ctx, d, dns, host, typ, network)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, ans := range resources { for _, ans := range resources {
if ans.Header.Type != typ { if ans.Header().Rrtype != typ {
continue continue
} }
var (
ip netip.Addr
okk bool
)
switch typ { switch typ {
case dnsmessage.TypeA: case dnsmessage.TypeA:
a, ok := ans.Body.(*dnsmessage.AResource) a, ok := ans.(*dnsmessage.A)
if !ok { if !ok {
return nil, BadDnsAnsError return nil, BadDnsAnsError
} }
addrs = append(addrs, netip.AddrFrom4(a.A)) ip, okk = netip.AddrFromSlice(a.A)
case dnsmessage.TypeAAAA: case dnsmessage.TypeAAAA:
a, ok := ans.Body.(*dnsmessage.AAAAResource) a, ok := ans.(*dnsmessage.AAAA)
if !ok { if !ok {
return nil, BadDnsAnsError return nil, BadDnsAnsError
} }
addrs = append(addrs, netip.AddrFrom16(a.AAAA)) ip, okk = netip.AddrFromSlice(a.AAAA)
} }
if !okk {
continue
}
addrs = append(addrs, ip)
} }
return addrs, nil return addrs, nil
} }
@ -124,50 +131,47 @@ func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host
return nil, err return nil, err
} }
for _, ans := range resources { for _, ans := range resources {
if ans.Header.Type != typ { if ans.Header().Rrtype != typ {
continue continue
} }
ns, ok := ans.Body.(*dnsmessage.NSResource) ns, ok := ans.(*dnsmessage.NS)
if !ok { if !ok {
return nil, BadDnsAnsError return nil, BadDnsAnsError
} }
records = append(records, ns.NS.String()) records = append(records, ns.Ns)
} }
return records, nil return records, nil
} }
func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, network string) (ans []dnsmessage.Resource, err error) { func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ uint16, network string) (ans []dnsmessage.RR, err error) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
fqdn := host fqdn := dnsmessage.CanonicalName(host)
if !strings.HasSuffix(fqdn, ".") {
fqdn += "."
}
switch typ { switch typ {
case dnsmessage.TypeA, dnsmessage.TypeAAAA: case dnsmessage.TypeA, dnsmessage.TypeAAAA:
if addr, err := netip.ParseAddr(host); err == nil { if addr, err := netip.ParseAddr(host); err == nil {
if (addr.Is4() || addr.Is4In6()) && typ == dnsmessage.TypeA { if (addr.Is4() || addr.Is4In6()) && typ == dnsmessage.TypeA {
return []dnsmessage.Resource{ return []dnsmessage.RR{
{ &dnsmessage.A{
Header: dnsmessage.ResourceHeader{ Hdr: dnsmessage.RR_Header{
Name: dnsmessage.MustNewName(fqdn), Name: dnsmessage.CanonicalName(fqdn),
Class: dnsmessage.ClassINET, Class: dnsmessage.ClassINET,
TTL: 0, Ttl: 0,
Type: typ, Rrtype: typ,
}, },
Body: &dnsmessage.AResource{A: addr.As4()}, A: addr.AsSlice(),
}, },
}, nil }, nil
} else if addr.Is6() && typ == dnsmessage.TypeAAAA { } else if addr.Is6() && typ == dnsmessage.TypeAAAA {
return []dnsmessage.Resource{ return []dnsmessage.RR{
{ &dnsmessage.AAAA{
Header: dnsmessage.ResourceHeader{ Hdr: dnsmessage.RR_Header{
Name: dnsmessage.MustNewName(fqdn), Name: dnsmessage.CanonicalName(fqdn),
Class: dnsmessage.ClassINET, Class: dnsmessage.ClassINET,
TTL: 0, Ttl: 0,
Type: typ, Rrtype: typ,
}, },
Body: &dnsmessage.AAAAResource{AAAA: addr.As16()}, AAAA: addr.AsSlice(),
}, },
}, nil }, nil
} }
@ -177,25 +181,18 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
default: default:
} }
// Build DNS req. // Build DNS req.
builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{ builder := dnsmessage.Msg{
ID: uint16(fastrand.Intn(math.MaxUint16 + 1)), MsgHdr: dnsmessage.MsgHdr{
Response: false, Id: uint16(fastrand.Intn(math.MaxUint16 + 1)),
OpCode: 0, Response: false,
Truncated: false, Opcode: 0,
RecursionDesired: true, Truncated: false,
Authoritative: false, RecursionDesired: true,
}) Authoritative: false,
if err = builder.StartQuestions(); err != nil { },
return nil, err
} }
if err = builder.Question(dnsmessage.Question{ builder.SetQuestion(fqdn, typ)
Name: dnsmessage.MustNewName(fqdn), b, err := builder.Pack()
Type: typ,
Class: dnsmessage.ClassINET,
}); err != nil {
return nil, err
}
b, err := builder.Finish()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -265,12 +262,12 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
return return
} }
// Resolve DNS response and extract A/AAAA record. // Resolve DNS response and extract A/AAAA record.
var msg dnsmessage.Message var msg dnsmessage.Msg
if err = msg.Unpack(buf[:n]); err != nil { if err = msg.Unpack(buf[:n]); err != nil {
ch <- err ch <- err
return return
} }
ans = msg.Answers ans = msg.Answer
ch <- nil ch <- nil
}() }()
select { select {

View File

@ -12,9 +12,9 @@ import (
"net/netip" "net/netip"
"sync" "sync"
dnsmessage "github.com/miekg/dns"
"github.com/mzz2017/softwind/netproxy" "github.com/mzz2017/softwind/netproxy"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
) )
type Ip46 struct { type Ip46 struct {

View File

@ -0,0 +1,22 @@
package netutils
import (
"context"
"net/netip"
"testing"
"time"
"github.com/mzz2017/softwind/protocol/direct"
)
func TestResolveIp46(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ip46, err := ResolveIp46(ctx, direct.SymmetricDirect, netip.MustParseAddrPort("223.5.5.5:53"), "www.apple.com", "udp", false)
if err != nil {
t.Fatal(err)
}
if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() {
t.Fatal("No record")
}
}

View File

@ -12,7 +12,6 @@ import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"github.com/mzz2017/softwind/netproxy"
"net/netip" "net/netip"
"net/url" "net/url"
"path/filepath" "path/filepath"
@ -22,9 +21,11 @@ import (
"time" "time"
"unsafe" "unsafe"
"github.com/mzz2017/softwind/netproxy"
internal "github.com/daeuniverse/dae/pkg/ebpf_internal" internal "github.com/daeuniverse/dae/pkg/ebpf_internal"
dnsmessage "github.com/miekg/dns"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/net/dns/dnsmessage"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -409,7 +410,7 @@ func NewGcm(key []byte) (cipher.AEAD, error) {
return cipher.NewGCM(block) return cipher.NewGCM(block)
} }
func AddrToDnsType(addr netip.Addr) dnsmessage.Type { func AddrToDnsType(addr netip.Addr) uint16 {
if addr.Is4() { if addr.Is4() {
return dnsmessage.TypeA return dnsmessage.TypeA
} else { } else {

View File

@ -16,8 +16,8 @@ import (
"github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/component/routing" "github.com/daeuniverse/dae/component/routing"
"github.com/daeuniverse/dae/config" "github.com/daeuniverse/dae/config"
dnsmessage "github.com/miekg/dns"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
) )
var BadUpstreamFormatError = fmt.Errorf("bad upstream format") var BadUpstreamFormatError = fmt.Errorf("bad upstream format")
@ -148,7 +148,7 @@ func (s *Dns) InitUpstreams() {
wg.Wait() wg.Wait()
} }
func (s *Dns) RequestSelect(qname string, qtype dnsmessage.Type) (upstreamIndex consts.DnsRequestOutboundIndex, upstream *Upstream, err error) { func (s *Dns) RequestSelect(qname string, qtype uint16) (upstreamIndex consts.DnsRequestOutboundIndex, upstream *Upstream, err error) {
// Route. // Route.
upstreamIndex, err = s.reqMatcher.Match(qname, qtype) upstreamIndex, err = s.reqMatcher.Match(qname, qtype)
if err != nil { if err != nil {
@ -170,29 +170,37 @@ func (s *Dns) RequestSelect(qname string, qtype dnsmessage.Type) (upstreamIndex
return upstreamIndex, upstream, nil return upstreamIndex, upstream, nil
} }
func (s *Dns) ResponseSelect(msg *dnsmessage.Message, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) { func (s *Dns) ResponseSelect(msg *dnsmessage.Msg, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) {
if !msg.Response { if !msg.Response {
return 0, nil, fmt.Errorf("DNS response expected but DNS request received") return 0, nil, fmt.Errorf("DNS response expected but DNS request received")
} }
// Prepare routing. // Prepare routing.
var qname string var qname string
var qtype dnsmessage.Type var qtype uint16
var ips []netip.Addr var ips []netip.Addr
if len(msg.Questions) == 0 { if len(msg.Question) == 0 {
qname = "" qname = ""
qtype = 0 qtype = 0
} else { } else {
q := msg.Questions[0] q := msg.Question[0]
qname = q.Name.String() qname = q.Name
qtype = q.Type qtype = q.Qtype
for _, ans := range msg.Answers { for _, ans := range msg.Answer {
switch body := ans.Body.(type) { var (
case *dnsmessage.AResource: ip netip.Addr
ips = append(ips, netip.AddrFrom4(body.A)) ok bool
case *dnsmessage.AAAAResource: )
ips = append(ips, netip.AddrFrom16(body.AAAA)) switch body := ans.(type) {
case *dnsmessage.A:
ip, ok = netip.AddrFromSlice(body.A)
case *dnsmessage.AAAA:
ip, ok = netip.AddrFromSlice(body.AAAA)
} }
if !ok {
continue
}
ips = append(ips, ip)
} }
} }

View File

@ -12,38 +12,20 @@ import (
"github.com/daeuniverse/dae/component/routing" "github.com/daeuniverse/dae/component/routing"
"github.com/daeuniverse/dae/pkg/config_parser" "github.com/daeuniverse/dae/pkg/config_parser"
dnsmessage "github.com/miekg/dns"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
) )
var typeNames = map[string]dnsmessage.Type{ func TypeParserFactory(callback func(f *config_parser.Function, types []uint16, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
"A": dnsmessage.TypeA,
"NS": dnsmessage.TypeNS,
"CNAME": dnsmessage.TypeCNAME,
"SOA": dnsmessage.TypeSOA,
"PTR": dnsmessage.TypePTR,
"MX": dnsmessage.TypeMX,
"TXT": dnsmessage.TypeTXT,
"AAAA": dnsmessage.TypeAAAA,
"SRV": dnsmessage.TypeSRV,
"OPT": dnsmessage.TypeOPT,
"WKS": dnsmessage.TypeWKS,
"HINFO": dnsmessage.TypeHINFO,
"MINFO": dnsmessage.TypeMINFO,
"AXFR": dnsmessage.TypeAXFR,
"ALL": dnsmessage.TypeALL,
}
func TypeParserFactory(callback func(f *config_parser.Function, types []dnsmessage.Type, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) { return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) {
var types []dnsmessage.Type var types []uint16
for _, v := range paramValueGroup { for _, v := range paramValueGroup {
if t, ok := typeNames[strings.ToUpper(v)]; ok { if t, ok := dnsmessage.StringToType[strings.ToUpper(v)]; ok {
types = append(types, t) types = append(types, t)
continue continue
} }
if val, err := strconv.ParseUint(v, 0, 16); err == nil { if val, err := strconv.ParseUint(v, 0, 16); err == nil {
types = append(types, dnsmessage.Type(val)) types = append(types, uint16(val))
continue continue
} }
return fmt.Errorf("unknown DNS request type: %v", v) return fmt.Errorf("unknown DNS request type: %v", v)

View File

@ -15,7 +15,6 @@ import (
"github.com/daeuniverse/dae/config" "github.com/daeuniverse/dae/config"
"github.com/daeuniverse/dae/pkg/config_parser" "github.com/daeuniverse/dae/pkg/config_parser"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
) )
type RequestMatcherBuilder struct { type RequestMatcherBuilder struct {
@ -88,7 +87,7 @@ func (b *RequestMatcherBuilder) addQName(f *config_parser.Function, key string,
return nil return nil
} }
func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) { func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []uint16, upstream *routing.Outbound) (err error) {
for i, value := range values { for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String() upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 { if i == len(values)-1 {
@ -166,7 +165,7 @@ type requestMatchSet struct {
func (m *RequestMatcher) Match( func (m *RequestMatcher) Match(
qName string, qName string,
qType dnsmessage.Type, qType uint16,
) (upstreamIndex consts.DnsRequestOutboundIndex, err error) { ) (upstreamIndex consts.DnsRequestOutboundIndex, err error) {
var domainMatchBitmap []uint32 var domainMatchBitmap []uint32
if qName != "" { if qName != "" {
@ -185,7 +184,7 @@ func (m *RequestMatcher) Match(
goodSubrule = true goodSubrule = true
} }
case consts.MatchType_QType: case consts.MatchType_QType:
if qType == dnsmessage.Type(match.Value) { if qType == match.Value {
goodSubrule = true goodSubrule = true
} }
case consts.MatchType_Fallback: case consts.MatchType_Fallback:

View File

@ -18,7 +18,6 @@ import (
"github.com/daeuniverse/dae/pkg/config_parser" "github.com/daeuniverse/dae/pkg/config_parser"
"github.com/daeuniverse/dae/pkg/trie" "github.com/daeuniverse/dae/pkg/trie"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
) )
type ResponseMatcherBuilder struct { type ResponseMatcherBuilder struct {
@ -138,7 +137,7 @@ func (b *ResponseMatcherBuilder) addUpstream(f *config_parser.Function, values [
return nil return nil
} }
func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) { func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []uint16, upstream *routing.Outbound) (err error) {
for i, value := range values { for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String() upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 { if i == len(values)-1 {
@ -219,7 +218,7 @@ type responseMatchSet struct {
func (m *ResponseMatcher) Match( func (m *ResponseMatcher) Match(
qName string, qName string,
qType dnsmessage.Type, qType uint16,
ips []netip.Addr, ips []netip.Addr,
upstream consts.DnsRequestOutboundIndex, upstream consts.DnsRequestOutboundIndex,
) (upstreamIndex consts.DnsResponseOutboundIndex, err error) { ) (upstreamIndex consts.DnsResponseOutboundIndex, err error) {
@ -253,7 +252,7 @@ func (m *ResponseMatcher) Match(
} }
} }
case consts.MatchType_QType: case consts.MatchType_QType:
if qType == dnsmessage.Type(match.Value) { if qType == uint16(match.Value) {
goodSubrule = true goodSubrule = true
} }
case consts.MatchType_Upstream: case consts.MatchType_Upstream:

View File

@ -25,11 +25,11 @@ import (
"github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/common/netutils" "github.com/daeuniverse/dae/common/netutils"
dnsmessage "github.com/miekg/dns"
"github.com/mzz2017/softwind/netproxy" "github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pkg/fastrand" "github.com/mzz2017/softwind/pkg/fastrand"
"github.com/mzz2017/softwind/protocol/direct" "github.com/mzz2017/softwind/protocol/direct"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
) )
type NetworkType struct { type NetworkType struct {

View File

@ -7,13 +7,14 @@ package socks
import ( import (
"context" "context"
"github.com/daeuniverse/dae/common/netutils"
"github.com/daeuniverse/dae/component/outbound/dialer"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
"github.com/daeuniverse/dae/common/netutils"
"github.com/daeuniverse/dae/component/outbound/dialer"
dnsmessage "github.com/miekg/dns"
"github.com/sirupsen/logrus"
) )
func TestSocks5(t *testing.T) { func TestSocks5(t *testing.T) {

View File

@ -30,8 +30,8 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
log := logger.NewLogger("trace", false) log := logger.NewLogger("trace", false)
option := &dialer.GlobalOption{ option := &dialer.GlobalOption{
Log: log, Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl}, TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns}, CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
CheckInterval: 15 * time.Second, CheckInterval: 15 * time.Second,
CheckTolerance: 0, CheckTolerance: 0,
CheckDnsTcp: false, CheckDnsTcp: false,
@ -46,7 +46,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
FixedIndex: fixedIndex, FixedIndex: fixedIndex,
}, func(alive bool, networkType *dialer.NetworkType, isInit bool) {}) }, func(alive bool, networkType *dialer.NetworkType, isInit bool) {})
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
d, _, err := g.Select(TestNetworkType) d, _, err := g.Select(TestNetworkType, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -58,7 +58,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
fixedIndex = 0 fixedIndex = 0
g.selectionPolicy.FixedIndex = fixedIndex g.selectionPolicy.FixedIndex = fixedIndex
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
d, _, err := g.Select(TestNetworkType) d, _, err := g.Select(TestNetworkType, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -73,8 +73,8 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
option := &dialer.GlobalOption{ option := &dialer.GlobalOption{
Log: log, Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl}, TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns}, CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
CheckInterval: 15 * time.Second, CheckInterval: 15 * time.Second,
} }
dialers := []*dialer.Dialer{ dialers := []*dialer.Dialer{
@ -120,7 +120,7 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
} }
g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(d, alive) g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(d, alive)
} }
d, _, err := g.Select(TestNetworkType) d, _, err := g.Select(TestNetworkType, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -143,8 +143,8 @@ func TestDialerGroup_Select_Random(t *testing.T) {
option := &dialer.GlobalOption{ option := &dialer.GlobalOption{
Log: log, Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl}, TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns}, CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
CheckInterval: 15 * time.Second, CheckInterval: 15 * time.Second,
} }
dialers := []*dialer.Dialer{ dialers := []*dialer.Dialer{
@ -159,7 +159,7 @@ func TestDialerGroup_Select_Random(t *testing.T) {
}, func(alive bool, networkType *dialer.NetworkType, isInit bool) {}) }, func(alive bool, networkType *dialer.NetworkType, isInit bool) {})
count := make([]int, len(dialers)) count := make([]int, len(dialers))
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
d, _, err := g.Select(TestNetworkType) d, _, err := g.Select(TestNetworkType, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -183,8 +183,8 @@ func TestDialerGroup_SetAlive(t *testing.T) {
option := &dialer.GlobalOption{ option := &dialer.GlobalOption{
Log: log, Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl}, TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns}, CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
CheckInterval: 15 * time.Second, CheckInterval: 15 * time.Second,
} }
dialers := []*dialer.Dialer{ dialers := []*dialer.Dialer{
@ -201,7 +201,7 @@ func TestDialerGroup_SetAlive(t *testing.T) {
g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(dialers[zeroTarget], false) g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(dialers[zeroTarget], false)
count := make([]int, len(dialers)) count := make([]int, len(dialers))
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
d, _, err := g.Select(TestNetworkType) d, _, err := g.Select(TestNetworkType, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -32,12 +32,12 @@ import (
"github.com/daeuniverse/dae/config" "github.com/daeuniverse/dae/config"
"github.com/daeuniverse/dae/pkg/config_parser" "github.com/daeuniverse/dae/pkg/config_parser"
internal "github.com/daeuniverse/dae/pkg/ebpf_internal" internal "github.com/daeuniverse/dae/pkg/ebpf_internal"
dnsmessage "github.com/miekg/dns"
"github.com/mohae/deepcopy" "github.com/mohae/deepcopy"
"github.com/mzz2017/softwind/pool" "github.com/mzz2017/softwind/pool"
"github.com/mzz2017/softwind/protocol/direct" "github.com/mzz2017/softwind/protocol/direct"
"github.com/mzz2017/softwind/transport/grpc" "github.com/mzz2017/softwind/transport/grpc"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -409,10 +409,10 @@ func NewControlPlane(
} }
return nil return nil
}, },
NewCache: func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error) { NewCache: func(fqdn string, answers []dnsmessage.RR, deadline time.Time) (cache *DnsCache, err error) {
return &DnsCache{ return &DnsCache{
DomainBitmap: plane.routingMatcher.domainMatcher.MatchDomainBitmap(fqdn), DomainBitmap: plane.routingMatcher.domainMatcher.MatchDomainBitmap(fqdn),
Answers: answers, Answer: answers,
Deadline: deadline, Deadline: deadline,
}, nil }, nil
}, },
@ -433,8 +433,13 @@ func NewControlPlane(
continue continue
} }
host := cacheKey[:lastDot] host := cacheKey[:lastDot]
typ := cacheKey[lastDot+1:] _typ := cacheKey[lastDot+1:]
_ = plane.dnsController.UpdateDnsCacheDeadline(host, typ, cache.Answers, cache.Deadline) typ, err := strconv.ParseUint(_typ, 10, 16)
if err != nil {
// Unexpected.
return nil, err
}
_ = plane.dnsController.UpdateDnsCacheDeadline(host, uint16(typ), cache.Answer, cache.Deadline)
} }
} else if _bpf != nil { } else if _bpf != nil {
// Is reloading, and dnsCache == nil. // Is reloading, and dnsCache == nil.
@ -509,43 +514,36 @@ func (c *ControlPlane) dnsUpstreamReadyCallback(dnsUpstream *dns.Upstream) (err
/// Updates dns cache to support domain routing for hostname of dns_upstream. /// Updates dns cache to support domain routing for hostname of dns_upstream.
// Ten years later. // Ten years later.
deadline := time.Now().Add(time.Hour * 24 * 365 * 10) deadline := time.Now().Add(time.Hour * 24 * 365 * 10)
fqdn := dnsUpstream.Hostname fqdn := dnsmessage.CanonicalName(dnsUpstream.Hostname)
if !strings.HasSuffix(fqdn, ".") {
fqdn = fqdn + "."
}
if dnsUpstream.Ip4.IsValid() { if dnsUpstream.Ip4.IsValid() {
typ := dnsmessage.TypeA typ := dnsmessage.TypeA
answers := []dnsmessage.Resource{{ answers := []dnsmessage.RR{&dnsmessage.A{
Header: dnsmessage.ResourceHeader{ Hdr: dnsmessage.RR_Header{
Name: dnsmessage.MustNewName(fqdn), Name: dnsmessage.CanonicalName(fqdn),
Type: typ, Rrtype: typ,
Class: dnsmessage.ClassINET, Class: dnsmessage.ClassINET,
TTL: 0, // Must be zero. Ttl: 0, // Must be zero.
},
Body: &dnsmessage.AResource{
A: dnsUpstream.Ip4.As4(),
}, },
A: dnsUpstream.Ip4.AsSlice(),
}} }}
if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ.String(), answers, deadline); err != nil { if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ, answers, deadline); err != nil {
return err return err
} }
} }
if dnsUpstream.Ip6.IsValid() { if dnsUpstream.Ip6.IsValid() {
typ := dnsmessage.TypeAAAA typ := dnsmessage.TypeAAAA
answers := []dnsmessage.Resource{{ answers := []dnsmessage.RR{&dnsmessage.AAAA{
Header: dnsmessage.ResourceHeader{ Hdr: dnsmessage.RR_Header{
Name: dnsmessage.MustNewName(fqdn), Name: dnsmessage.CanonicalName(fqdn),
Type: typ, Rrtype: typ,
Class: dnsmessage.ClassINET, Class: dnsmessage.ClassINET,
TTL: 0, // Must be zero. Ttl: 0, // Must be zero.
},
Body: &dnsmessage.AAAAResource{
AAAA: dnsUpstream.Ip6.As16(),
}, },
AAAA: dnsUpstream.Ip6.AsSlice(),
}} }}
if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ.String(), answers, deadline); err != nil { if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ, answers, deadline); err != nil {
return err return err
} }
} }

View File

@ -20,11 +20,11 @@ import (
"github.com/daeuniverse/dae/common" "github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/consts"
internal "github.com/daeuniverse/dae/pkg/ebpf_internal" internal "github.com/daeuniverse/dae/pkg/ebpf_internal"
dnsmessage "github.com/miekg/dns"
"github.com/mohae/deepcopy" "github.com/mohae/deepcopy"
"github.com/safchain/ethtool" "github.com/safchain/ethtool"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/net/dns/dnsmessage"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
@ -629,15 +629,18 @@ func (c *controlPlaneCore) _bindWan(ifname string) error {
func (c *controlPlaneCore) BatchUpdateDomainRouting(cache *DnsCache) error { func (c *controlPlaneCore) BatchUpdateDomainRouting(cache *DnsCache) error {
// Parse ips from DNS resp answers. // Parse ips from DNS resp answers.
var ips []netip.Addr var ips []netip.Addr
for _, ans := range cache.Answers { for _, ans := range cache.Answer {
var ip netip.Addr var (
switch ans.Header.Type { ip netip.Addr
case dnsmessage.TypeA: ok bool
ip = netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A) )
case dnsmessage.TypeAAAA: switch body := ans.(type) {
ip = netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA) case *dnsmessage.A:
ip, ok = netip.AddrFromSlice(body.A)
case *dnsmessage.AAAA:
ip, ok = netip.AddrFromSlice(body.AAAA)
} }
if ip.IsUnspecified() { if !ok || ip.IsUnspecified() {
continue continue
} }
ips = append(ips, ip) ips = append(ips, ip)
@ -672,15 +675,18 @@ func (c *controlPlaneCore) BatchUpdateDomainRouting(cache *DnsCache) error {
func (c *controlPlaneCore) BatchRemoveDomainRouting(cache *DnsCache) error { func (c *controlPlaneCore) BatchRemoveDomainRouting(cache *DnsCache) error {
// Parse ips from DNS resp answers. // Parse ips from DNS resp answers.
var ips []netip.Addr var ips []netip.Addr
for _, ans := range cache.Answers { for _, ans := range cache.Answer {
var ip netip.Addr var (
switch ans.Header.Type { ip netip.Addr
case dnsmessage.TypeA: ok bool
ip = netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A) )
case dnsmessage.TypeAAAA: switch body := ans.(type) {
ip = netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA) case *dnsmessage.A:
ip, ok = netip.AddrFromSlice(body.A)
case *dnsmessage.AAAA:
ip, ok = netip.AddrFromSlice(body.AAAA)
} }
if ip.IsUnspecified() { if !ok || ip.IsUnspecified() {
continue continue
} }
ips = append(ips, ip) ips = append(ips, ip)

View File

@ -9,49 +9,39 @@ import (
"net/netip" "net/netip"
"time" "time"
dnsmessage "github.com/miekg/dns"
"github.com/mohae/deepcopy" "github.com/mohae/deepcopy"
"golang.org/x/net/dns/dnsmessage"
) )
type DnsCache struct { type DnsCache struct {
DomainBitmap []uint32 DomainBitmap []uint32
Answers []dnsmessage.Resource Answer []dnsmessage.RR
Deadline time.Time Deadline time.Time
} }
func (c *DnsCache) FillInto(req *dnsmessage.Message) { func (c *DnsCache) FillInto(req *dnsmessage.Msg) {
req.Answers = deepcopy.Copy(c.Answers).([]dnsmessage.Resource) req.Answer = deepcopy.Copy(c.Answer).([]dnsmessage.RR)
// No need to align because of no flipping now. req.Rcode = dnsmessage.RcodeSuccess
//// Align question and answer Name.
//if len(req.Questions) > 0 {
// q := req.Questions[0]
// for i := range req.Answers {
// if strings.EqualFold(req.Answers[i].Header.Name.String(), q.Name.String()) {
// req.Answers[i].Header.Name.Data = q.Name.Data
// }
// }
//}
req.RCode = dnsmessage.RCodeSuccess
req.Response = true req.Response = true
req.RecursionAvailable = true req.RecursionAvailable = true
req.Truncated = false req.Truncated = false
} }
func (c *DnsCache) IncludeIp(ip netip.Addr) bool { func (c *DnsCache) IncludeIp(ip netip.Addr) bool {
for _, ans := range c.Answers { for _, ans := range c.Answer {
switch body := ans.Body.(type) { switch body := ans.(type) {
case *dnsmessage.AResource: case *dnsmessage.A:
if !ip.Is4() { if !ip.Is4() {
continue continue
} }
if netip.AddrFrom4(body.A) == ip { if a, ok := netip.AddrFromSlice(body.A); ok && a == ip {
return true return true
} }
case *dnsmessage.AAAAResource: case *dnsmessage.AAAA:
if !ip.Is6() { if !ip.Is6() {
continue continue
} }
if netip.AddrFrom16(body.AAAA) == ip { if a, ok := netip.AddrFromSlice(body.AAAA); ok && a == ip {
return true return true
} }
} }
@ -60,9 +50,9 @@ func (c *DnsCache) IncludeIp(ip netip.Addr) bool {
} }
func (c *DnsCache) IncludeAnyIp() bool { func (c *DnsCache) IncludeAnyIp() bool {
for _, ans := range c.Answers { for _, ans := range c.Answer {
switch ans.Body.(type) { switch ans.(type) {
case *dnsmessage.AResource, *dnsmessage.AAAAResource: case *dnsmessage.A, *dnsmessage.AAAA:
return true return true
} }
} }

View File

@ -8,12 +8,12 @@ package control
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"math" "math"
"net" "net"
"net/netip" "net/netip"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -25,12 +25,12 @@ import (
"github.com/daeuniverse/dae/component/dns" "github.com/daeuniverse/dae/component/dns"
"github.com/daeuniverse/dae/component/outbound" "github.com/daeuniverse/dae/component/outbound"
"github.com/daeuniverse/dae/component/outbound/dialer" "github.com/daeuniverse/dae/component/outbound/dialer"
dnsmessage "github.com/miekg/dns"
"github.com/mohae/deepcopy" "github.com/mohae/deepcopy"
"github.com/mzz2017/softwind/netproxy" "github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pkg/fastrand" "github.com/mzz2017/softwind/pkg/fastrand"
"github.com/mzz2017/softwind/pool" "github.com/mzz2017/softwind/pool"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
) )
const ( const (
@ -47,7 +47,6 @@ const (
) )
var ( var (
SuspectedRushAnswerError = fmt.Errorf("suspected DNS rush-answer")
UnsupportedQuestionTypeError = fmt.Errorf("unsupported question type") UnsupportedQuestionTypeError = fmt.Errorf("unsupported question type")
) )
@ -60,7 +59,7 @@ type DnsControllerOption struct {
Log *logrus.Logger Log *logrus.Logger
CacheAccessCallback func(cache *DnsCache) (err error) CacheAccessCallback func(cache *DnsCache) (err error)
CacheRemoveCallback func(cache *DnsCache) (err error) CacheRemoveCallback func(cache *DnsCache) (err error)
NewCache func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error) NewCache func(fqdn string, answers []dnsmessage.RR, deadline time.Time) (cache *DnsCache, err error)
BestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error) BestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error)
IpVersionPrefer int IpVersionPrefer int
FixedDomainTtl map[string]int FixedDomainTtl map[string]int
@ -70,12 +69,12 @@ type DnsController struct {
handling sync.Map handling sync.Map
routing *dns.Dns routing *dns.Dns
qtypePrefer dnsmessage.Type qtypePrefer uint16
log *logrus.Logger log *logrus.Logger
cacheAccessCallback func(cache *DnsCache) (err error) cacheAccessCallback func(cache *DnsCache) (err error)
cacheRemoveCallback func(cache *DnsCache) (err error) cacheRemoveCallback func(cache *DnsCache) (err error)
newCache func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error) newCache func(fqdn string, answers []dnsmessage.RR, deadline time.Time) (cache *DnsCache, err error)
bestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error) bestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error)
fixedDomainTtl map[string]int fixedDomainTtl map[string]int
@ -84,7 +83,7 @@ type DnsController struct {
dnsCache map[string]*DnsCache dnsCache map[string]*DnsCache
} }
func parseIpVersionPreference(prefer int) (dnsmessage.Type, error) { func parseIpVersionPreference(prefer int) (uint16, error) {
switch prefer := IpVersionPrefer(prefer); prefer { switch prefer := IpVersionPrefer(prefer); prefer {
case IpVersionPrefer_No: case IpVersionPrefer_No:
return 0, nil return 0, nil
@ -120,15 +119,12 @@ func NewDnsController(routing *dns.Dns, option *DnsControllerOption) (c *DnsCont
}, nil }, nil
} }
func (c *DnsController) cacheKey(qname string, qtype dnsmessage.Type) string { func (c *DnsController) cacheKey(qname string, qtype uint16) string {
// To fqdn. // To fqdn.
if !strings.HasSuffix(qname, ".") { return dnsmessage.CanonicalName(qname) + strconv.Itoa(int(qtype))
qname = qname + "."
}
return strings.ToLower(qname) + qtype.String()
} }
func (c *DnsController) RemoveDnsRespCache(qname string, qtype dnsmessage.Type) { func (c *DnsController) RemoveDnsRespCache(qname string, qtype uint16) {
c.dnsCacheMu.Lock() c.dnsCacheMu.Lock()
key := c.cacheKey(qname, qtype) key := c.cacheKey(qname, qtype)
_, ok := c.dnsCache[key] _, ok := c.dnsCache[key]
@ -137,7 +133,7 @@ func (c *DnsController) RemoveDnsRespCache(qname string, qtype dnsmessage.Type)
} }
c.dnsCacheMu.Unlock() c.dnsCacheMu.Unlock()
} }
func (c *DnsController) LookupDnsRespCache(qname string, qtype dnsmessage.Type) (cache *DnsCache) { func (c *DnsController) LookupDnsRespCache(qname string, qtype uint16) (cache *DnsCache) {
c.dnsCacheMu.Lock() c.dnsCacheMu.Lock()
cache, ok := c.dnsCache[c.cacheKey(qname, qtype)] cache, ok := c.dnsCache[c.cacheKey(qname, qtype)]
c.dnsCacheMu.Unlock() c.dnsCacheMu.Unlock()
@ -150,15 +146,15 @@ func (c *DnsController) LookupDnsRespCache(qname string, qtype dnsmessage.Type)
} }
// LookupDnsRespCache_ will modify the msg in place. // LookupDnsRespCache_ will modify the msg in place.
func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Message) (resp []byte) { func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Msg) (resp []byte) {
if len(msg.Questions) == 0 { if len(msg.Question) == 0 {
return nil return nil
} }
q := msg.Questions[0] q := msg.Question[0]
if msg.Response { if msg.Response {
return nil return nil
} }
cache := c.LookupDnsRespCache(q.Name.String(), q.Type) cache := c.LookupDnsRespCache(q.Name, q.Qtype)
if cache != nil { if cache != nil {
cache.FillInto(msg) cache.FillInto(msg)
b, err := msg.Pack() b, err := msg.Pack()
@ -176,28 +172,28 @@ func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Message) (resp []byt
} }
// DnsRespHandler handle DNS resp. // DnsRespHandler handle DNS resp.
func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMsg *dnsmessage.Message, err error) { func (c *DnsController) DnsRespHandler(data []byte) (newMsg *dnsmessage.Msg, err error) {
var msg dnsmessage.Message var msg dnsmessage.Msg
if err = msg.Unpack(data); err != nil { if err = msg.Unpack(data); err != nil {
return nil, fmt.Errorf("unpack dns pkt: %w", err) return nil, fmt.Errorf("unpack dns pkt: %w", err)
} }
// Check healthy resp. // Check healthy resp.
if !msg.Response || len(msg.Questions) == 0 { if !msg.Response || len(msg.Question) == 0 {
return &msg, nil return &msg, nil
} }
q := msg.Questions[0] q := msg.Question[0]
// Check suc resp. // Check suc resp.
if msg.RCode != dnsmessage.RCodeSuccess { if msg.Rcode != dnsmessage.RcodeSuccess {
return &msg, nil return &msg, nil
} }
// Get TTL. // Get TTL.
var ttl uint32 var ttl uint32
for i := range msg.Answers { for i := range msg.Answer {
if ttl == 0 { if ttl == 0 {
ttl = msg.Answers[i].Header.TTL ttl = msg.Answer[i].Header().Ttl
break break
} }
} }
@ -207,7 +203,7 @@ func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMs
} }
// Check req type. // Check req type.
switch q.Type { switch q.Qtype {
case dnsmessage.TypeA, dnsmessage.TypeAAAA: case dnsmessage.TypeA, dnsmessage.TypeAAAA:
default: default:
// Update DnsCache. // Update DnsCache.
@ -218,17 +214,17 @@ func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMs
} }
// Set ttl. // Set ttl.
for i := range msg.Answers { for i := range msg.Answer {
// Set TTL = zero. This requests applications must resend every request. // Set TTL = zero. This requests applications must resend every request.
// However, it may be not defined in the standard. // However, it may be not defined in the standard.
msg.Answers[i].Header.TTL = 0 msg.Answer[i].Header().Ttl = 0
} }
// Check if request A/AAAA record. // Check if request A/AAAA record.
var reqIpRecord bool var reqIpRecord bool
loop: loop:
for i := range msg.Questions { for i := range msg.Question {
switch msg.Questions[i].Type { switch msg.Question[i].Qtype {
case dnsmessage.TypeA, dnsmessage.TypeAAAA: case dnsmessage.TypeA, dnsmessage.TypeAAAA:
reqIpRecord = true reqIpRecord = true
break loop break loop
@ -242,23 +238,6 @@ loop:
return &msg, nil return &msg, nil
} }
if validateRushAns {
exist, e := EnsureAdditionalOpt(&msg, false)
if e != nil && !errors.Is(e, UnsupportedQuestionTypeError) {
c.log.Warnf("EnsureAdditionalOpt: %v", e)
}
if e == nil && !exist {
// Additional record OPT in the request was ensured, and in normal case the resp should also set it.
// This DNS packet may be a rush-answer, and we should reject it.
c.log.WithFields(logrus.Fields{
"ques": q,
"addition": FormatDnsRsc(msg.Additionals),
"ans": FormatDnsRsc(msg.Answers),
}).Traceln("DNS rush-answer detected")
return nil, SuspectedRushAnswerError
}
}
// Update DnsCache. // Update DnsCache.
if err = c.updateDnsCache(&msg, ttl, &q); err != nil { if err = c.updateDnsCache(&msg, ttl, &q); err != nil {
return nil, err return nil, err
@ -267,31 +246,29 @@ loop:
return &msg, nil return &msg, nil
} }
func (c *DnsController) updateDnsCache(msg *dnsmessage.Message, ttl uint32, q *dnsmessage.Question) error { func (c *DnsController) updateDnsCache(msg *dnsmessage.Msg, ttl uint32, q *dnsmessage.Question) error {
// Update DnsCache. // Update DnsCache.
if c.log.IsLevelEnabled(logrus.TraceLevel) { if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{ c.log.WithFields(logrus.Fields{
"_qname": q.Name, "_qname": q.Name,
"rcode": msg.RCode, "rcode": msg.Rcode,
"ans": FormatDnsRsc(msg.Answers), "ans": FormatDnsRsc(msg.Answer),
"auth": FormatDnsRsc(msg.Authorities),
"addition": FormatDnsRsc(msg.Additionals),
}).Tracef("Update DNS record cache") }).Tracef("Update DNS record cache")
} }
if err := c.UpdateDnsCacheTtl(q.Name.String(), q.Type.String(), msg.Answers, int(ttl)); err != nil { if err := c.UpdateDnsCacheTtl(q.Name, q.Qtype, msg.Answer, int(ttl)); err != nil {
return err return err
} }
return nil return nil
} }
func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp string, answers []dnsmessage.Resource, deadlineFunc func(now time.Time, host string) time.Time) (err error) { func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp uint16, answers []dnsmessage.RR, deadlineFunc func(now time.Time, host string) time.Time) (err error) {
var fqdn string var fqdn string
if strings.HasSuffix(host, ".") { if strings.HasSuffix(host, ".") {
fqdn = host fqdn = strings.ToLower(host)
host = host[:len(host)-1] host = host[:len(host)-1]
} else { } else {
fqdn = host + "." fqdn = dnsmessage.CanonicalName(host)
} }
// Bypass pure IP. // Bypass pure IP.
if _, err = netip.ParseAddr(host); err == nil { if _, err = netip.ParseAddr(host); err == nil {
@ -301,11 +278,11 @@ func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp string, ans
now := time.Now() now := time.Now()
deadline := deadlineFunc(now, host) deadline := deadlineFunc(now, host)
cacheKey := fqdn + dnsTyp cacheKey := c.cacheKey(fqdn, dnsTyp)
c.dnsCacheMu.Lock() c.dnsCacheMu.Lock()
cache, ok := c.dnsCache[cacheKey] cache, ok := c.dnsCache[cacheKey]
if ok { if ok {
cache.Answers = answers cache.Answer = answers
cache.Deadline = deadline cache.Deadline = deadline
c.dnsCacheMu.Unlock() c.dnsCacheMu.Unlock()
} else { } else {
@ -324,7 +301,7 @@ func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp string, ans
return nil return nil
} }
func (c *DnsController) UpdateDnsCacheDeadline(host string, dnsTyp string, answers []dnsmessage.Resource, deadline time.Time) (err error) { func (c *DnsController) UpdateDnsCacheDeadline(host string, dnsTyp uint16, answers []dnsmessage.RR, deadline time.Time) (err error) {
return c.__updateDnsCacheDeadline(host, dnsTyp, answers, func(now time.Time, host string) time.Time { return c.__updateDnsCacheDeadline(host, dnsTyp, answers, func(now time.Time, host string) time.Time {
if fixedTtl, ok := c.fixedDomainTtl[host]; ok { if fixedTtl, ok := c.fixedDomainTtl[host]; ok {
/// NOTICE: Cannot set TTL accurately. /// NOTICE: Cannot set TTL accurately.
@ -336,7 +313,7 @@ func (c *DnsController) UpdateDnsCacheDeadline(host string, dnsTyp string, answe
}) })
} }
func (c *DnsController) UpdateDnsCacheTtl(host string, dnsTyp string, answers []dnsmessage.Resource, ttl int) (err error) { func (c *DnsController) UpdateDnsCacheTtl(host string, dnsTyp uint16, answers []dnsmessage.RR, ttl int) (err error) {
return c.__updateDnsCacheDeadline(host, dnsTyp, answers, func(now time.Time, host string) time.Time { return c.__updateDnsCacheDeadline(host, dnsTyp, answers, func(now time.Time, host string) time.Time {
if fixedTtl, ok := c.fixedDomainTtl[host]; ok { if fixedTtl, ok := c.fixedDomainTtl[host]; ok {
return now.Add(time.Duration(fixedTtl) * time.Second) return now.Add(time.Duration(fixedTtl) * time.Second)
@ -346,27 +323,16 @@ func (c *DnsController) UpdateDnsCacheTtl(host string, dnsTyp string, answers []
}) })
} }
func (c *DnsController) DnsRespHandlerFactory(validateRushAnsFunc func(from netip.AddrPort) bool) func(data []byte, from netip.AddrPort) (msg *dnsmessage.Message, err error) { func (c *DnsController) DnsRespHandlerFactory() func(data []byte, from netip.AddrPort) (msg *dnsmessage.Msg, err error) {
return func(data []byte, from netip.AddrPort) (msg *dnsmessage.Message, err error) { return func(data []byte, from netip.AddrPort) (msg *dnsmessage.Msg, err error) {
// Do not return conn-unrelated err in this func. // Do not return conn-unrelated err in this func.
validateRushAns := validateRushAnsFunc(from) msg, err = c.DnsRespHandler(data)
msg, err = c.DnsRespHandler(data, validateRushAns)
if err != nil { if err != nil {
if errors.Is(err, SuspectedRushAnswerError) { if c.log.IsLevelEnabled(logrus.DebugLevel) {
if validateRushAns { c.log.Debugf("DnsRespHandler: %v", err)
// Reject DNS rush-answer.
c.log.WithFields(logrus.Fields{
"from": from,
}).Tracef("DNS rush-answer rejected")
return nil, nil
}
} else {
if c.log.IsLevelEnabled(logrus.DebugLevel) {
c.log.Debugf("DnsRespHandler: %v", err)
}
return nil, err
} }
return nil, err
} }
return msg, nil return msg, nil
} }
@ -390,11 +356,11 @@ type dialArgument struct {
mark uint32 mark uint32
} }
func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) (err error) { func (c *DnsController) Handle_(dnsMessage *dnsmessage.Msg, req *udpRequest) (err error) {
if c.log.IsLevelEnabled(logrus.TraceLevel) && len(dnsMessage.Questions) > 0 { if c.log.IsLevelEnabled(logrus.TraceLevel) && len(dnsMessage.Question) > 0 {
q := dnsMessage.Questions[0] q := dnsMessage.Question[0]
c.log.Tracef("Received UDP(DNS) %v <-> %v: %v %v", c.log.Tracef("Received UDP(DNS) %v <-> %v: %v %v",
RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), req.realDst.String(), strings.ToLower(q.Name.String()), q.Type, RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), req.realDst.String(), strings.ToLower(q.Name), QtypeToString(q.Qtype),
) )
} }
@ -404,10 +370,10 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest)
// Prepare qname, qtype. // Prepare qname, qtype.
var qname string var qname string
var qtype dnsmessage.Type var qtype uint16
if len(dnsMessage.Questions) != 0 { if len(dnsMessage.Question) != 0 {
qname = dnsMessage.Questions[0].Name.String() qname = dnsMessage.Question[0].Name
qtype = dnsMessage.Questions[0].Type qtype = dnsMessage.Question[0].Qtype
} }
// Check ip version preference and qtype. // Check ip version preference and qtype.
@ -421,9 +387,9 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest)
} }
// Try to make both A and AAAA lookups. // Try to make both A and AAAA lookups.
dnsMessage2 := deepcopy.Copy(dnsMessage).(*dnsmessage.Message) dnsMessage2 := deepcopy.Copy(dnsMessage).(*dnsmessage.Msg)
dnsMessage2.ID = uint16(fastrand.Intn(math.MaxUint16)) dnsMessage2.Id = uint16(fastrand.Intn(math.MaxUint16))
var qtype2 dnsmessage.Type var qtype2 uint16
switch qtype { switch qtype {
case dnsmessage.TypeA: case dnsmessage.TypeA:
qtype2 = dnsmessage.TypeAAAA qtype2 = dnsmessage.TypeAAAA
@ -432,7 +398,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest)
default: default:
return fmt.Errorf("unexpected qtype path") return fmt.Errorf("unexpected qtype path")
} }
dnsMessage2.Questions[0].Type = qtype2 dnsMessage2.Question[0].Qtype = qtype2
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
@ -452,7 +418,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest)
// resp is not valid. // resp is not valid.
c.log.WithFields(logrus.Fields{ c.log.WithFields(logrus.Fields{
"qname": qname, "qname": qname,
}).Tracef("Reject %v due to resp not valid", qtype.String()) }).Tracef("Reject %v due to resp not valid", qtype)
return c.sendReject_(dnsMessage, req) return c.sendReject_(dnsMessage, req)
} }
// resp is valid. // resp is valid.
@ -465,25 +431,19 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest)
} }
func (c *DnsController) handle_( func (c *DnsController) handle_(
dnsMessage *dnsmessage.Message, dnsMessage *dnsmessage.Msg,
req *udpRequest, req *udpRequest,
needResp bool, needResp bool,
) (err error) { ) (err error) {
// Prepare qname, qtype. // Prepare qname, qtype.
var qname string var qname string
var qtype dnsmessage.Type var qtype uint16
if len(dnsMessage.Questions) != 0 { if len(dnsMessage.Question) != 0 {
q := dnsMessage.Questions[0] q := dnsMessage.Question[0]
qname = q.Name.String() qname = q.Name
qtype = q.Type qtype = q.Qtype
} }
//// NOTICE: Rush-answer detector was removed because it does not always work in all districts.
//// Make sure there is additional record OPT in the request to filter DNS rush-answer in the response process.
//// Because rush-answer has no resp OPT. We can distinguish them from multiple responses.
//// Note that additional record OPT may not be supported by home router either.
//_, _ = EnsureAdditionalOpt(dnsMessage, true)
// Route request. // Route request.
upstreamIndex, upstream, err := c.routing.RequestSelect(qname, qtype) upstreamIndex, upstream, err := c.routing.RequestSelect(qname, qtype)
if err != nil { if err != nil {
@ -509,10 +469,10 @@ func (c *DnsController) handle_(
return fmt.Errorf("failed to write cached DNS resp: %w", err) return fmt.Errorf("failed to write cached DNS resp: %w", err)
} }
} }
if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 { if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Question) > 0 {
q := dnsMessage.Questions[0] q := dnsMessage.Question[0]
c.log.Debugf("UDP(DNS) %v <-> Cache: %v %v", c.log.Debugf("UDP(DNS) %v <-> Cache: %v %v",
RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), strings.ToLower(q.Name.String()), q.Type, RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), strings.ToLower(q.Name), QtypeToString(q.Qtype),
) )
} }
return nil return nil
@ -524,7 +484,7 @@ func (c *DnsController) handle_(
upstreamName = upstream.String() upstreamName = upstream.String()
} }
c.log.WithFields(logrus.Fields{ c.log.WithFields(logrus.Fields{
"question": dnsMessage.Questions, "question": dnsMessage.Question,
"upstream": upstreamName, "upstream": upstreamName,
}).Traceln("Request to DNS upstream") }).Traceln("Request to DNS upstream")
} }
@ -534,44 +494,44 @@ func (c *DnsController) handle_(
if err != nil { if err != nil {
return fmt.Errorf("pack DNS packet: %w", err) return fmt.Errorf("pack DNS packet: %w", err)
} }
return c.dialSend(0, req, data, dnsMessage.ID, upstream, needResp) return c.dialSend(0, req, data, dnsMessage.Id, upstream, needResp)
} }
// sendReject_ send empty answer. // sendReject_ send empty answer.
func (c *DnsController) sendReject_(dnsMessage *dnsmessage.Message, req *udpRequest) (err error) { func (c *DnsController) sendReject_(dnsMessage *dnsmessage.Msg, req *udpRequest) (err error) {
dnsMessage.Answers = nil dnsMessage.Answer = nil
if len(dnsMessage.Questions) > 0 { if len(dnsMessage.Question) > 0 {
q := dnsMessage.Questions[0] q := dnsMessage.Question[0]
switch typ := q.Type; typ { switch typ := q.Qtype; typ {
case dnsmessage.TypeA: case dnsmessage.TypeA:
dnsMessage.Answers = []dnsmessage.Resource{{ dnsMessage.Answer = []dnsmessage.RR{&dnsmessage.A{
Header: dnsmessage.ResourceHeader{ Hdr: dnsmessage.RR_Header{
Name: q.Name, Name: q.Name,
Type: typ, Rrtype: typ,
Class: dnsmessage.ClassINET, Class: dnsmessage.ClassINET,
TTL: 0, Ttl: 0,
}, },
Body: &dnsmessage.AResource{A: UnspecifiedAddressA.As4()}, A: UnspecifiedAddressA.AsSlice(),
}} }}
case dnsmessage.TypeAAAA: case dnsmessage.TypeAAAA:
dnsMessage.Answers = []dnsmessage.Resource{{ dnsMessage.Answer = []dnsmessage.RR{&dnsmessage.AAAA{
Header: dnsmessage.ResourceHeader{ Hdr: dnsmessage.RR_Header{
Name: q.Name, Name: q.Name,
Type: typ, Rrtype: typ,
Class: dnsmessage.ClassINET, Class: dnsmessage.ClassINET,
TTL: 0, Ttl: 0,
}, },
Body: &dnsmessage.AAAAResource{AAAA: UnspecifiedAddressAAAA.As16()}, AAAA: UnspecifiedAddressAAAA.AsSlice(),
}} }}
} }
} }
dnsMessage.RCode = dnsmessage.RCodeSuccess dnsMessage.Rcode = dnsmessage.RcodeSuccess
dnsMessage.Response = true dnsMessage.Response = true
dnsMessage.RecursionAvailable = true dnsMessage.RecursionAvailable = true
dnsMessage.Truncated = false dnsMessage.Truncated = false
if c.log.IsLevelEnabled(logrus.TraceLevel) { if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{ c.log.WithFields(logrus.Fields{
"question": dnsMessage.Questions, "question": dnsMessage.Question,
}).Traceln("Reject") }).Traceln("Reject")
} }
data, err := dnsMessage.Pack() data, err := dnsMessage.Pack()
@ -623,21 +583,9 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
} }
// dnsRespHandler caches dns response and check rush answers. // dnsRespHandler caches dns response and check rush answers.
dnsRespHandler := c.DnsRespHandlerFactory(func(from netip.AddrPort) bool { dnsRespHandler := c.DnsRespHandlerFactory()
//// NOTICE: Rush-answer detector was removed because it does not always work in all districts.
//// We only validate rush-ans when outbound is direct and pkt does not send to a home device.
//// Because additional record OPT may not be supported by home router.
//// So se should trust home devices even if they make rush-answer (or looks like).
//return dialArgument.bestDialer.Property().Name == "direct" &&
// !from.Addr().IsPrivate() &&
// !from.Addr().IsLoopback() &&
// !from.Addr().IsUnspecified()
// Do not validate rush-answer.
return false
})
// Dial and send. // Dial and send.
var respMsg *dnsmessage.Message var respMsg *dnsmessage.Msg
// defer in a recursive call will delay Close(), thus we Close() before // defer in a recursive call will delay Close(), thus we Close() before
// the next recursive call. However, a connection cannot be closed twice. // the next recursive call. However, a connection cannot be closed twice.
// We should set a connClosed flag to avoid it. // We should set a connClosed flag to avoid it.
@ -774,23 +722,23 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
// Accept. // Accept.
if c.log.IsLevelEnabled(logrus.TraceLevel) { if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{ c.log.WithFields(logrus.Fields{
"question": respMsg.Questions, "question": respMsg.Question,
"upstream": upstreamName, "upstream": upstreamName,
}).Traceln("Accept") }).Traceln("Accept")
} }
case consts.DnsResponseOutboundIndex_Reject: case consts.DnsResponseOutboundIndex_Reject:
// Reject the request with empty answer. // Reject the request with empty answer.
respMsg.Answers = nil respMsg.Answer = nil
if c.log.IsLevelEnabled(logrus.TraceLevel) { if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{ c.log.WithFields(logrus.Fields{
"question": respMsg.Questions, "question": respMsg.Question,
"upstream": upstreamName, "upstream": upstreamName,
}).Traceln("Reject with empty answer") }).Traceln("Reject with empty answer")
} }
default: default:
if c.log.IsLevelEnabled(logrus.TraceLevel) { if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{ c.log.WithFields(logrus.Fields{
"question": respMsg.Questions, "question": respMsg.Question,
"last_upstream": upstreamName, "last_upstream": upstreamName,
"next_upstream": nextUpstream.String(), "next_upstream": nextUpstream.String(),
}).Traceln("Change DNS upstream and resend") }).Traceln("Change DNS upstream and resend")
@ -798,11 +746,14 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
return c.dialSend(invokingDepth+1, req, data, id, nextUpstream, needResp) return c.dialSend(invokingDepth+1, req, data, id, nextUpstream, needResp)
} }
if upstreamIndex.IsReserved() && c.log.IsLevelEnabled(logrus.InfoLevel) { if upstreamIndex.IsReserved() && c.log.IsLevelEnabled(logrus.InfoLevel) {
var qname, qtype string var (
if len(respMsg.Questions) > 0 { qname string
q := respMsg.Questions[0] qtype string
qname = strings.ToLower(q.Name.String()) )
qtype = q.Type.String() if len(respMsg.Question) > 0 {
q := respMsg.Question[0]
qname = strings.ToLower(q.Name)
qtype = QtypeToString(q.Qtype)
} }
fields := logrus.Fields{ fields := logrus.Fields{
"network": networkType.String(), "network": networkType.String(),
@ -825,7 +776,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
} }
} }
// Keep the id the same with request. // Keep the id the same with request.
respMsg.ID = id respMsg.Id = id
data, err = respMsg.Pack() data, err = respMsg.Pack()
if err != nil { if err != nil {
return err return err

View File

@ -6,98 +6,44 @@
package control package control
import ( import (
"encoding/binary"
"fmt" "fmt"
"hash/fnv" "strconv"
"math/rand"
"net/netip"
"strings" "strings"
"golang.org/x/net/dns/dnsmessage" dnsmessage "github.com/miekg/dns"
) )
// FlipDnsQuestionCase is used to reduce dns pollution.
func FlipDnsQuestionCase(dm *dnsmessage.Message) {
if len(dm.Questions) == 0 {
return
}
q := &dm.Questions[0]
// For reproducibility, we use dm.ID as input and add some entropy to make the results more discrete.
h := fnv.New64()
var buf [4]byte
binary.BigEndian.PutUint16(buf[:], dm.ID)
h.Write(buf[:2])
binary.BigEndian.PutUint32(buf[:], 20230204) // entropy
h.Write(buf[:])
r := rand.New(rand.NewSource(int64(h.Sum64())))
perm := r.Perm(int(q.Name.Length))
for i := 0; i < int(q.Name.Length/3); i++ {
j := perm[i]
// Upper to lower; lower to upper.
if q.Name.Data[j] >= 'a' && q.Name.Data[j] <= 'z' {
q.Name.Data[j] -= 'a' - 'A'
} else if q.Name.Data[j] >= 'A' && q.Name.Data[j] <= 'Z' {
q.Name.Data[j] += 'a' - 'A'
}
}
}
// EnsureAdditionalOpt makes sure there is additional record OPT in the request.
func EnsureAdditionalOpt(dm *dnsmessage.Message, isReqAdd bool) (bool, error) {
// Check healthy resp.
if isReqAdd == dm.Response || dm.RCode != dnsmessage.RCodeSuccess || len(dm.Questions) == 0 {
return false, UnsupportedQuestionTypeError
}
q := dm.Questions[0]
switch q.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
default:
return false, UnsupportedQuestionTypeError
}
for _, ad := range dm.Additionals {
if ad.Header.Type == dnsmessage.TypeOPT {
// Already has additional record OPT.
return true, nil
}
}
if !isReqAdd {
return false, nil
}
// Add one.
dm.Additionals = append(dm.Additionals, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName("."),
Type: dnsmessage.TypeOPT,
Class: 512, TTL: 0, Length: 0,
},
Body: &dnsmessage.OPTResource{
Options: nil,
},
})
return false, nil
}
type RscWrapper struct { type RscWrapper struct {
Rsc dnsmessage.Resource Rsc dnsmessage.RR
} }
func (w RscWrapper) String() string { func (w RscWrapper) String() string {
var strBody string var strBody string
switch body := w.Rsc.Body.(type) { switch body := w.Rsc.(type) {
case *dnsmessage.AResource: case *dnsmessage.A:
strBody = netip.AddrFrom4(body.A).String() strBody = body.A.String()
case *dnsmessage.AAAAResource: case *dnsmessage.AAAA:
strBody = netip.AddrFrom16(body.AAAA).String() strBody = body.AAAA.String()
case *dnsmessage.CNAME:
strBody = body.Target
default: default:
strBody = body.GoString() strBody = body.String()
} }
return fmt.Sprintf("%v(%v): %v", w.Rsc.Header.Name.String(), w.Rsc.Header.Type.String(), strBody) return fmt.Sprintf("%v(%v): %v", w.Rsc.Header().Name, QtypeToString(w.Rsc.Header().Rrtype), strBody)
} }
func FormatDnsRsc(ans []dnsmessage.Resource) string {
func FormatDnsRsc(ans []dnsmessage.RR) string {
var w []string var w []string
for _, a := range ans { for _, a := range ans {
w = append(w, RscWrapper{Rsc: a}.String()) w = append(w, RscWrapper{Rsc: a}.String())
} }
return strings.Join(w, "; ") return strings.Join(w, "; ")
} }
func QtypeToString(qtype uint16) string {
str, ok := dnsmessage.TypeToString[qtype]
if !ok {
str = strconv.Itoa(int(qtype))
}
return str
}

View File

@ -20,9 +20,9 @@ import (
"github.com/daeuniverse/dae/component/outbound/dialer" "github.com/daeuniverse/dae/component/outbound/dialer"
"github.com/daeuniverse/dae/component/sniffing" "github.com/daeuniverse/dae/component/sniffing"
internal "github.com/daeuniverse/dae/pkg/ebpf_internal" internal "github.com/daeuniverse/dae/pkg/ebpf_internal"
dnsmessage "github.com/miekg/dns"
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer" "github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
) )
const ( const (
@ -31,11 +31,11 @@ const (
MaxRetry = 2 MaxRetry = 2
) )
func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Message, timeout time.Duration) { func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout time.Duration) {
if sniffDns { if sniffDns {
var dnsmsg dnsmessage.Message var dnsmsg dnsmessage.Msg
if err := dnsmsg.Unpack(data); err == nil { if err := dnsmsg.Unpack(data); err == nil {
//log.Printf("DEBUG: lookup %v", dnsmsg.Questions[0].Name) //log.Printf("DEBUG: lookup %v", dnsmsg.Question[0].Name)
return &dnsmsg, DnsNatTimeout return &dnsmsg, DnsNatTimeout
} }
} }

View File

@ -6,7 +6,6 @@
package control package control
import ( import (
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"sync" "sync"
@ -47,9 +46,6 @@ func (ue *UdpEndpoint) start() {
ue.deadlineTimer.Reset(ue.NatTimeout) ue.deadlineTimer.Reset(ue.NatTimeout)
ue.mu.Unlock() ue.mu.Unlock()
if err = ue.handler(buf[:n], from); err != nil { if err = ue.handler(buf[:n], from); err != nil {
if errors.Is(err, SuspectedRushAnswerError) {
continue
}
break break
} }
} }

3
go.mod
View File

@ -10,6 +10,7 @@ require (
github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/json-iterator/go v1.1.12 github.com/json-iterator/go v1.1.12
github.com/miekg/dns v1.1.55
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/mzz2017/softwind v0.0.0-20230708102709-26ff44839573 github.com/mzz2017/softwind v0.0.0-20230708102709-26ff44839573
github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd
@ -21,7 +22,6 @@ require (
github.com/x-cray/logrus-prefixed-formatter v0.5.2 github.com/x-cray/logrus-prefixed-formatter v0.5.2
golang.org/x/crypto v0.11.0 golang.org/x/crypto v0.11.0
golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
golang.org/x/net v0.12.0
golang.org/x/sys v0.10.0 golang.org/x/sys v0.10.0
google.golang.org/protobuf v1.31.0 google.golang.org/protobuf v1.31.0
) )
@ -38,6 +38,7 @@ require (
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect github.com/quic-go/qtls-go1-19 v0.3.2 // indirect
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
golang.org/x/mod v0.12.0 // indirect golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.12.0 // indirect
golang.org/x/tools v0.11.0 // indirect golang.org/x/tools v0.11.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230706204954-ccb25ca9f130 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230706204954-ccb25ca9f130 // indirect
) )

3
go.sum
View File

@ -78,6 +78,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI=
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -175,6 +177,7 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=