From 00cf4bc3cd84b76e9a1c1f75655aec3e6ccde13c Mon Sep 17 00:00:00 2001 From: mzz <2017@duck.com> Date: Sun, 9 Jul 2023 16:02:17 +0800 Subject: [PATCH] refactor(dns): replace dnsmessage with miekg/dns (#188) --- common/netutils/dns.go | 101 ++++--- common/netutils/ip46.go | 2 +- common/netutils/ip46_test.go | 22 ++ common/utils.go | 7 +- component/dns/dns.go | 36 ++- component/dns/function_parser.go | 28 +- component/dns/request_routing.go | 7 +- component/dns/response_routing.go | 7 +- .../outbound/dialer/connectivity_check.go | 2 +- component/outbound/dialer/socks/socks_test.go | 9 +- component/outbound/dialer_group_test.go | 26 +- control/control_plane.go | 56 ++-- control/control_plane_core.go | 40 +-- control/dns_cache.go | 38 +-- control/dns_control.go | 257 +++++++----------- control/dns_utils.go | 98 ++----- control/udp.go | 8 +- control/udp_endpoint.go | 4 - go.mod | 3 +- go.sum | 3 + 20 files changed, 327 insertions(+), 427 deletions(-) create mode 100644 common/netutils/ip46_test.go diff --git a/common/netutils/dns.go b/common/netutils/dns.go index b38e450..9b93838 100644 --- a/common/netutils/dns.go +++ b/common/netutils/dns.go @@ -12,14 +12,13 @@ import ( "io" "math" "net/netip" - "strings" "sync" "time" + dnsmessage "github.com/miekg/dns" "github.com/mzz2017/softwind/netproxy" "github.com/mzz2017/softwind/pkg/fastrand" "github.com/mzz2017/softwind/pool" - "golang.org/x/net/dns/dnsmessage" ) var ( @@ -90,29 +89,37 @@ func SystemDns() (dns netip.AddrPort, err error) { return systemDns, nil } -func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, network string) (addrs []netip.Addr, err error) { +func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ uint16, network string) (addrs []netip.Addr, err error) { resources, err := resolve(ctx, d, dns, host, typ, network) if err != nil { return nil, err } for _, ans := range resources { - if ans.Header.Type != typ { + if ans.Header().Rrtype != typ { continue } + var ( + ip netip.Addr + okk bool + ) switch typ { case dnsmessage.TypeA: - a, ok := ans.Body.(*dnsmessage.AResource) + a, ok := ans.(*dnsmessage.A) if !ok { return nil, BadDnsAnsError } - addrs = append(addrs, netip.AddrFrom4(a.A)) + ip, okk = netip.AddrFromSlice(a.A) case dnsmessage.TypeAAAA: - a, ok := ans.Body.(*dnsmessage.AAAAResource) + a, ok := ans.(*dnsmessage.AAAA) if !ok { return nil, BadDnsAnsError } - addrs = append(addrs, netip.AddrFrom16(a.AAAA)) + ip, okk = netip.AddrFromSlice(a.AAAA) } + if !okk { + continue + } + addrs = append(addrs, ip) } return addrs, nil } @@ -124,50 +131,47 @@ func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host return nil, err } for _, ans := range resources { - if ans.Header.Type != typ { + if ans.Header().Rrtype != typ { continue } - ns, ok := ans.Body.(*dnsmessage.NSResource) + ns, ok := ans.(*dnsmessage.NS) if !ok { return nil, BadDnsAnsError } - records = append(records, ns.NS.String()) + records = append(records, ns.Ns) } return records, nil } -func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, network string) (ans []dnsmessage.Resource, err error) { +func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ uint16, network string) (ans []dnsmessage.RR, err error) { ctx, cancel := context.WithCancel(ctx) defer cancel() - fqdn := host - if !strings.HasSuffix(fqdn, ".") { - fqdn += "." - } + fqdn := dnsmessage.CanonicalName(host) switch typ { case dnsmessage.TypeA, dnsmessage.TypeAAAA: if addr, err := netip.ParseAddr(host); err == nil { if (addr.Is4() || addr.Is4In6()) && typ == dnsmessage.TypeA { - return []dnsmessage.Resource{ - { - Header: dnsmessage.ResourceHeader{ - Name: dnsmessage.MustNewName(fqdn), - Class: dnsmessage.ClassINET, - TTL: 0, - Type: typ, + return []dnsmessage.RR{ + &dnsmessage.A{ + Hdr: dnsmessage.RR_Header{ + Name: dnsmessage.CanonicalName(fqdn), + Class: dnsmessage.ClassINET, + Ttl: 0, + Rrtype: typ, }, - Body: &dnsmessage.AResource{A: addr.As4()}, + A: addr.AsSlice(), }, }, nil } else if addr.Is6() && typ == dnsmessage.TypeAAAA { - return []dnsmessage.Resource{ - { - Header: dnsmessage.ResourceHeader{ - Name: dnsmessage.MustNewName(fqdn), - Class: dnsmessage.ClassINET, - TTL: 0, - Type: typ, + return []dnsmessage.RR{ + &dnsmessage.AAAA{ + Hdr: dnsmessage.RR_Header{ + Name: dnsmessage.CanonicalName(fqdn), + Class: dnsmessage.ClassINET, + Ttl: 0, + Rrtype: typ, }, - Body: &dnsmessage.AAAAResource{AAAA: addr.As16()}, + AAAA: addr.AsSlice(), }, }, nil } @@ -177,25 +181,18 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st default: } // Build DNS req. - builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{ - ID: uint16(fastrand.Intn(math.MaxUint16 + 1)), - Response: false, - OpCode: 0, - Truncated: false, - RecursionDesired: true, - Authoritative: false, - }) - if err = builder.StartQuestions(); err != nil { - return nil, err + builder := dnsmessage.Msg{ + MsgHdr: dnsmessage.MsgHdr{ + Id: uint16(fastrand.Intn(math.MaxUint16 + 1)), + Response: false, + Opcode: 0, + Truncated: false, + RecursionDesired: true, + Authoritative: false, + }, } - if err = builder.Question(dnsmessage.Question{ - Name: dnsmessage.MustNewName(fqdn), - Type: typ, - Class: dnsmessage.ClassINET, - }); err != nil { - return nil, err - } - b, err := builder.Finish() + builder.SetQuestion(fqdn, typ) + b, err := builder.Pack() if err != nil { return nil, err } @@ -265,12 +262,12 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st return } // Resolve DNS response and extract A/AAAA record. - var msg dnsmessage.Message + var msg dnsmessage.Msg if err = msg.Unpack(buf[:n]); err != nil { ch <- err return } - ans = msg.Answers + ans = msg.Answer ch <- nil }() select { diff --git a/common/netutils/ip46.go b/common/netutils/ip46.go index 6d2f007..3d8ac90 100644 --- a/common/netutils/ip46.go +++ b/common/netutils/ip46.go @@ -12,9 +12,9 @@ import ( "net/netip" "sync" + dnsmessage "github.com/miekg/dns" "github.com/mzz2017/softwind/netproxy" "github.com/sirupsen/logrus" - "golang.org/x/net/dns/dnsmessage" ) type Ip46 struct { diff --git a/common/netutils/ip46_test.go b/common/netutils/ip46_test.go new file mode 100644 index 0000000..3f0b9d5 --- /dev/null +++ b/common/netutils/ip46_test.go @@ -0,0 +1,22 @@ +package netutils + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/mzz2017/softwind/protocol/direct" +) + +func TestResolveIp46(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + ip46, err := ResolveIp46(ctx, direct.SymmetricDirect, netip.MustParseAddrPort("223.5.5.5:53"), "www.apple.com", "udp", false) + if err != nil { + t.Fatal(err) + } + if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() { + t.Fatal("No record") + } +} diff --git a/common/utils.go b/common/utils.go index 2067903..80a0505 100644 --- a/common/utils.go +++ b/common/utils.go @@ -12,7 +12,6 @@ import ( "encoding/binary" "encoding/hex" "fmt" - "github.com/mzz2017/softwind/netproxy" "net/netip" "net/url" "path/filepath" @@ -22,9 +21,11 @@ import ( "time" "unsafe" + "github.com/mzz2017/softwind/netproxy" + internal "github.com/daeuniverse/dae/pkg/ebpf_internal" + dnsmessage "github.com/miekg/dns" "github.com/vishvananda/netlink" - "golang.org/x/net/dns/dnsmessage" "golang.org/x/sys/unix" ) @@ -409,7 +410,7 @@ func NewGcm(key []byte) (cipher.AEAD, error) { return cipher.NewGCM(block) } -func AddrToDnsType(addr netip.Addr) dnsmessage.Type { +func AddrToDnsType(addr netip.Addr) uint16 { if addr.Is4() { return dnsmessage.TypeA } else { diff --git a/component/dns/dns.go b/component/dns/dns.go index f997dd0..f1ed9d5 100644 --- a/component/dns/dns.go +++ b/component/dns/dns.go @@ -16,8 +16,8 @@ import ( "github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/component/routing" "github.com/daeuniverse/dae/config" + dnsmessage "github.com/miekg/dns" "github.com/sirupsen/logrus" - "golang.org/x/net/dns/dnsmessage" ) var BadUpstreamFormatError = fmt.Errorf("bad upstream format") @@ -148,7 +148,7 @@ func (s *Dns) InitUpstreams() { wg.Wait() } -func (s *Dns) RequestSelect(qname string, qtype dnsmessage.Type) (upstreamIndex consts.DnsRequestOutboundIndex, upstream *Upstream, err error) { +func (s *Dns) RequestSelect(qname string, qtype uint16) (upstreamIndex consts.DnsRequestOutboundIndex, upstream *Upstream, err error) { // Route. upstreamIndex, err = s.reqMatcher.Match(qname, qtype) if err != nil { @@ -170,29 +170,37 @@ func (s *Dns) RequestSelect(qname string, qtype dnsmessage.Type) (upstreamIndex return upstreamIndex, upstream, nil } -func (s *Dns) ResponseSelect(msg *dnsmessage.Message, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) { +func (s *Dns) ResponseSelect(msg *dnsmessage.Msg, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) { if !msg.Response { return 0, nil, fmt.Errorf("DNS response expected but DNS request received") } // Prepare routing. var qname string - var qtype dnsmessage.Type + var qtype uint16 var ips []netip.Addr - if len(msg.Questions) == 0 { + if len(msg.Question) == 0 { qname = "" qtype = 0 } else { - q := msg.Questions[0] - qname = q.Name.String() - qtype = q.Type - for _, ans := range msg.Answers { - switch body := ans.Body.(type) { - case *dnsmessage.AResource: - ips = append(ips, netip.AddrFrom4(body.A)) - case *dnsmessage.AAAAResource: - ips = append(ips, netip.AddrFrom16(body.AAAA)) + q := msg.Question[0] + qname = q.Name + qtype = q.Qtype + for _, ans := range msg.Answer { + var ( + ip netip.Addr + ok bool + ) + switch body := ans.(type) { + case *dnsmessage.A: + ip, ok = netip.AddrFromSlice(body.A) + case *dnsmessage.AAAA: + ip, ok = netip.AddrFromSlice(body.AAAA) } + if !ok { + continue + } + ips = append(ips, ip) } } diff --git a/component/dns/function_parser.go b/component/dns/function_parser.go index f4fee4c..a025c9c 100644 --- a/component/dns/function_parser.go +++ b/component/dns/function_parser.go @@ -12,38 +12,20 @@ import ( "github.com/daeuniverse/dae/component/routing" "github.com/daeuniverse/dae/pkg/config_parser" + dnsmessage "github.com/miekg/dns" "github.com/sirupsen/logrus" - "golang.org/x/net/dns/dnsmessage" ) -var typeNames = map[string]dnsmessage.Type{ - "A": dnsmessage.TypeA, - "NS": dnsmessage.TypeNS, - "CNAME": dnsmessage.TypeCNAME, - "SOA": dnsmessage.TypeSOA, - "PTR": dnsmessage.TypePTR, - "MX": dnsmessage.TypeMX, - "TXT": dnsmessage.TypeTXT, - "AAAA": dnsmessage.TypeAAAA, - "SRV": dnsmessage.TypeSRV, - "OPT": dnsmessage.TypeOPT, - "WKS": dnsmessage.TypeWKS, - "HINFO": dnsmessage.TypeHINFO, - "MINFO": dnsmessage.TypeMINFO, - "AXFR": dnsmessage.TypeAXFR, - "ALL": dnsmessage.TypeALL, -} - -func TypeParserFactory(callback func(f *config_parser.Function, types []dnsmessage.Type, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser { +func TypeParserFactory(callback func(f *config_parser.Function, types []uint16, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser { return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) { - var types []dnsmessage.Type + var types []uint16 for _, v := range paramValueGroup { - if t, ok := typeNames[strings.ToUpper(v)]; ok { + if t, ok := dnsmessage.StringToType[strings.ToUpper(v)]; ok { types = append(types, t) continue } if val, err := strconv.ParseUint(v, 0, 16); err == nil { - types = append(types, dnsmessage.Type(val)) + types = append(types, uint16(val)) continue } return fmt.Errorf("unknown DNS request type: %v", v) diff --git a/component/dns/request_routing.go b/component/dns/request_routing.go index 69e439e..b2334d9 100644 --- a/component/dns/request_routing.go +++ b/component/dns/request_routing.go @@ -15,7 +15,6 @@ import ( "github.com/daeuniverse/dae/config" "github.com/daeuniverse/dae/pkg/config_parser" "github.com/sirupsen/logrus" - "golang.org/x/net/dns/dnsmessage" ) type RequestMatcherBuilder struct { @@ -88,7 +87,7 @@ func (b *RequestMatcherBuilder) addQName(f *config_parser.Function, key string, return nil } -func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) { +func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []uint16, upstream *routing.Outbound) (err error) { for i, value := range values { upstreamName := consts.OutboundLogicalOr.String() if i == len(values)-1 { @@ -166,7 +165,7 @@ type requestMatchSet struct { func (m *RequestMatcher) Match( qName string, - qType dnsmessage.Type, + qType uint16, ) (upstreamIndex consts.DnsRequestOutboundIndex, err error) { var domainMatchBitmap []uint32 if qName != "" { @@ -185,7 +184,7 @@ func (m *RequestMatcher) Match( goodSubrule = true } case consts.MatchType_QType: - if qType == dnsmessage.Type(match.Value) { + if qType == match.Value { goodSubrule = true } case consts.MatchType_Fallback: diff --git a/component/dns/response_routing.go b/component/dns/response_routing.go index acee549..83a6f22 100644 --- a/component/dns/response_routing.go +++ b/component/dns/response_routing.go @@ -18,7 +18,6 @@ import ( "github.com/daeuniverse/dae/pkg/config_parser" "github.com/daeuniverse/dae/pkg/trie" "github.com/sirupsen/logrus" - "golang.org/x/net/dns/dnsmessage" ) type ResponseMatcherBuilder struct { @@ -138,7 +137,7 @@ func (b *ResponseMatcherBuilder) addUpstream(f *config_parser.Function, values [ return nil } -func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) { +func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []uint16, upstream *routing.Outbound) (err error) { for i, value := range values { upstreamName := consts.OutboundLogicalOr.String() if i == len(values)-1 { @@ -219,7 +218,7 @@ type responseMatchSet struct { func (m *ResponseMatcher) Match( qName string, - qType dnsmessage.Type, + qType uint16, ips []netip.Addr, upstream consts.DnsRequestOutboundIndex, ) (upstreamIndex consts.DnsResponseOutboundIndex, err error) { @@ -253,7 +252,7 @@ func (m *ResponseMatcher) Match( } } case consts.MatchType_QType: - if qType == dnsmessage.Type(match.Value) { + if qType == uint16(match.Value) { goodSubrule = true } case consts.MatchType_Upstream: diff --git a/component/outbound/dialer/connectivity_check.go b/component/outbound/dialer/connectivity_check.go index 0929590..601cf97 100644 --- a/component/outbound/dialer/connectivity_check.go +++ b/component/outbound/dialer/connectivity_check.go @@ -25,11 +25,11 @@ import ( "github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/netutils" + dnsmessage "github.com/miekg/dns" "github.com/mzz2017/softwind/netproxy" "github.com/mzz2017/softwind/pkg/fastrand" "github.com/mzz2017/softwind/protocol/direct" "github.com/sirupsen/logrus" - "golang.org/x/net/dns/dnsmessage" ) type NetworkType struct { diff --git a/component/outbound/dialer/socks/socks_test.go b/component/outbound/dialer/socks/socks_test.go index a663eca..410de06 100644 --- a/component/outbound/dialer/socks/socks_test.go +++ b/component/outbound/dialer/socks/socks_test.go @@ -7,13 +7,14 @@ package socks import ( "context" - "github.com/daeuniverse/dae/common/netutils" - "github.com/daeuniverse/dae/component/outbound/dialer" - "github.com/sirupsen/logrus" - "golang.org/x/net/dns/dnsmessage" "net/netip" "testing" "time" + + "github.com/daeuniverse/dae/common/netutils" + "github.com/daeuniverse/dae/component/outbound/dialer" + dnsmessage "github.com/miekg/dns" + "github.com/sirupsen/logrus" ) func TestSocks5(t *testing.T) { diff --git a/component/outbound/dialer_group_test.go b/component/outbound/dialer_group_test.go index 5ad5dfd..ca200d8 100644 --- a/component/outbound/dialer_group_test.go +++ b/component/outbound/dialer_group_test.go @@ -30,8 +30,8 @@ func TestDialerGroup_Select_Fixed(t *testing.T) { log := logger.NewLogger("trace", false) option := &dialer.GlobalOption{ Log: log, - TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl}, - CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns}, + TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}}, + CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}}, CheckInterval: 15 * time.Second, CheckTolerance: 0, CheckDnsTcp: false, @@ -46,7 +46,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) { FixedIndex: fixedIndex, }, func(alive bool, networkType *dialer.NetworkType, isInit bool) {}) for i := 0; i < 10; i++ { - d, _, err := g.Select(TestNetworkType) + d, _, err := g.Select(TestNetworkType, false) if err != nil { t.Fatal(err) } @@ -58,7 +58,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) { fixedIndex = 0 g.selectionPolicy.FixedIndex = fixedIndex for i := 0; i < 10; i++ { - d, _, err := g.Select(TestNetworkType) + d, _, err := g.Select(TestNetworkType, false) if err != nil { t.Fatal(err) } @@ -73,8 +73,8 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) { option := &dialer.GlobalOption{ Log: log, - TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl}, - CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns}, + TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}}, + CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}}, CheckInterval: 15 * time.Second, } dialers := []*dialer.Dialer{ @@ -120,7 +120,7 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) { } g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(d, alive) } - d, _, err := g.Select(TestNetworkType) + d, _, err := g.Select(TestNetworkType, false) if err != nil { t.Fatal(err) } @@ -143,8 +143,8 @@ func TestDialerGroup_Select_Random(t *testing.T) { option := &dialer.GlobalOption{ Log: log, - TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl}, - CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns}, + TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}}, + CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}}, CheckInterval: 15 * time.Second, } dialers := []*dialer.Dialer{ @@ -159,7 +159,7 @@ func TestDialerGroup_Select_Random(t *testing.T) { }, func(alive bool, networkType *dialer.NetworkType, isInit bool) {}) count := make([]int, len(dialers)) for i := 0; i < 100; i++ { - d, _, err := g.Select(TestNetworkType) + d, _, err := g.Select(TestNetworkType, false) if err != nil { t.Fatal(err) } @@ -183,8 +183,8 @@ func TestDialerGroup_SetAlive(t *testing.T) { option := &dialer.GlobalOption{ Log: log, - TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl}, - CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns}, + TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}}, + CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}}, CheckInterval: 15 * time.Second, } dialers := []*dialer.Dialer{ @@ -201,7 +201,7 @@ func TestDialerGroup_SetAlive(t *testing.T) { g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(dialers[zeroTarget], false) count := make([]int, len(dialers)) for i := 0; i < 100; i++ { - d, _, err := g.Select(TestNetworkType) + d, _, err := g.Select(TestNetworkType, false) if err != nil { t.Fatal(err) } diff --git a/control/control_plane.go b/control/control_plane.go index 1a51c1e..8c08578 100644 --- a/control/control_plane.go +++ b/control/control_plane.go @@ -32,12 +32,12 @@ import ( "github.com/daeuniverse/dae/config" "github.com/daeuniverse/dae/pkg/config_parser" internal "github.com/daeuniverse/dae/pkg/ebpf_internal" + dnsmessage "github.com/miekg/dns" "github.com/mohae/deepcopy" "github.com/mzz2017/softwind/pool" "github.com/mzz2017/softwind/protocol/direct" "github.com/mzz2017/softwind/transport/grpc" "github.com/sirupsen/logrus" - "golang.org/x/net/dns/dnsmessage" "golang.org/x/sys/unix" ) @@ -409,10 +409,10 @@ func NewControlPlane( } return nil }, - NewCache: func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error) { + NewCache: func(fqdn string, answers []dnsmessage.RR, deadline time.Time) (cache *DnsCache, err error) { return &DnsCache{ DomainBitmap: plane.routingMatcher.domainMatcher.MatchDomainBitmap(fqdn), - Answers: answers, + Answer: answers, Deadline: deadline, }, nil }, @@ -433,8 +433,13 @@ func NewControlPlane( continue } host := cacheKey[:lastDot] - typ := cacheKey[lastDot+1:] - _ = plane.dnsController.UpdateDnsCacheDeadline(host, typ, cache.Answers, cache.Deadline) + _typ := cacheKey[lastDot+1:] + typ, err := strconv.ParseUint(_typ, 10, 16) + if err != nil { + // Unexpected. + return nil, err + } + _ = plane.dnsController.UpdateDnsCacheDeadline(host, uint16(typ), cache.Answer, cache.Deadline) } } else if _bpf != nil { // Is reloading, and dnsCache == nil. @@ -509,43 +514,36 @@ func (c *ControlPlane) dnsUpstreamReadyCallback(dnsUpstream *dns.Upstream) (err /// Updates dns cache to support domain routing for hostname of dns_upstream. // Ten years later. deadline := time.Now().Add(time.Hour * 24 * 365 * 10) - fqdn := dnsUpstream.Hostname - if !strings.HasSuffix(fqdn, ".") { - fqdn = fqdn + "." - } + fqdn := dnsmessage.CanonicalName(dnsUpstream.Hostname) if dnsUpstream.Ip4.IsValid() { typ := dnsmessage.TypeA - answers := []dnsmessage.Resource{{ - Header: dnsmessage.ResourceHeader{ - Name: dnsmessage.MustNewName(fqdn), - Type: typ, - Class: dnsmessage.ClassINET, - TTL: 0, // Must be zero. - }, - Body: &dnsmessage.AResource{ - A: dnsUpstream.Ip4.As4(), + answers := []dnsmessage.RR{&dnsmessage.A{ + Hdr: dnsmessage.RR_Header{ + Name: dnsmessage.CanonicalName(fqdn), + Rrtype: typ, + Class: dnsmessage.ClassINET, + Ttl: 0, // Must be zero. }, + A: dnsUpstream.Ip4.AsSlice(), }} - if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ.String(), answers, deadline); err != nil { + if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ, answers, deadline); err != nil { return err } } if dnsUpstream.Ip6.IsValid() { typ := dnsmessage.TypeAAAA - answers := []dnsmessage.Resource{{ - Header: dnsmessage.ResourceHeader{ - Name: dnsmessage.MustNewName(fqdn), - Type: typ, - Class: dnsmessage.ClassINET, - TTL: 0, // Must be zero. - }, - Body: &dnsmessage.AAAAResource{ - AAAA: dnsUpstream.Ip6.As16(), + answers := []dnsmessage.RR{&dnsmessage.AAAA{ + Hdr: dnsmessage.RR_Header{ + Name: dnsmessage.CanonicalName(fqdn), + Rrtype: typ, + Class: dnsmessage.ClassINET, + Ttl: 0, // Must be zero. }, + AAAA: dnsUpstream.Ip6.AsSlice(), }} - if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ.String(), answers, deadline); err != nil { + if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ, answers, deadline); err != nil { return err } } diff --git a/control/control_plane_core.go b/control/control_plane_core.go index 8f402c2..b4121c5 100644 --- a/control/control_plane_core.go +++ b/control/control_plane_core.go @@ -20,11 +20,11 @@ import ( "github.com/daeuniverse/dae/common" "github.com/daeuniverse/dae/common/consts" internal "github.com/daeuniverse/dae/pkg/ebpf_internal" + dnsmessage "github.com/miekg/dns" "github.com/mohae/deepcopy" "github.com/safchain/ethtool" "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" - "golang.org/x/net/dns/dnsmessage" "golang.org/x/sys/unix" ) @@ -629,15 +629,18 @@ func (c *controlPlaneCore) _bindWan(ifname string) error { func (c *controlPlaneCore) BatchUpdateDomainRouting(cache *DnsCache) error { // Parse ips from DNS resp answers. var ips []netip.Addr - for _, ans := range cache.Answers { - var ip netip.Addr - switch ans.Header.Type { - case dnsmessage.TypeA: - ip = netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A) - case dnsmessage.TypeAAAA: - ip = netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA) + for _, ans := range cache.Answer { + var ( + ip netip.Addr + ok bool + ) + switch body := ans.(type) { + case *dnsmessage.A: + ip, ok = netip.AddrFromSlice(body.A) + case *dnsmessage.AAAA: + ip, ok = netip.AddrFromSlice(body.AAAA) } - if ip.IsUnspecified() { + if !ok || ip.IsUnspecified() { continue } ips = append(ips, ip) @@ -672,15 +675,18 @@ func (c *controlPlaneCore) BatchUpdateDomainRouting(cache *DnsCache) error { func (c *controlPlaneCore) BatchRemoveDomainRouting(cache *DnsCache) error { // Parse ips from DNS resp answers. var ips []netip.Addr - for _, ans := range cache.Answers { - var ip netip.Addr - switch ans.Header.Type { - case dnsmessage.TypeA: - ip = netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A) - case dnsmessage.TypeAAAA: - ip = netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA) + for _, ans := range cache.Answer { + var ( + ip netip.Addr + ok bool + ) + switch body := ans.(type) { + case *dnsmessage.A: + ip, ok = netip.AddrFromSlice(body.A) + case *dnsmessage.AAAA: + ip, ok = netip.AddrFromSlice(body.AAAA) } - if ip.IsUnspecified() { + if !ok || ip.IsUnspecified() { continue } ips = append(ips, ip) diff --git a/control/dns_cache.go b/control/dns_cache.go index 7ddf094..43122ea 100644 --- a/control/dns_cache.go +++ b/control/dns_cache.go @@ -9,49 +9,39 @@ import ( "net/netip" "time" + dnsmessage "github.com/miekg/dns" "github.com/mohae/deepcopy" - "golang.org/x/net/dns/dnsmessage" ) type DnsCache struct { DomainBitmap []uint32 - Answers []dnsmessage.Resource + Answer []dnsmessage.RR Deadline time.Time } -func (c *DnsCache) FillInto(req *dnsmessage.Message) { - req.Answers = deepcopy.Copy(c.Answers).([]dnsmessage.Resource) - // No need to align because of no flipping now. - //// Align question and answer Name. - //if len(req.Questions) > 0 { - // q := req.Questions[0] - // for i := range req.Answers { - // if strings.EqualFold(req.Answers[i].Header.Name.String(), q.Name.String()) { - // req.Answers[i].Header.Name.Data = q.Name.Data - // } - // } - //} - req.RCode = dnsmessage.RCodeSuccess +func (c *DnsCache) FillInto(req *dnsmessage.Msg) { + req.Answer = deepcopy.Copy(c.Answer).([]dnsmessage.RR) + req.Rcode = dnsmessage.RcodeSuccess req.Response = true req.RecursionAvailable = true req.Truncated = false } func (c *DnsCache) IncludeIp(ip netip.Addr) bool { - for _, ans := range c.Answers { - switch body := ans.Body.(type) { - case *dnsmessage.AResource: + for _, ans := range c.Answer { + switch body := ans.(type) { + case *dnsmessage.A: if !ip.Is4() { continue } - if netip.AddrFrom4(body.A) == ip { + if a, ok := netip.AddrFromSlice(body.A); ok && a == ip { return true } - case *dnsmessage.AAAAResource: + case *dnsmessage.AAAA: if !ip.Is6() { continue } - if netip.AddrFrom16(body.AAAA) == ip { + if a, ok := netip.AddrFromSlice(body.AAAA); ok && a == ip { return true } } @@ -60,9 +50,9 @@ func (c *DnsCache) IncludeIp(ip netip.Addr) bool { } func (c *DnsCache) IncludeAnyIp() bool { - for _, ans := range c.Answers { - switch ans.Body.(type) { - case *dnsmessage.AResource, *dnsmessage.AAAAResource: + for _, ans := range c.Answer { + switch ans.(type) { + case *dnsmessage.A, *dnsmessage.AAAA: return true } } diff --git a/control/dns_control.go b/control/dns_control.go index 52e2d4e..849a394 100644 --- a/control/dns_control.go +++ b/control/dns_control.go @@ -8,12 +8,12 @@ package control import ( "context" "encoding/binary" - "errors" "fmt" "io" "math" "net" "net/netip" + "strconv" "strings" "sync" "time" @@ -25,12 +25,12 @@ import ( "github.com/daeuniverse/dae/component/dns" "github.com/daeuniverse/dae/component/outbound" "github.com/daeuniverse/dae/component/outbound/dialer" + dnsmessage "github.com/miekg/dns" "github.com/mohae/deepcopy" "github.com/mzz2017/softwind/netproxy" "github.com/mzz2017/softwind/pkg/fastrand" "github.com/mzz2017/softwind/pool" "github.com/sirupsen/logrus" - "golang.org/x/net/dns/dnsmessage" ) const ( @@ -47,7 +47,6 @@ const ( ) var ( - SuspectedRushAnswerError = fmt.Errorf("suspected DNS rush-answer") UnsupportedQuestionTypeError = fmt.Errorf("unsupported question type") ) @@ -60,7 +59,7 @@ type DnsControllerOption struct { Log *logrus.Logger CacheAccessCallback func(cache *DnsCache) (err error) CacheRemoveCallback func(cache *DnsCache) (err error) - NewCache func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error) + NewCache func(fqdn string, answers []dnsmessage.RR, deadline time.Time) (cache *DnsCache, err error) BestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error) IpVersionPrefer int FixedDomainTtl map[string]int @@ -70,12 +69,12 @@ type DnsController struct { handling sync.Map routing *dns.Dns - qtypePrefer dnsmessage.Type + qtypePrefer uint16 log *logrus.Logger cacheAccessCallback func(cache *DnsCache) (err error) cacheRemoveCallback func(cache *DnsCache) (err error) - newCache func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error) + newCache func(fqdn string, answers []dnsmessage.RR, deadline time.Time) (cache *DnsCache, err error) bestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error) fixedDomainTtl map[string]int @@ -84,7 +83,7 @@ type DnsController struct { dnsCache map[string]*DnsCache } -func parseIpVersionPreference(prefer int) (dnsmessage.Type, error) { +func parseIpVersionPreference(prefer int) (uint16, error) { switch prefer := IpVersionPrefer(prefer); prefer { case IpVersionPrefer_No: return 0, nil @@ -120,15 +119,12 @@ func NewDnsController(routing *dns.Dns, option *DnsControllerOption) (c *DnsCont }, nil } -func (c *DnsController) cacheKey(qname string, qtype dnsmessage.Type) string { +func (c *DnsController) cacheKey(qname string, qtype uint16) string { // To fqdn. - if !strings.HasSuffix(qname, ".") { - qname = qname + "." - } - return strings.ToLower(qname) + qtype.String() + return dnsmessage.CanonicalName(qname) + strconv.Itoa(int(qtype)) } -func (c *DnsController) RemoveDnsRespCache(qname string, qtype dnsmessage.Type) { +func (c *DnsController) RemoveDnsRespCache(qname string, qtype uint16) { c.dnsCacheMu.Lock() key := c.cacheKey(qname, qtype) _, ok := c.dnsCache[key] @@ -137,7 +133,7 @@ func (c *DnsController) RemoveDnsRespCache(qname string, qtype dnsmessage.Type) } c.dnsCacheMu.Unlock() } -func (c *DnsController) LookupDnsRespCache(qname string, qtype dnsmessage.Type) (cache *DnsCache) { +func (c *DnsController) LookupDnsRespCache(qname string, qtype uint16) (cache *DnsCache) { c.dnsCacheMu.Lock() cache, ok := c.dnsCache[c.cacheKey(qname, qtype)] c.dnsCacheMu.Unlock() @@ -150,15 +146,15 @@ func (c *DnsController) LookupDnsRespCache(qname string, qtype dnsmessage.Type) } // LookupDnsRespCache_ will modify the msg in place. -func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Message) (resp []byte) { - if len(msg.Questions) == 0 { +func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Msg) (resp []byte) { + if len(msg.Question) == 0 { return nil } - q := msg.Questions[0] + q := msg.Question[0] if msg.Response { return nil } - cache := c.LookupDnsRespCache(q.Name.String(), q.Type) + cache := c.LookupDnsRespCache(q.Name, q.Qtype) if cache != nil { cache.FillInto(msg) b, err := msg.Pack() @@ -176,28 +172,28 @@ func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Message) (resp []byt } // DnsRespHandler handle DNS resp. -func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMsg *dnsmessage.Message, err error) { - var msg dnsmessage.Message +func (c *DnsController) DnsRespHandler(data []byte) (newMsg *dnsmessage.Msg, err error) { + var msg dnsmessage.Msg if err = msg.Unpack(data); err != nil { return nil, fmt.Errorf("unpack dns pkt: %w", err) } // Check healthy resp. - if !msg.Response || len(msg.Questions) == 0 { + if !msg.Response || len(msg.Question) == 0 { return &msg, nil } - q := msg.Questions[0] + q := msg.Question[0] // Check suc resp. - if msg.RCode != dnsmessage.RCodeSuccess { + if msg.Rcode != dnsmessage.RcodeSuccess { return &msg, nil } // Get TTL. var ttl uint32 - for i := range msg.Answers { + for i := range msg.Answer { if ttl == 0 { - ttl = msg.Answers[i].Header.TTL + ttl = msg.Answer[i].Header().Ttl break } } @@ -207,7 +203,7 @@ func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMs } // Check req type. - switch q.Type { + switch q.Qtype { case dnsmessage.TypeA, dnsmessage.TypeAAAA: default: // Update DnsCache. @@ -218,17 +214,17 @@ func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMs } // Set ttl. - for i := range msg.Answers { + for i := range msg.Answer { // Set TTL = zero. This requests applications must resend every request. // However, it may be not defined in the standard. - msg.Answers[i].Header.TTL = 0 + msg.Answer[i].Header().Ttl = 0 } // Check if request A/AAAA record. var reqIpRecord bool loop: - for i := range msg.Questions { - switch msg.Questions[i].Type { + for i := range msg.Question { + switch msg.Question[i].Qtype { case dnsmessage.TypeA, dnsmessage.TypeAAAA: reqIpRecord = true break loop @@ -242,23 +238,6 @@ loop: return &msg, nil } - if validateRushAns { - exist, e := EnsureAdditionalOpt(&msg, false) - if e != nil && !errors.Is(e, UnsupportedQuestionTypeError) { - c.log.Warnf("EnsureAdditionalOpt: %v", e) - } - if e == nil && !exist { - // Additional record OPT in the request was ensured, and in normal case the resp should also set it. - // This DNS packet may be a rush-answer, and we should reject it. - c.log.WithFields(logrus.Fields{ - "ques": q, - "addition": FormatDnsRsc(msg.Additionals), - "ans": FormatDnsRsc(msg.Answers), - }).Traceln("DNS rush-answer detected") - return nil, SuspectedRushAnswerError - } - } - // Update DnsCache. if err = c.updateDnsCache(&msg, ttl, &q); err != nil { return nil, err @@ -267,31 +246,29 @@ loop: return &msg, nil } -func (c *DnsController) updateDnsCache(msg *dnsmessage.Message, ttl uint32, q *dnsmessage.Question) error { +func (c *DnsController) updateDnsCache(msg *dnsmessage.Msg, ttl uint32, q *dnsmessage.Question) error { // Update DnsCache. if c.log.IsLevelEnabled(logrus.TraceLevel) { c.log.WithFields(logrus.Fields{ - "_qname": q.Name, - "rcode": msg.RCode, - "ans": FormatDnsRsc(msg.Answers), - "auth": FormatDnsRsc(msg.Authorities), - "addition": FormatDnsRsc(msg.Additionals), + "_qname": q.Name, + "rcode": msg.Rcode, + "ans": FormatDnsRsc(msg.Answer), }).Tracef("Update DNS record cache") } - if err := c.UpdateDnsCacheTtl(q.Name.String(), q.Type.String(), msg.Answers, int(ttl)); err != nil { + if err := c.UpdateDnsCacheTtl(q.Name, q.Qtype, msg.Answer, int(ttl)); err != nil { return err } return nil } -func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp string, answers []dnsmessage.Resource, deadlineFunc func(now time.Time, host string) time.Time) (err error) { +func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp uint16, answers []dnsmessage.RR, deadlineFunc func(now time.Time, host string) time.Time) (err error) { var fqdn string if strings.HasSuffix(host, ".") { - fqdn = host + fqdn = strings.ToLower(host) host = host[:len(host)-1] } else { - fqdn = host + "." + fqdn = dnsmessage.CanonicalName(host) } // Bypass pure IP. if _, err = netip.ParseAddr(host); err == nil { @@ -301,11 +278,11 @@ func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp string, ans now := time.Now() deadline := deadlineFunc(now, host) - cacheKey := fqdn + dnsTyp + cacheKey := c.cacheKey(fqdn, dnsTyp) c.dnsCacheMu.Lock() cache, ok := c.dnsCache[cacheKey] if ok { - cache.Answers = answers + cache.Answer = answers cache.Deadline = deadline c.dnsCacheMu.Unlock() } else { @@ -324,7 +301,7 @@ func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp string, ans return nil } -func (c *DnsController) UpdateDnsCacheDeadline(host string, dnsTyp string, answers []dnsmessage.Resource, deadline time.Time) (err error) { +func (c *DnsController) UpdateDnsCacheDeadline(host string, dnsTyp uint16, answers []dnsmessage.RR, deadline time.Time) (err error) { return c.__updateDnsCacheDeadline(host, dnsTyp, answers, func(now time.Time, host string) time.Time { if fixedTtl, ok := c.fixedDomainTtl[host]; ok { /// NOTICE: Cannot set TTL accurately. @@ -336,7 +313,7 @@ func (c *DnsController) UpdateDnsCacheDeadline(host string, dnsTyp string, answe }) } -func (c *DnsController) UpdateDnsCacheTtl(host string, dnsTyp string, answers []dnsmessage.Resource, ttl int) (err error) { +func (c *DnsController) UpdateDnsCacheTtl(host string, dnsTyp uint16, answers []dnsmessage.RR, ttl int) (err error) { return c.__updateDnsCacheDeadline(host, dnsTyp, answers, func(now time.Time, host string) time.Time { if fixedTtl, ok := c.fixedDomainTtl[host]; ok { return now.Add(time.Duration(fixedTtl) * time.Second) @@ -346,27 +323,16 @@ func (c *DnsController) UpdateDnsCacheTtl(host string, dnsTyp string, answers [] }) } -func (c *DnsController) DnsRespHandlerFactory(validateRushAnsFunc func(from netip.AddrPort) bool) func(data []byte, from netip.AddrPort) (msg *dnsmessage.Message, err error) { - return func(data []byte, from netip.AddrPort) (msg *dnsmessage.Message, err error) { +func (c *DnsController) DnsRespHandlerFactory() func(data []byte, from netip.AddrPort) (msg *dnsmessage.Msg, err error) { + return func(data []byte, from netip.AddrPort) (msg *dnsmessage.Msg, err error) { // Do not return conn-unrelated err in this func. - validateRushAns := validateRushAnsFunc(from) - msg, err = c.DnsRespHandler(data, validateRushAns) + msg, err = c.DnsRespHandler(data) if err != nil { - if errors.Is(err, SuspectedRushAnswerError) { - if validateRushAns { - // Reject DNS rush-answer. - c.log.WithFields(logrus.Fields{ - "from": from, - }).Tracef("DNS rush-answer rejected") - return nil, nil - } - } else { - if c.log.IsLevelEnabled(logrus.DebugLevel) { - c.log.Debugf("DnsRespHandler: %v", err) - } - return nil, err + if c.log.IsLevelEnabled(logrus.DebugLevel) { + c.log.Debugf("DnsRespHandler: %v", err) } + return nil, err } return msg, nil } @@ -390,11 +356,11 @@ type dialArgument struct { mark uint32 } -func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) (err error) { - if c.log.IsLevelEnabled(logrus.TraceLevel) && len(dnsMessage.Questions) > 0 { - q := dnsMessage.Questions[0] +func (c *DnsController) Handle_(dnsMessage *dnsmessage.Msg, req *udpRequest) (err error) { + if c.log.IsLevelEnabled(logrus.TraceLevel) && len(dnsMessage.Question) > 0 { + q := dnsMessage.Question[0] c.log.Tracef("Received UDP(DNS) %v <-> %v: %v %v", - RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), req.realDst.String(), strings.ToLower(q.Name.String()), q.Type, + RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), req.realDst.String(), strings.ToLower(q.Name), QtypeToString(q.Qtype), ) } @@ -404,10 +370,10 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) // Prepare qname, qtype. var qname string - var qtype dnsmessage.Type - if len(dnsMessage.Questions) != 0 { - qname = dnsMessage.Questions[0].Name.String() - qtype = dnsMessage.Questions[0].Type + var qtype uint16 + if len(dnsMessage.Question) != 0 { + qname = dnsMessage.Question[0].Name + qtype = dnsMessage.Question[0].Qtype } // Check ip version preference and qtype. @@ -421,9 +387,9 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) } // Try to make both A and AAAA lookups. - dnsMessage2 := deepcopy.Copy(dnsMessage).(*dnsmessage.Message) - dnsMessage2.ID = uint16(fastrand.Intn(math.MaxUint16)) - var qtype2 dnsmessage.Type + dnsMessage2 := deepcopy.Copy(dnsMessage).(*dnsmessage.Msg) + dnsMessage2.Id = uint16(fastrand.Intn(math.MaxUint16)) + var qtype2 uint16 switch qtype { case dnsmessage.TypeA: qtype2 = dnsmessage.TypeAAAA @@ -432,7 +398,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) default: return fmt.Errorf("unexpected qtype path") } - dnsMessage2.Questions[0].Type = qtype2 + dnsMessage2.Question[0].Qtype = qtype2 done := make(chan struct{}) go func() { @@ -452,7 +418,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) // resp is not valid. c.log.WithFields(logrus.Fields{ "qname": qname, - }).Tracef("Reject %v due to resp not valid", qtype.String()) + }).Tracef("Reject %v due to resp not valid", qtype) return c.sendReject_(dnsMessage, req) } // resp is valid. @@ -465,25 +431,19 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) } func (c *DnsController) handle_( - dnsMessage *dnsmessage.Message, + dnsMessage *dnsmessage.Msg, req *udpRequest, needResp bool, ) (err error) { // Prepare qname, qtype. var qname string - var qtype dnsmessage.Type - if len(dnsMessage.Questions) != 0 { - q := dnsMessage.Questions[0] - qname = q.Name.String() - qtype = q.Type + var qtype uint16 + if len(dnsMessage.Question) != 0 { + q := dnsMessage.Question[0] + qname = q.Name + qtype = q.Qtype } - //// NOTICE: Rush-answer detector was removed because it does not always work in all districts. - //// Make sure there is additional record OPT in the request to filter DNS rush-answer in the response process. - //// Because rush-answer has no resp OPT. We can distinguish them from multiple responses. - //// Note that additional record OPT may not be supported by home router either. - //_, _ = EnsureAdditionalOpt(dnsMessage, true) - // Route request. upstreamIndex, upstream, err := c.routing.RequestSelect(qname, qtype) if err != nil { @@ -509,10 +469,10 @@ func (c *DnsController) handle_( return fmt.Errorf("failed to write cached DNS resp: %w", err) } } - if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 { - q := dnsMessage.Questions[0] + if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Question) > 0 { + q := dnsMessage.Question[0] c.log.Debugf("UDP(DNS) %v <-> Cache: %v %v", - RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), strings.ToLower(q.Name.String()), q.Type, + RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), strings.ToLower(q.Name), QtypeToString(q.Qtype), ) } return nil @@ -524,7 +484,7 @@ func (c *DnsController) handle_( upstreamName = upstream.String() } c.log.WithFields(logrus.Fields{ - "question": dnsMessage.Questions, + "question": dnsMessage.Question, "upstream": upstreamName, }).Traceln("Request to DNS upstream") } @@ -534,44 +494,44 @@ func (c *DnsController) handle_( if err != nil { return fmt.Errorf("pack DNS packet: %w", err) } - return c.dialSend(0, req, data, dnsMessage.ID, upstream, needResp) + return c.dialSend(0, req, data, dnsMessage.Id, upstream, needResp) } // sendReject_ send empty answer. -func (c *DnsController) sendReject_(dnsMessage *dnsmessage.Message, req *udpRequest) (err error) { - dnsMessage.Answers = nil - if len(dnsMessage.Questions) > 0 { - q := dnsMessage.Questions[0] - switch typ := q.Type; typ { +func (c *DnsController) sendReject_(dnsMessage *dnsmessage.Msg, req *udpRequest) (err error) { + dnsMessage.Answer = nil + if len(dnsMessage.Question) > 0 { + q := dnsMessage.Question[0] + switch typ := q.Qtype; typ { case dnsmessage.TypeA: - dnsMessage.Answers = []dnsmessage.Resource{{ - Header: dnsmessage.ResourceHeader{ - Name: q.Name, - Type: typ, - Class: dnsmessage.ClassINET, - TTL: 0, + dnsMessage.Answer = []dnsmessage.RR{&dnsmessage.A{ + Hdr: dnsmessage.RR_Header{ + Name: q.Name, + Rrtype: typ, + Class: dnsmessage.ClassINET, + Ttl: 0, }, - Body: &dnsmessage.AResource{A: UnspecifiedAddressA.As4()}, + A: UnspecifiedAddressA.AsSlice(), }} case dnsmessage.TypeAAAA: - dnsMessage.Answers = []dnsmessage.Resource{{ - Header: dnsmessage.ResourceHeader{ - Name: q.Name, - Type: typ, - Class: dnsmessage.ClassINET, - TTL: 0, + dnsMessage.Answer = []dnsmessage.RR{&dnsmessage.AAAA{ + Hdr: dnsmessage.RR_Header{ + Name: q.Name, + Rrtype: typ, + Class: dnsmessage.ClassINET, + Ttl: 0, }, - Body: &dnsmessage.AAAAResource{AAAA: UnspecifiedAddressAAAA.As16()}, + AAAA: UnspecifiedAddressAAAA.AsSlice(), }} } } - dnsMessage.RCode = dnsmessage.RCodeSuccess + dnsMessage.Rcode = dnsmessage.RcodeSuccess dnsMessage.Response = true dnsMessage.RecursionAvailable = true dnsMessage.Truncated = false if c.log.IsLevelEnabled(logrus.TraceLevel) { c.log.WithFields(logrus.Fields{ - "question": dnsMessage.Questions, + "question": dnsMessage.Question, }).Traceln("Reject") } data, err := dnsMessage.Pack() @@ -623,21 +583,9 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte } // dnsRespHandler caches dns response and check rush answers. - dnsRespHandler := c.DnsRespHandlerFactory(func(from netip.AddrPort) bool { - //// NOTICE: Rush-answer detector was removed because it does not always work in all districts. - //// We only validate rush-ans when outbound is direct and pkt does not send to a home device. - //// Because additional record OPT may not be supported by home router. - //// So se should trust home devices even if they make rush-answer (or looks like). - //return dialArgument.bestDialer.Property().Name == "direct" && - // !from.Addr().IsPrivate() && - // !from.Addr().IsLoopback() && - // !from.Addr().IsUnspecified() - - // Do not validate rush-answer. - return false - }) + dnsRespHandler := c.DnsRespHandlerFactory() // Dial and send. - var respMsg *dnsmessage.Message + var respMsg *dnsmessage.Msg // defer in a recursive call will delay Close(), thus we Close() before // the next recursive call. However, a connection cannot be closed twice. // We should set a connClosed flag to avoid it. @@ -774,23 +722,23 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte // Accept. if c.log.IsLevelEnabled(logrus.TraceLevel) { c.log.WithFields(logrus.Fields{ - "question": respMsg.Questions, + "question": respMsg.Question, "upstream": upstreamName, }).Traceln("Accept") } case consts.DnsResponseOutboundIndex_Reject: // Reject the request with empty answer. - respMsg.Answers = nil + respMsg.Answer = nil if c.log.IsLevelEnabled(logrus.TraceLevel) { c.log.WithFields(logrus.Fields{ - "question": respMsg.Questions, + "question": respMsg.Question, "upstream": upstreamName, }).Traceln("Reject with empty answer") } default: if c.log.IsLevelEnabled(logrus.TraceLevel) { c.log.WithFields(logrus.Fields{ - "question": respMsg.Questions, + "question": respMsg.Question, "last_upstream": upstreamName, "next_upstream": nextUpstream.String(), }).Traceln("Change DNS upstream and resend") @@ -798,11 +746,14 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte return c.dialSend(invokingDepth+1, req, data, id, nextUpstream, needResp) } if upstreamIndex.IsReserved() && c.log.IsLevelEnabled(logrus.InfoLevel) { - var qname, qtype string - if len(respMsg.Questions) > 0 { - q := respMsg.Questions[0] - qname = strings.ToLower(q.Name.String()) - qtype = q.Type.String() + var ( + qname string + qtype string + ) + if len(respMsg.Question) > 0 { + q := respMsg.Question[0] + qname = strings.ToLower(q.Name) + qtype = QtypeToString(q.Qtype) } fields := logrus.Fields{ "network": networkType.String(), @@ -825,7 +776,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte } } // Keep the id the same with request. - respMsg.ID = id + respMsg.Id = id data, err = respMsg.Pack() if err != nil { return err diff --git a/control/dns_utils.go b/control/dns_utils.go index a53fbee..3c449b9 100644 --- a/control/dns_utils.go +++ b/control/dns_utils.go @@ -6,98 +6,44 @@ package control import ( - "encoding/binary" "fmt" - "hash/fnv" - "math/rand" - "net/netip" + "strconv" "strings" - "golang.org/x/net/dns/dnsmessage" + dnsmessage "github.com/miekg/dns" ) -// FlipDnsQuestionCase is used to reduce dns pollution. -func FlipDnsQuestionCase(dm *dnsmessage.Message) { - if len(dm.Questions) == 0 { - return - } - q := &dm.Questions[0] - // For reproducibility, we use dm.ID as input and add some entropy to make the results more discrete. - h := fnv.New64() - var buf [4]byte - binary.BigEndian.PutUint16(buf[:], dm.ID) - h.Write(buf[:2]) - binary.BigEndian.PutUint32(buf[:], 20230204) // entropy - h.Write(buf[:]) - r := rand.New(rand.NewSource(int64(h.Sum64()))) - perm := r.Perm(int(q.Name.Length)) - for i := 0; i < int(q.Name.Length/3); i++ { - j := perm[i] - // Upper to lower; lower to upper. - if q.Name.Data[j] >= 'a' && q.Name.Data[j] <= 'z' { - q.Name.Data[j] -= 'a' - 'A' - } else if q.Name.Data[j] >= 'A' && q.Name.Data[j] <= 'Z' { - q.Name.Data[j] += 'a' - 'A' - } - } -} - -// EnsureAdditionalOpt makes sure there is additional record OPT in the request. -func EnsureAdditionalOpt(dm *dnsmessage.Message, isReqAdd bool) (bool, error) { - // Check healthy resp. - if isReqAdd == dm.Response || dm.RCode != dnsmessage.RCodeSuccess || len(dm.Questions) == 0 { - return false, UnsupportedQuestionTypeError - } - q := dm.Questions[0] - switch q.Type { - case dnsmessage.TypeA, dnsmessage.TypeAAAA: - default: - return false, UnsupportedQuestionTypeError - } - - for _, ad := range dm.Additionals { - if ad.Header.Type == dnsmessage.TypeOPT { - // Already has additional record OPT. - return true, nil - } - } - if !isReqAdd { - return false, nil - } - // Add one. - dm.Additionals = append(dm.Additionals, dnsmessage.Resource{ - Header: dnsmessage.ResourceHeader{ - Name: dnsmessage.MustNewName("."), - Type: dnsmessage.TypeOPT, - Class: 512, TTL: 0, Length: 0, - }, - Body: &dnsmessage.OPTResource{ - Options: nil, - }, - }) - return false, nil -} - type RscWrapper struct { - Rsc dnsmessage.Resource + Rsc dnsmessage.RR } func (w RscWrapper) String() string { var strBody string - switch body := w.Rsc.Body.(type) { - case *dnsmessage.AResource: - strBody = netip.AddrFrom4(body.A).String() - case *dnsmessage.AAAAResource: - strBody = netip.AddrFrom16(body.AAAA).String() + switch body := w.Rsc.(type) { + case *dnsmessage.A: + strBody = body.A.String() + case *dnsmessage.AAAA: + strBody = body.AAAA.String() + case *dnsmessage.CNAME: + strBody = body.Target default: - strBody = body.GoString() + strBody = body.String() } - return fmt.Sprintf("%v(%v): %v", w.Rsc.Header.Name.String(), w.Rsc.Header.Type.String(), strBody) + return fmt.Sprintf("%v(%v): %v", w.Rsc.Header().Name, QtypeToString(w.Rsc.Header().Rrtype), strBody) } -func FormatDnsRsc(ans []dnsmessage.Resource) string { + +func FormatDnsRsc(ans []dnsmessage.RR) string { var w []string for _, a := range ans { w = append(w, RscWrapper{Rsc: a}.String()) } return strings.Join(w, "; ") } + +func QtypeToString(qtype uint16) string { + str, ok := dnsmessage.TypeToString[qtype] + if !ok { + str = strconv.Itoa(int(qtype)) + } + return str +} diff --git a/control/udp.go b/control/udp.go index e4ad09f..eb07037 100644 --- a/control/udp.go +++ b/control/udp.go @@ -20,9 +20,9 @@ import ( "github.com/daeuniverse/dae/component/outbound/dialer" "github.com/daeuniverse/dae/component/sniffing" internal "github.com/daeuniverse/dae/pkg/ebpf_internal" + dnsmessage "github.com/miekg/dns" "github.com/mzz2017/softwind/pkg/zeroalloc/buffer" "github.com/sirupsen/logrus" - "golang.org/x/net/dns/dnsmessage" ) const ( @@ -31,11 +31,11 @@ const ( MaxRetry = 2 ) -func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Message, timeout time.Duration) { +func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout time.Duration) { if sniffDns { - var dnsmsg dnsmessage.Message + var dnsmsg dnsmessage.Msg if err := dnsmsg.Unpack(data); err == nil { - //log.Printf("DEBUG: lookup %v", dnsmsg.Questions[0].Name) + //log.Printf("DEBUG: lookup %v", dnsmsg.Question[0].Name) return &dnsmsg, DnsNatTimeout } } diff --git a/control/udp_endpoint.go b/control/udp_endpoint.go index 1e37dfc..2e15123 100644 --- a/control/udp_endpoint.go +++ b/control/udp_endpoint.go @@ -6,7 +6,6 @@ package control import ( - "errors" "fmt" "net/netip" "sync" @@ -47,9 +46,6 @@ func (ue *UdpEndpoint) start() { ue.deadlineTimer.Reset(ue.NatTimeout) ue.mu.Unlock() if err = ue.handler(buf[:n], from); err != nil { - if errors.Is(err, SuspectedRushAnswerError) { - continue - } break } } diff --git a/go.mod b/go.mod index 3cdc228..5609a53 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d github.com/gorilla/websocket v1.5.0 github.com/json-iterator/go v1.1.12 + github.com/miekg/dns v1.1.55 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/mzz2017/softwind v0.0.0-20230708102709-26ff44839573 github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd @@ -21,7 +22,6 @@ require ( github.com/x-cray/logrus-prefixed-formatter v0.5.2 golang.org/x/crypto v0.11.0 golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df - golang.org/x/net v0.12.0 golang.org/x/sys v0.10.0 google.golang.org/protobuf v1.31.0 ) @@ -38,6 +38,7 @@ require ( github.com/quic-go/qtls-go1-19 v0.3.2 // indirect github.com/quic-go/qtls-go1-20 v0.2.2 // indirect golang.org/x/mod v0.12.0 // indirect + golang.org/x/net v0.12.0 // indirect golang.org/x/tools v0.11.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230706204954-ccb25ca9f130 // indirect ) diff --git a/go.sum b/go.sum index d1b6d1e..eb4ac0f 100644 --- a/go.sum +++ b/go.sum @@ -78,6 +78,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI= github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= +github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo= +github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -175,6 +177,7 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=