refactor(dns): replace dnsmessage with miekg/dns (#188)

This commit is contained in:
mzz 2023-07-09 16:02:17 +08:00 committed by GitHub
parent b82b31e350
commit 00cf4bc3cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 327 additions and 427 deletions

View File

@ -12,14 +12,13 @@ import (
"io"
"math"
"net/netip"
"strings"
"sync"
"time"
dnsmessage "github.com/miekg/dns"
"github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pkg/fastrand"
"github.com/mzz2017/softwind/pool"
"golang.org/x/net/dns/dnsmessage"
)
var (
@ -90,29 +89,37 @@ func SystemDns() (dns netip.AddrPort, err error) {
return systemDns, nil
}
func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, network string) (addrs []netip.Addr, err error) {
func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ uint16, network string) (addrs []netip.Addr, err error) {
resources, err := resolve(ctx, d, dns, host, typ, network)
if err != nil {
return nil, err
}
for _, ans := range resources {
if ans.Header.Type != typ {
if ans.Header().Rrtype != typ {
continue
}
var (
ip netip.Addr
okk bool
)
switch typ {
case dnsmessage.TypeA:
a, ok := ans.Body.(*dnsmessage.AResource)
a, ok := ans.(*dnsmessage.A)
if !ok {
return nil, BadDnsAnsError
}
addrs = append(addrs, netip.AddrFrom4(a.A))
ip, okk = netip.AddrFromSlice(a.A)
case dnsmessage.TypeAAAA:
a, ok := ans.Body.(*dnsmessage.AAAAResource)
a, ok := ans.(*dnsmessage.AAAA)
if !ok {
return nil, BadDnsAnsError
}
addrs = append(addrs, netip.AddrFrom16(a.AAAA))
ip, okk = netip.AddrFromSlice(a.AAAA)
}
if !okk {
continue
}
addrs = append(addrs, ip)
}
return addrs, nil
}
@ -124,50 +131,47 @@ func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host
return nil, err
}
for _, ans := range resources {
if ans.Header.Type != typ {
if ans.Header().Rrtype != typ {
continue
}
ns, ok := ans.Body.(*dnsmessage.NSResource)
ns, ok := ans.(*dnsmessage.NS)
if !ok {
return nil, BadDnsAnsError
}
records = append(records, ns.NS.String())
records = append(records, ns.Ns)
}
return records, nil
}
func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, network string) (ans []dnsmessage.Resource, err error) {
func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ uint16, network string) (ans []dnsmessage.RR, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
fqdn := host
if !strings.HasSuffix(fqdn, ".") {
fqdn += "."
}
fqdn := dnsmessage.CanonicalName(host)
switch typ {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
if addr, err := netip.ParseAddr(host); err == nil {
if (addr.Is4() || addr.Is4In6()) && typ == dnsmessage.TypeA {
return []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(fqdn),
Class: dnsmessage.ClassINET,
TTL: 0,
Type: typ,
return []dnsmessage.RR{
&dnsmessage.A{
Hdr: dnsmessage.RR_Header{
Name: dnsmessage.CanonicalName(fqdn),
Class: dnsmessage.ClassINET,
Ttl: 0,
Rrtype: typ,
},
Body: &dnsmessage.AResource{A: addr.As4()},
A: addr.AsSlice(),
},
}, nil
} else if addr.Is6() && typ == dnsmessage.TypeAAAA {
return []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(fqdn),
Class: dnsmessage.ClassINET,
TTL: 0,
Type: typ,
return []dnsmessage.RR{
&dnsmessage.AAAA{
Hdr: dnsmessage.RR_Header{
Name: dnsmessage.CanonicalName(fqdn),
Class: dnsmessage.ClassINET,
Ttl: 0,
Rrtype: typ,
},
Body: &dnsmessage.AAAAResource{AAAA: addr.As16()},
AAAA: addr.AsSlice(),
},
}, nil
}
@ -177,25 +181,18 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
default:
}
// Build DNS req.
builder := dnsmessage.NewBuilder(nil, dnsmessage.Header{
ID: uint16(fastrand.Intn(math.MaxUint16 + 1)),
Response: false,
OpCode: 0,
Truncated: false,
RecursionDesired: true,
Authoritative: false,
})
if err = builder.StartQuestions(); err != nil {
return nil, err
builder := dnsmessage.Msg{
MsgHdr: dnsmessage.MsgHdr{
Id: uint16(fastrand.Intn(math.MaxUint16 + 1)),
Response: false,
Opcode: 0,
Truncated: false,
RecursionDesired: true,
Authoritative: false,
},
}
if err = builder.Question(dnsmessage.Question{
Name: dnsmessage.MustNewName(fqdn),
Type: typ,
Class: dnsmessage.ClassINET,
}); err != nil {
return nil, err
}
b, err := builder.Finish()
builder.SetQuestion(fqdn, typ)
b, err := builder.Pack()
if err != nil {
return nil, err
}
@ -265,12 +262,12 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
return
}
// Resolve DNS response and extract A/AAAA record.
var msg dnsmessage.Message
var msg dnsmessage.Msg
if err = msg.Unpack(buf[:n]); err != nil {
ch <- err
return
}
ans = msg.Answers
ans = msg.Answer
ch <- nil
}()
select {

View File

@ -12,9 +12,9 @@ import (
"net/netip"
"sync"
dnsmessage "github.com/miekg/dns"
"github.com/mzz2017/softwind/netproxy"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
type Ip46 struct {

View File

@ -0,0 +1,22 @@
package netutils
import (
"context"
"net/netip"
"testing"
"time"
"github.com/mzz2017/softwind/protocol/direct"
)
func TestResolveIp46(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ip46, err := ResolveIp46(ctx, direct.SymmetricDirect, netip.MustParseAddrPort("223.5.5.5:53"), "www.apple.com", "udp", false)
if err != nil {
t.Fatal(err)
}
if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() {
t.Fatal("No record")
}
}

View File

@ -12,7 +12,6 @@ import (
"encoding/binary"
"encoding/hex"
"fmt"
"github.com/mzz2017/softwind/netproxy"
"net/netip"
"net/url"
"path/filepath"
@ -22,9 +21,11 @@ import (
"time"
"unsafe"
"github.com/mzz2017/softwind/netproxy"
internal "github.com/daeuniverse/dae/pkg/ebpf_internal"
dnsmessage "github.com/miekg/dns"
"github.com/vishvananda/netlink"
"golang.org/x/net/dns/dnsmessage"
"golang.org/x/sys/unix"
)
@ -409,7 +410,7 @@ func NewGcm(key []byte) (cipher.AEAD, error) {
return cipher.NewGCM(block)
}
func AddrToDnsType(addr netip.Addr) dnsmessage.Type {
func AddrToDnsType(addr netip.Addr) uint16 {
if addr.Is4() {
return dnsmessage.TypeA
} else {

View File

@ -16,8 +16,8 @@ import (
"github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/component/routing"
"github.com/daeuniverse/dae/config"
dnsmessage "github.com/miekg/dns"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
var BadUpstreamFormatError = fmt.Errorf("bad upstream format")
@ -148,7 +148,7 @@ func (s *Dns) InitUpstreams() {
wg.Wait()
}
func (s *Dns) RequestSelect(qname string, qtype dnsmessage.Type) (upstreamIndex consts.DnsRequestOutboundIndex, upstream *Upstream, err error) {
func (s *Dns) RequestSelect(qname string, qtype uint16) (upstreamIndex consts.DnsRequestOutboundIndex, upstream *Upstream, err error) {
// Route.
upstreamIndex, err = s.reqMatcher.Match(qname, qtype)
if err != nil {
@ -170,29 +170,37 @@ func (s *Dns) RequestSelect(qname string, qtype dnsmessage.Type) (upstreamIndex
return upstreamIndex, upstream, nil
}
func (s *Dns) ResponseSelect(msg *dnsmessage.Message, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) {
func (s *Dns) ResponseSelect(msg *dnsmessage.Msg, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) {
if !msg.Response {
return 0, nil, fmt.Errorf("DNS response expected but DNS request received")
}
// Prepare routing.
var qname string
var qtype dnsmessage.Type
var qtype uint16
var ips []netip.Addr
if len(msg.Questions) == 0 {
if len(msg.Question) == 0 {
qname = ""
qtype = 0
} else {
q := msg.Questions[0]
qname = q.Name.String()
qtype = q.Type
for _, ans := range msg.Answers {
switch body := ans.Body.(type) {
case *dnsmessage.AResource:
ips = append(ips, netip.AddrFrom4(body.A))
case *dnsmessage.AAAAResource:
ips = append(ips, netip.AddrFrom16(body.AAAA))
q := msg.Question[0]
qname = q.Name
qtype = q.Qtype
for _, ans := range msg.Answer {
var (
ip netip.Addr
ok bool
)
switch body := ans.(type) {
case *dnsmessage.A:
ip, ok = netip.AddrFromSlice(body.A)
case *dnsmessage.AAAA:
ip, ok = netip.AddrFromSlice(body.AAAA)
}
if !ok {
continue
}
ips = append(ips, ip)
}
}

View File

@ -12,38 +12,20 @@ import (
"github.com/daeuniverse/dae/component/routing"
"github.com/daeuniverse/dae/pkg/config_parser"
dnsmessage "github.com/miekg/dns"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
var typeNames = map[string]dnsmessage.Type{
"A": dnsmessage.TypeA,
"NS": dnsmessage.TypeNS,
"CNAME": dnsmessage.TypeCNAME,
"SOA": dnsmessage.TypeSOA,
"PTR": dnsmessage.TypePTR,
"MX": dnsmessage.TypeMX,
"TXT": dnsmessage.TypeTXT,
"AAAA": dnsmessage.TypeAAAA,
"SRV": dnsmessage.TypeSRV,
"OPT": dnsmessage.TypeOPT,
"WKS": dnsmessage.TypeWKS,
"HINFO": dnsmessage.TypeHINFO,
"MINFO": dnsmessage.TypeMINFO,
"AXFR": dnsmessage.TypeAXFR,
"ALL": dnsmessage.TypeALL,
}
func TypeParserFactory(callback func(f *config_parser.Function, types []dnsmessage.Type, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
func TypeParserFactory(callback func(f *config_parser.Function, types []uint16, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) {
var types []dnsmessage.Type
var types []uint16
for _, v := range paramValueGroup {
if t, ok := typeNames[strings.ToUpper(v)]; ok {
if t, ok := dnsmessage.StringToType[strings.ToUpper(v)]; ok {
types = append(types, t)
continue
}
if val, err := strconv.ParseUint(v, 0, 16); err == nil {
types = append(types, dnsmessage.Type(val))
types = append(types, uint16(val))
continue
}
return fmt.Errorf("unknown DNS request type: %v", v)

View File

@ -15,7 +15,6 @@ import (
"github.com/daeuniverse/dae/config"
"github.com/daeuniverse/dae/pkg/config_parser"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
type RequestMatcherBuilder struct {
@ -88,7 +87,7 @@ func (b *RequestMatcherBuilder) addQName(f *config_parser.Function, key string,
return nil
}
func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []uint16, upstream *routing.Outbound) (err error) {
for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
@ -166,7 +165,7 @@ type requestMatchSet struct {
func (m *RequestMatcher) Match(
qName string,
qType dnsmessage.Type,
qType uint16,
) (upstreamIndex consts.DnsRequestOutboundIndex, err error) {
var domainMatchBitmap []uint32
if qName != "" {
@ -185,7 +184,7 @@ func (m *RequestMatcher) Match(
goodSubrule = true
}
case consts.MatchType_QType:
if qType == dnsmessage.Type(match.Value) {
if qType == match.Value {
goodSubrule = true
}
case consts.MatchType_Fallback:

View File

@ -18,7 +18,6 @@ import (
"github.com/daeuniverse/dae/pkg/config_parser"
"github.com/daeuniverse/dae/pkg/trie"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
type ResponseMatcherBuilder struct {
@ -138,7 +137,7 @@ func (b *ResponseMatcherBuilder) addUpstream(f *config_parser.Function, values [
return nil
}
func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []uint16, upstream *routing.Outbound) (err error) {
for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
@ -219,7 +218,7 @@ type responseMatchSet struct {
func (m *ResponseMatcher) Match(
qName string,
qType dnsmessage.Type,
qType uint16,
ips []netip.Addr,
upstream consts.DnsRequestOutboundIndex,
) (upstreamIndex consts.DnsResponseOutboundIndex, err error) {
@ -253,7 +252,7 @@ func (m *ResponseMatcher) Match(
}
}
case consts.MatchType_QType:
if qType == dnsmessage.Type(match.Value) {
if qType == uint16(match.Value) {
goodSubrule = true
}
case consts.MatchType_Upstream:

View File

@ -25,11 +25,11 @@ import (
"github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/common/netutils"
dnsmessage "github.com/miekg/dns"
"github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pkg/fastrand"
"github.com/mzz2017/softwind/protocol/direct"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
type NetworkType struct {

View File

@ -7,13 +7,14 @@ package socks
import (
"context"
"github.com/daeuniverse/dae/common/netutils"
"github.com/daeuniverse/dae/component/outbound/dialer"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
"net/netip"
"testing"
"time"
"github.com/daeuniverse/dae/common/netutils"
"github.com/daeuniverse/dae/component/outbound/dialer"
dnsmessage "github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
func TestSocks5(t *testing.T) {

View File

@ -30,8 +30,8 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
log := logger.NewLogger("trace", false)
option := &dialer.GlobalOption{
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
CheckInterval: 15 * time.Second,
CheckTolerance: 0,
CheckDnsTcp: false,
@ -46,7 +46,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
FixedIndex: fixedIndex,
}, func(alive bool, networkType *dialer.NetworkType, isInit bool) {})
for i := 0; i < 10; i++ {
d, _, err := g.Select(TestNetworkType)
d, _, err := g.Select(TestNetworkType, false)
if err != nil {
t.Fatal(err)
}
@ -58,7 +58,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
fixedIndex = 0
g.selectionPolicy.FixedIndex = fixedIndex
for i := 0; i < 10; i++ {
d, _, err := g.Select(TestNetworkType)
d, _, err := g.Select(TestNetworkType, false)
if err != nil {
t.Fatal(err)
}
@ -73,8 +73,8 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
option := &dialer.GlobalOption{
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
CheckInterval: 15 * time.Second,
}
dialers := []*dialer.Dialer{
@ -120,7 +120,7 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
}
g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(d, alive)
}
d, _, err := g.Select(TestNetworkType)
d, _, err := g.Select(TestNetworkType, false)
if err != nil {
t.Fatal(err)
}
@ -143,8 +143,8 @@ func TestDialerGroup_Select_Random(t *testing.T) {
option := &dialer.GlobalOption{
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
CheckInterval: 15 * time.Second,
}
dialers := []*dialer.Dialer{
@ -159,7 +159,7 @@ func TestDialerGroup_Select_Random(t *testing.T) {
}, func(alive bool, networkType *dialer.NetworkType, isInit bool) {})
count := make([]int, len(dialers))
for i := 0; i < 100; i++ {
d, _, err := g.Select(TestNetworkType)
d, _, err := g.Select(TestNetworkType, false)
if err != nil {
t.Fatal(err)
}
@ -183,8 +183,8 @@ func TestDialerGroup_SetAlive(t *testing.T) {
option := &dialer.GlobalOption{
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
CheckInterval: 15 * time.Second,
}
dialers := []*dialer.Dialer{
@ -201,7 +201,7 @@ func TestDialerGroup_SetAlive(t *testing.T) {
g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(dialers[zeroTarget], false)
count := make([]int, len(dialers))
for i := 0; i < 100; i++ {
d, _, err := g.Select(TestNetworkType)
d, _, err := g.Select(TestNetworkType, false)
if err != nil {
t.Fatal(err)
}

View File

@ -32,12 +32,12 @@ import (
"github.com/daeuniverse/dae/config"
"github.com/daeuniverse/dae/pkg/config_parser"
internal "github.com/daeuniverse/dae/pkg/ebpf_internal"
dnsmessage "github.com/miekg/dns"
"github.com/mohae/deepcopy"
"github.com/mzz2017/softwind/pool"
"github.com/mzz2017/softwind/protocol/direct"
"github.com/mzz2017/softwind/transport/grpc"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
"golang.org/x/sys/unix"
)
@ -409,10 +409,10 @@ func NewControlPlane(
}
return nil
},
NewCache: func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error) {
NewCache: func(fqdn string, answers []dnsmessage.RR, deadline time.Time) (cache *DnsCache, err error) {
return &DnsCache{
DomainBitmap: plane.routingMatcher.domainMatcher.MatchDomainBitmap(fqdn),
Answers: answers,
Answer: answers,
Deadline: deadline,
}, nil
},
@ -433,8 +433,13 @@ func NewControlPlane(
continue
}
host := cacheKey[:lastDot]
typ := cacheKey[lastDot+1:]
_ = plane.dnsController.UpdateDnsCacheDeadline(host, typ, cache.Answers, cache.Deadline)
_typ := cacheKey[lastDot+1:]
typ, err := strconv.ParseUint(_typ, 10, 16)
if err != nil {
// Unexpected.
return nil, err
}
_ = plane.dnsController.UpdateDnsCacheDeadline(host, uint16(typ), cache.Answer, cache.Deadline)
}
} else if _bpf != nil {
// Is reloading, and dnsCache == nil.
@ -509,43 +514,36 @@ func (c *ControlPlane) dnsUpstreamReadyCallback(dnsUpstream *dns.Upstream) (err
/// Updates dns cache to support domain routing for hostname of dns_upstream.
// Ten years later.
deadline := time.Now().Add(time.Hour * 24 * 365 * 10)
fqdn := dnsUpstream.Hostname
if !strings.HasSuffix(fqdn, ".") {
fqdn = fqdn + "."
}
fqdn := dnsmessage.CanonicalName(dnsUpstream.Hostname)
if dnsUpstream.Ip4.IsValid() {
typ := dnsmessage.TypeA
answers := []dnsmessage.Resource{{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(fqdn),
Type: typ,
Class: dnsmessage.ClassINET,
TTL: 0, // Must be zero.
},
Body: &dnsmessage.AResource{
A: dnsUpstream.Ip4.As4(),
answers := []dnsmessage.RR{&dnsmessage.A{
Hdr: dnsmessage.RR_Header{
Name: dnsmessage.CanonicalName(fqdn),
Rrtype: typ,
Class: dnsmessage.ClassINET,
Ttl: 0, // Must be zero.
},
A: dnsUpstream.Ip4.AsSlice(),
}}
if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ.String(), answers, deadline); err != nil {
if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ, answers, deadline); err != nil {
return err
}
}
if dnsUpstream.Ip6.IsValid() {
typ := dnsmessage.TypeAAAA
answers := []dnsmessage.Resource{{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName(fqdn),
Type: typ,
Class: dnsmessage.ClassINET,
TTL: 0, // Must be zero.
},
Body: &dnsmessage.AAAAResource{
AAAA: dnsUpstream.Ip6.As16(),
answers := []dnsmessage.RR{&dnsmessage.AAAA{
Hdr: dnsmessage.RR_Header{
Name: dnsmessage.CanonicalName(fqdn),
Rrtype: typ,
Class: dnsmessage.ClassINET,
Ttl: 0, // Must be zero.
},
AAAA: dnsUpstream.Ip6.AsSlice(),
}}
if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ.String(), answers, deadline); err != nil {
if err = c.dnsController.UpdateDnsCacheDeadline(dnsUpstream.Hostname, typ, answers, deadline); err != nil {
return err
}
}

View File

@ -20,11 +20,11 @@ import (
"github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/common/consts"
internal "github.com/daeuniverse/dae/pkg/ebpf_internal"
dnsmessage "github.com/miekg/dns"
"github.com/mohae/deepcopy"
"github.com/safchain/ethtool"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"golang.org/x/net/dns/dnsmessage"
"golang.org/x/sys/unix"
)
@ -629,15 +629,18 @@ func (c *controlPlaneCore) _bindWan(ifname string) error {
func (c *controlPlaneCore) BatchUpdateDomainRouting(cache *DnsCache) error {
// Parse ips from DNS resp answers.
var ips []netip.Addr
for _, ans := range cache.Answers {
var ip netip.Addr
switch ans.Header.Type {
case dnsmessage.TypeA:
ip = netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A)
case dnsmessage.TypeAAAA:
ip = netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA)
for _, ans := range cache.Answer {
var (
ip netip.Addr
ok bool
)
switch body := ans.(type) {
case *dnsmessage.A:
ip, ok = netip.AddrFromSlice(body.A)
case *dnsmessage.AAAA:
ip, ok = netip.AddrFromSlice(body.AAAA)
}
if ip.IsUnspecified() {
if !ok || ip.IsUnspecified() {
continue
}
ips = append(ips, ip)
@ -672,15 +675,18 @@ func (c *controlPlaneCore) BatchUpdateDomainRouting(cache *DnsCache) error {
func (c *controlPlaneCore) BatchRemoveDomainRouting(cache *DnsCache) error {
// Parse ips from DNS resp answers.
var ips []netip.Addr
for _, ans := range cache.Answers {
var ip netip.Addr
switch ans.Header.Type {
case dnsmessage.TypeA:
ip = netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A)
case dnsmessage.TypeAAAA:
ip = netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA)
for _, ans := range cache.Answer {
var (
ip netip.Addr
ok bool
)
switch body := ans.(type) {
case *dnsmessage.A:
ip, ok = netip.AddrFromSlice(body.A)
case *dnsmessage.AAAA:
ip, ok = netip.AddrFromSlice(body.AAAA)
}
if ip.IsUnspecified() {
if !ok || ip.IsUnspecified() {
continue
}
ips = append(ips, ip)

View File

@ -9,49 +9,39 @@ import (
"net/netip"
"time"
dnsmessage "github.com/miekg/dns"
"github.com/mohae/deepcopy"
"golang.org/x/net/dns/dnsmessage"
)
type DnsCache struct {
DomainBitmap []uint32
Answers []dnsmessage.Resource
Answer []dnsmessage.RR
Deadline time.Time
}
func (c *DnsCache) FillInto(req *dnsmessage.Message) {
req.Answers = deepcopy.Copy(c.Answers).([]dnsmessage.Resource)
// No need to align because of no flipping now.
//// Align question and answer Name.
//if len(req.Questions) > 0 {
// q := req.Questions[0]
// for i := range req.Answers {
// if strings.EqualFold(req.Answers[i].Header.Name.String(), q.Name.String()) {
// req.Answers[i].Header.Name.Data = q.Name.Data
// }
// }
//}
req.RCode = dnsmessage.RCodeSuccess
func (c *DnsCache) FillInto(req *dnsmessage.Msg) {
req.Answer = deepcopy.Copy(c.Answer).([]dnsmessage.RR)
req.Rcode = dnsmessage.RcodeSuccess
req.Response = true
req.RecursionAvailable = true
req.Truncated = false
}
func (c *DnsCache) IncludeIp(ip netip.Addr) bool {
for _, ans := range c.Answers {
switch body := ans.Body.(type) {
case *dnsmessage.AResource:
for _, ans := range c.Answer {
switch body := ans.(type) {
case *dnsmessage.A:
if !ip.Is4() {
continue
}
if netip.AddrFrom4(body.A) == ip {
if a, ok := netip.AddrFromSlice(body.A); ok && a == ip {
return true
}
case *dnsmessage.AAAAResource:
case *dnsmessage.AAAA:
if !ip.Is6() {
continue
}
if netip.AddrFrom16(body.AAAA) == ip {
if a, ok := netip.AddrFromSlice(body.AAAA); ok && a == ip {
return true
}
}
@ -60,9 +50,9 @@ func (c *DnsCache) IncludeIp(ip netip.Addr) bool {
}
func (c *DnsCache) IncludeAnyIp() bool {
for _, ans := range c.Answers {
switch ans.Body.(type) {
case *dnsmessage.AResource, *dnsmessage.AAAAResource:
for _, ans := range c.Answer {
switch ans.(type) {
case *dnsmessage.A, *dnsmessage.AAAA:
return true
}
}

View File

@ -8,12 +8,12 @@ package control
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"time"
@ -25,12 +25,12 @@ import (
"github.com/daeuniverse/dae/component/dns"
"github.com/daeuniverse/dae/component/outbound"
"github.com/daeuniverse/dae/component/outbound/dialer"
dnsmessage "github.com/miekg/dns"
"github.com/mohae/deepcopy"
"github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pkg/fastrand"
"github.com/mzz2017/softwind/pool"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
const (
@ -47,7 +47,6 @@ const (
)
var (
SuspectedRushAnswerError = fmt.Errorf("suspected DNS rush-answer")
UnsupportedQuestionTypeError = fmt.Errorf("unsupported question type")
)
@ -60,7 +59,7 @@ type DnsControllerOption struct {
Log *logrus.Logger
CacheAccessCallback func(cache *DnsCache) (err error)
CacheRemoveCallback func(cache *DnsCache) (err error)
NewCache func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error)
NewCache func(fqdn string, answers []dnsmessage.RR, deadline time.Time) (cache *DnsCache, err error)
BestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error)
IpVersionPrefer int
FixedDomainTtl map[string]int
@ -70,12 +69,12 @@ type DnsController struct {
handling sync.Map
routing *dns.Dns
qtypePrefer dnsmessage.Type
qtypePrefer uint16
log *logrus.Logger
cacheAccessCallback func(cache *DnsCache) (err error)
cacheRemoveCallback func(cache *DnsCache) (err error)
newCache func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error)
newCache func(fqdn string, answers []dnsmessage.RR, deadline time.Time) (cache *DnsCache, err error)
bestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error)
fixedDomainTtl map[string]int
@ -84,7 +83,7 @@ type DnsController struct {
dnsCache map[string]*DnsCache
}
func parseIpVersionPreference(prefer int) (dnsmessage.Type, error) {
func parseIpVersionPreference(prefer int) (uint16, error) {
switch prefer := IpVersionPrefer(prefer); prefer {
case IpVersionPrefer_No:
return 0, nil
@ -120,15 +119,12 @@ func NewDnsController(routing *dns.Dns, option *DnsControllerOption) (c *DnsCont
}, nil
}
func (c *DnsController) cacheKey(qname string, qtype dnsmessage.Type) string {
func (c *DnsController) cacheKey(qname string, qtype uint16) string {
// To fqdn.
if !strings.HasSuffix(qname, ".") {
qname = qname + "."
}
return strings.ToLower(qname) + qtype.String()
return dnsmessage.CanonicalName(qname) + strconv.Itoa(int(qtype))
}
func (c *DnsController) RemoveDnsRespCache(qname string, qtype dnsmessage.Type) {
func (c *DnsController) RemoveDnsRespCache(qname string, qtype uint16) {
c.dnsCacheMu.Lock()
key := c.cacheKey(qname, qtype)
_, ok := c.dnsCache[key]
@ -137,7 +133,7 @@ func (c *DnsController) RemoveDnsRespCache(qname string, qtype dnsmessage.Type)
}
c.dnsCacheMu.Unlock()
}
func (c *DnsController) LookupDnsRespCache(qname string, qtype dnsmessage.Type) (cache *DnsCache) {
func (c *DnsController) LookupDnsRespCache(qname string, qtype uint16) (cache *DnsCache) {
c.dnsCacheMu.Lock()
cache, ok := c.dnsCache[c.cacheKey(qname, qtype)]
c.dnsCacheMu.Unlock()
@ -150,15 +146,15 @@ func (c *DnsController) LookupDnsRespCache(qname string, qtype dnsmessage.Type)
}
// LookupDnsRespCache_ will modify the msg in place.
func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Message) (resp []byte) {
if len(msg.Questions) == 0 {
func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Msg) (resp []byte) {
if len(msg.Question) == 0 {
return nil
}
q := msg.Questions[0]
q := msg.Question[0]
if msg.Response {
return nil
}
cache := c.LookupDnsRespCache(q.Name.String(), q.Type)
cache := c.LookupDnsRespCache(q.Name, q.Qtype)
if cache != nil {
cache.FillInto(msg)
b, err := msg.Pack()
@ -176,28 +172,28 @@ func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Message) (resp []byt
}
// DnsRespHandler handle DNS resp.
func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMsg *dnsmessage.Message, err error) {
var msg dnsmessage.Message
func (c *DnsController) DnsRespHandler(data []byte) (newMsg *dnsmessage.Msg, err error) {
var msg dnsmessage.Msg
if err = msg.Unpack(data); err != nil {
return nil, fmt.Errorf("unpack dns pkt: %w", err)
}
// Check healthy resp.
if !msg.Response || len(msg.Questions) == 0 {
if !msg.Response || len(msg.Question) == 0 {
return &msg, nil
}
q := msg.Questions[0]
q := msg.Question[0]
// Check suc resp.
if msg.RCode != dnsmessage.RCodeSuccess {
if msg.Rcode != dnsmessage.RcodeSuccess {
return &msg, nil
}
// Get TTL.
var ttl uint32
for i := range msg.Answers {
for i := range msg.Answer {
if ttl == 0 {
ttl = msg.Answers[i].Header.TTL
ttl = msg.Answer[i].Header().Ttl
break
}
}
@ -207,7 +203,7 @@ func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMs
}
// Check req type.
switch q.Type {
switch q.Qtype {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
default:
// Update DnsCache.
@ -218,17 +214,17 @@ func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMs
}
// Set ttl.
for i := range msg.Answers {
for i := range msg.Answer {
// Set TTL = zero. This requests applications must resend every request.
// However, it may be not defined in the standard.
msg.Answers[i].Header.TTL = 0
msg.Answer[i].Header().Ttl = 0
}
// Check if request A/AAAA record.
var reqIpRecord bool
loop:
for i := range msg.Questions {
switch msg.Questions[i].Type {
for i := range msg.Question {
switch msg.Question[i].Qtype {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
reqIpRecord = true
break loop
@ -242,23 +238,6 @@ loop:
return &msg, nil
}
if validateRushAns {
exist, e := EnsureAdditionalOpt(&msg, false)
if e != nil && !errors.Is(e, UnsupportedQuestionTypeError) {
c.log.Warnf("EnsureAdditionalOpt: %v", e)
}
if e == nil && !exist {
// Additional record OPT in the request was ensured, and in normal case the resp should also set it.
// This DNS packet may be a rush-answer, and we should reject it.
c.log.WithFields(logrus.Fields{
"ques": q,
"addition": FormatDnsRsc(msg.Additionals),
"ans": FormatDnsRsc(msg.Answers),
}).Traceln("DNS rush-answer detected")
return nil, SuspectedRushAnswerError
}
}
// Update DnsCache.
if err = c.updateDnsCache(&msg, ttl, &q); err != nil {
return nil, err
@ -267,31 +246,29 @@ loop:
return &msg, nil
}
func (c *DnsController) updateDnsCache(msg *dnsmessage.Message, ttl uint32, q *dnsmessage.Question) error {
func (c *DnsController) updateDnsCache(msg *dnsmessage.Msg, ttl uint32, q *dnsmessage.Question) error {
// Update DnsCache.
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"_qname": q.Name,
"rcode": msg.RCode,
"ans": FormatDnsRsc(msg.Answers),
"auth": FormatDnsRsc(msg.Authorities),
"addition": FormatDnsRsc(msg.Additionals),
"_qname": q.Name,
"rcode": msg.Rcode,
"ans": FormatDnsRsc(msg.Answer),
}).Tracef("Update DNS record cache")
}
if err := c.UpdateDnsCacheTtl(q.Name.String(), q.Type.String(), msg.Answers, int(ttl)); err != nil {
if err := c.UpdateDnsCacheTtl(q.Name, q.Qtype, msg.Answer, int(ttl)); err != nil {
return err
}
return nil
}
func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp string, answers []dnsmessage.Resource, deadlineFunc func(now time.Time, host string) time.Time) (err error) {
func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp uint16, answers []dnsmessage.RR, deadlineFunc func(now time.Time, host string) time.Time) (err error) {
var fqdn string
if strings.HasSuffix(host, ".") {
fqdn = host
fqdn = strings.ToLower(host)
host = host[:len(host)-1]
} else {
fqdn = host + "."
fqdn = dnsmessage.CanonicalName(host)
}
// Bypass pure IP.
if _, err = netip.ParseAddr(host); err == nil {
@ -301,11 +278,11 @@ func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp string, ans
now := time.Now()
deadline := deadlineFunc(now, host)
cacheKey := fqdn + dnsTyp
cacheKey := c.cacheKey(fqdn, dnsTyp)
c.dnsCacheMu.Lock()
cache, ok := c.dnsCache[cacheKey]
if ok {
cache.Answers = answers
cache.Answer = answers
cache.Deadline = deadline
c.dnsCacheMu.Unlock()
} else {
@ -324,7 +301,7 @@ func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp string, ans
return nil
}
func (c *DnsController) UpdateDnsCacheDeadline(host string, dnsTyp string, answers []dnsmessage.Resource, deadline time.Time) (err error) {
func (c *DnsController) UpdateDnsCacheDeadline(host string, dnsTyp uint16, answers []dnsmessage.RR, deadline time.Time) (err error) {
return c.__updateDnsCacheDeadline(host, dnsTyp, answers, func(now time.Time, host string) time.Time {
if fixedTtl, ok := c.fixedDomainTtl[host]; ok {
/// NOTICE: Cannot set TTL accurately.
@ -336,7 +313,7 @@ func (c *DnsController) UpdateDnsCacheDeadline(host string, dnsTyp string, answe
})
}
func (c *DnsController) UpdateDnsCacheTtl(host string, dnsTyp string, answers []dnsmessage.Resource, ttl int) (err error) {
func (c *DnsController) UpdateDnsCacheTtl(host string, dnsTyp uint16, answers []dnsmessage.RR, ttl int) (err error) {
return c.__updateDnsCacheDeadline(host, dnsTyp, answers, func(now time.Time, host string) time.Time {
if fixedTtl, ok := c.fixedDomainTtl[host]; ok {
return now.Add(time.Duration(fixedTtl) * time.Second)
@ -346,27 +323,16 @@ func (c *DnsController) UpdateDnsCacheTtl(host string, dnsTyp string, answers []
})
}
func (c *DnsController) DnsRespHandlerFactory(validateRushAnsFunc func(from netip.AddrPort) bool) func(data []byte, from netip.AddrPort) (msg *dnsmessage.Message, err error) {
return func(data []byte, from netip.AddrPort) (msg *dnsmessage.Message, err error) {
func (c *DnsController) DnsRespHandlerFactory() func(data []byte, from netip.AddrPort) (msg *dnsmessage.Msg, err error) {
return func(data []byte, from netip.AddrPort) (msg *dnsmessage.Msg, err error) {
// Do not return conn-unrelated err in this func.
validateRushAns := validateRushAnsFunc(from)
msg, err = c.DnsRespHandler(data, validateRushAns)
msg, err = c.DnsRespHandler(data)
if err != nil {
if errors.Is(err, SuspectedRushAnswerError) {
if validateRushAns {
// Reject DNS rush-answer.
c.log.WithFields(logrus.Fields{
"from": from,
}).Tracef("DNS rush-answer rejected")
return nil, nil
}
} else {
if c.log.IsLevelEnabled(logrus.DebugLevel) {
c.log.Debugf("DnsRespHandler: %v", err)
}
return nil, err
if c.log.IsLevelEnabled(logrus.DebugLevel) {
c.log.Debugf("DnsRespHandler: %v", err)
}
return nil, err
}
return msg, nil
}
@ -390,11 +356,11 @@ type dialArgument struct {
mark uint32
}
func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) (err error) {
if c.log.IsLevelEnabled(logrus.TraceLevel) && len(dnsMessage.Questions) > 0 {
q := dnsMessage.Questions[0]
func (c *DnsController) Handle_(dnsMessage *dnsmessage.Msg, req *udpRequest) (err error) {
if c.log.IsLevelEnabled(logrus.TraceLevel) && len(dnsMessage.Question) > 0 {
q := dnsMessage.Question[0]
c.log.Tracef("Received UDP(DNS) %v <-> %v: %v %v",
RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), req.realDst.String(), strings.ToLower(q.Name.String()), q.Type,
RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), req.realDst.String(), strings.ToLower(q.Name), QtypeToString(q.Qtype),
)
}
@ -404,10 +370,10 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest)
// Prepare qname, qtype.
var qname string
var qtype dnsmessage.Type
if len(dnsMessage.Questions) != 0 {
qname = dnsMessage.Questions[0].Name.String()
qtype = dnsMessage.Questions[0].Type
var qtype uint16
if len(dnsMessage.Question) != 0 {
qname = dnsMessage.Question[0].Name
qtype = dnsMessage.Question[0].Qtype
}
// Check ip version preference and qtype.
@ -421,9 +387,9 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest)
}
// Try to make both A and AAAA lookups.
dnsMessage2 := deepcopy.Copy(dnsMessage).(*dnsmessage.Message)
dnsMessage2.ID = uint16(fastrand.Intn(math.MaxUint16))
var qtype2 dnsmessage.Type
dnsMessage2 := deepcopy.Copy(dnsMessage).(*dnsmessage.Msg)
dnsMessage2.Id = uint16(fastrand.Intn(math.MaxUint16))
var qtype2 uint16
switch qtype {
case dnsmessage.TypeA:
qtype2 = dnsmessage.TypeAAAA
@ -432,7 +398,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest)
default:
return fmt.Errorf("unexpected qtype path")
}
dnsMessage2.Questions[0].Type = qtype2
dnsMessage2.Question[0].Qtype = qtype2
done := make(chan struct{})
go func() {
@ -452,7 +418,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest)
// resp is not valid.
c.log.WithFields(logrus.Fields{
"qname": qname,
}).Tracef("Reject %v due to resp not valid", qtype.String())
}).Tracef("Reject %v due to resp not valid", qtype)
return c.sendReject_(dnsMessage, req)
}
// resp is valid.
@ -465,25 +431,19 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest)
}
func (c *DnsController) handle_(
dnsMessage *dnsmessage.Message,
dnsMessage *dnsmessage.Msg,
req *udpRequest,
needResp bool,
) (err error) {
// Prepare qname, qtype.
var qname string
var qtype dnsmessage.Type
if len(dnsMessage.Questions) != 0 {
q := dnsMessage.Questions[0]
qname = q.Name.String()
qtype = q.Type
var qtype uint16
if len(dnsMessage.Question) != 0 {
q := dnsMessage.Question[0]
qname = q.Name
qtype = q.Qtype
}
//// NOTICE: Rush-answer detector was removed because it does not always work in all districts.
//// Make sure there is additional record OPT in the request to filter DNS rush-answer in the response process.
//// Because rush-answer has no resp OPT. We can distinguish them from multiple responses.
//// Note that additional record OPT may not be supported by home router either.
//_, _ = EnsureAdditionalOpt(dnsMessage, true)
// Route request.
upstreamIndex, upstream, err := c.routing.RequestSelect(qname, qtype)
if err != nil {
@ -509,10 +469,10 @@ func (c *DnsController) handle_(
return fmt.Errorf("failed to write cached DNS resp: %w", err)
}
}
if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 {
q := dnsMessage.Questions[0]
if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Question) > 0 {
q := dnsMessage.Question[0]
c.log.Debugf("UDP(DNS) %v <-> Cache: %v %v",
RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), strings.ToLower(q.Name.String()), q.Type,
RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), strings.ToLower(q.Name), QtypeToString(q.Qtype),
)
}
return nil
@ -524,7 +484,7 @@ func (c *DnsController) handle_(
upstreamName = upstream.String()
}
c.log.WithFields(logrus.Fields{
"question": dnsMessage.Questions,
"question": dnsMessage.Question,
"upstream": upstreamName,
}).Traceln("Request to DNS upstream")
}
@ -534,44 +494,44 @@ func (c *DnsController) handle_(
if err != nil {
return fmt.Errorf("pack DNS packet: %w", err)
}
return c.dialSend(0, req, data, dnsMessage.ID, upstream, needResp)
return c.dialSend(0, req, data, dnsMessage.Id, upstream, needResp)
}
// sendReject_ send empty answer.
func (c *DnsController) sendReject_(dnsMessage *dnsmessage.Message, req *udpRequest) (err error) {
dnsMessage.Answers = nil
if len(dnsMessage.Questions) > 0 {
q := dnsMessage.Questions[0]
switch typ := q.Type; typ {
func (c *DnsController) sendReject_(dnsMessage *dnsmessage.Msg, req *udpRequest) (err error) {
dnsMessage.Answer = nil
if len(dnsMessage.Question) > 0 {
q := dnsMessage.Question[0]
switch typ := q.Qtype; typ {
case dnsmessage.TypeA:
dnsMessage.Answers = []dnsmessage.Resource{{
Header: dnsmessage.ResourceHeader{
Name: q.Name,
Type: typ,
Class: dnsmessage.ClassINET,
TTL: 0,
dnsMessage.Answer = []dnsmessage.RR{&dnsmessage.A{
Hdr: dnsmessage.RR_Header{
Name: q.Name,
Rrtype: typ,
Class: dnsmessage.ClassINET,
Ttl: 0,
},
Body: &dnsmessage.AResource{A: UnspecifiedAddressA.As4()},
A: UnspecifiedAddressA.AsSlice(),
}}
case dnsmessage.TypeAAAA:
dnsMessage.Answers = []dnsmessage.Resource{{
Header: dnsmessage.ResourceHeader{
Name: q.Name,
Type: typ,
Class: dnsmessage.ClassINET,
TTL: 0,
dnsMessage.Answer = []dnsmessage.RR{&dnsmessage.AAAA{
Hdr: dnsmessage.RR_Header{
Name: q.Name,
Rrtype: typ,
Class: dnsmessage.ClassINET,
Ttl: 0,
},
Body: &dnsmessage.AAAAResource{AAAA: UnspecifiedAddressAAAA.As16()},
AAAA: UnspecifiedAddressAAAA.AsSlice(),
}}
}
}
dnsMessage.RCode = dnsmessage.RCodeSuccess
dnsMessage.Rcode = dnsmessage.RcodeSuccess
dnsMessage.Response = true
dnsMessage.RecursionAvailable = true
dnsMessage.Truncated = false
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"question": dnsMessage.Questions,
"question": dnsMessage.Question,
}).Traceln("Reject")
}
data, err := dnsMessage.Pack()
@ -623,21 +583,9 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
}
// dnsRespHandler caches dns response and check rush answers.
dnsRespHandler := c.DnsRespHandlerFactory(func(from netip.AddrPort) bool {
//// NOTICE: Rush-answer detector was removed because it does not always work in all districts.
//// We only validate rush-ans when outbound is direct and pkt does not send to a home device.
//// Because additional record OPT may not be supported by home router.
//// So se should trust home devices even if they make rush-answer (or looks like).
//return dialArgument.bestDialer.Property().Name == "direct" &&
// !from.Addr().IsPrivate() &&
// !from.Addr().IsLoopback() &&
// !from.Addr().IsUnspecified()
// Do not validate rush-answer.
return false
})
dnsRespHandler := c.DnsRespHandlerFactory()
// Dial and send.
var respMsg *dnsmessage.Message
var respMsg *dnsmessage.Msg
// defer in a recursive call will delay Close(), thus we Close() before
// the next recursive call. However, a connection cannot be closed twice.
// We should set a connClosed flag to avoid it.
@ -774,23 +722,23 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
// Accept.
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"question": respMsg.Questions,
"question": respMsg.Question,
"upstream": upstreamName,
}).Traceln("Accept")
}
case consts.DnsResponseOutboundIndex_Reject:
// Reject the request with empty answer.
respMsg.Answers = nil
respMsg.Answer = nil
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"question": respMsg.Questions,
"question": respMsg.Question,
"upstream": upstreamName,
}).Traceln("Reject with empty answer")
}
default:
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"question": respMsg.Questions,
"question": respMsg.Question,
"last_upstream": upstreamName,
"next_upstream": nextUpstream.String(),
}).Traceln("Change DNS upstream and resend")
@ -798,11 +746,14 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
return c.dialSend(invokingDepth+1, req, data, id, nextUpstream, needResp)
}
if upstreamIndex.IsReserved() && c.log.IsLevelEnabled(logrus.InfoLevel) {
var qname, qtype string
if len(respMsg.Questions) > 0 {
q := respMsg.Questions[0]
qname = strings.ToLower(q.Name.String())
qtype = q.Type.String()
var (
qname string
qtype string
)
if len(respMsg.Question) > 0 {
q := respMsg.Question[0]
qname = strings.ToLower(q.Name)
qtype = QtypeToString(q.Qtype)
}
fields := logrus.Fields{
"network": networkType.String(),
@ -825,7 +776,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
}
}
// Keep the id the same with request.
respMsg.ID = id
respMsg.Id = id
data, err = respMsg.Pack()
if err != nil {
return err

View File

@ -6,98 +6,44 @@
package control
import (
"encoding/binary"
"fmt"
"hash/fnv"
"math/rand"
"net/netip"
"strconv"
"strings"
"golang.org/x/net/dns/dnsmessage"
dnsmessage "github.com/miekg/dns"
)
// FlipDnsQuestionCase is used to reduce dns pollution.
func FlipDnsQuestionCase(dm *dnsmessage.Message) {
if len(dm.Questions) == 0 {
return
}
q := &dm.Questions[0]
// For reproducibility, we use dm.ID as input and add some entropy to make the results more discrete.
h := fnv.New64()
var buf [4]byte
binary.BigEndian.PutUint16(buf[:], dm.ID)
h.Write(buf[:2])
binary.BigEndian.PutUint32(buf[:], 20230204) // entropy
h.Write(buf[:])
r := rand.New(rand.NewSource(int64(h.Sum64())))
perm := r.Perm(int(q.Name.Length))
for i := 0; i < int(q.Name.Length/3); i++ {
j := perm[i]
// Upper to lower; lower to upper.
if q.Name.Data[j] >= 'a' && q.Name.Data[j] <= 'z' {
q.Name.Data[j] -= 'a' - 'A'
} else if q.Name.Data[j] >= 'A' && q.Name.Data[j] <= 'Z' {
q.Name.Data[j] += 'a' - 'A'
}
}
}
// EnsureAdditionalOpt makes sure there is additional record OPT in the request.
func EnsureAdditionalOpt(dm *dnsmessage.Message, isReqAdd bool) (bool, error) {
// Check healthy resp.
if isReqAdd == dm.Response || dm.RCode != dnsmessage.RCodeSuccess || len(dm.Questions) == 0 {
return false, UnsupportedQuestionTypeError
}
q := dm.Questions[0]
switch q.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
default:
return false, UnsupportedQuestionTypeError
}
for _, ad := range dm.Additionals {
if ad.Header.Type == dnsmessage.TypeOPT {
// Already has additional record OPT.
return true, nil
}
}
if !isReqAdd {
return false, nil
}
// Add one.
dm.Additionals = append(dm.Additionals, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: dnsmessage.MustNewName("."),
Type: dnsmessage.TypeOPT,
Class: 512, TTL: 0, Length: 0,
},
Body: &dnsmessage.OPTResource{
Options: nil,
},
})
return false, nil
}
type RscWrapper struct {
Rsc dnsmessage.Resource
Rsc dnsmessage.RR
}
func (w RscWrapper) String() string {
var strBody string
switch body := w.Rsc.Body.(type) {
case *dnsmessage.AResource:
strBody = netip.AddrFrom4(body.A).String()
case *dnsmessage.AAAAResource:
strBody = netip.AddrFrom16(body.AAAA).String()
switch body := w.Rsc.(type) {
case *dnsmessage.A:
strBody = body.A.String()
case *dnsmessage.AAAA:
strBody = body.AAAA.String()
case *dnsmessage.CNAME:
strBody = body.Target
default:
strBody = body.GoString()
strBody = body.String()
}
return fmt.Sprintf("%v(%v): %v", w.Rsc.Header.Name.String(), w.Rsc.Header.Type.String(), strBody)
return fmt.Sprintf("%v(%v): %v", w.Rsc.Header().Name, QtypeToString(w.Rsc.Header().Rrtype), strBody)
}
func FormatDnsRsc(ans []dnsmessage.Resource) string {
func FormatDnsRsc(ans []dnsmessage.RR) string {
var w []string
for _, a := range ans {
w = append(w, RscWrapper{Rsc: a}.String())
}
return strings.Join(w, "; ")
}
func QtypeToString(qtype uint16) string {
str, ok := dnsmessage.TypeToString[qtype]
if !ok {
str = strconv.Itoa(int(qtype))
}
return str
}

View File

@ -20,9 +20,9 @@ import (
"github.com/daeuniverse/dae/component/outbound/dialer"
"github.com/daeuniverse/dae/component/sniffing"
internal "github.com/daeuniverse/dae/pkg/ebpf_internal"
dnsmessage "github.com/miekg/dns"
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
const (
@ -31,11 +31,11 @@ const (
MaxRetry = 2
)
func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Message, timeout time.Duration) {
func ChooseNatTimeout(data []byte, sniffDns bool) (dmsg *dnsmessage.Msg, timeout time.Duration) {
if sniffDns {
var dnsmsg dnsmessage.Message
var dnsmsg dnsmessage.Msg
if err := dnsmsg.Unpack(data); err == nil {
//log.Printf("DEBUG: lookup %v", dnsmsg.Questions[0].Name)
//log.Printf("DEBUG: lookup %v", dnsmsg.Question[0].Name)
return &dnsmsg, DnsNatTimeout
}
}

View File

@ -6,7 +6,6 @@
package control
import (
"errors"
"fmt"
"net/netip"
"sync"
@ -47,9 +46,6 @@ func (ue *UdpEndpoint) start() {
ue.deadlineTimer.Reset(ue.NatTimeout)
ue.mu.Unlock()
if err = ue.handler(buf[:n], from); err != nil {
if errors.Is(err, SuspectedRushAnswerError) {
continue
}
break
}
}

3
go.mod
View File

@ -10,6 +10,7 @@ require (
github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d
github.com/gorilla/websocket v1.5.0
github.com/json-iterator/go v1.1.12
github.com/miekg/dns v1.1.55
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/mzz2017/softwind v0.0.0-20230708102709-26ff44839573
github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd
@ -21,7 +22,6 @@ require (
github.com/x-cray/logrus-prefixed-formatter v0.5.2
golang.org/x/crypto v0.11.0
golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df
golang.org/x/net v0.12.0
golang.org/x/sys v0.10.0
google.golang.org/protobuf v1.31.0
)
@ -38,6 +38,7 @@ require (
github.com/quic-go/qtls-go1-19 v0.3.2 // indirect
github.com/quic-go/qtls-go1-20 v0.2.2 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.12.0 // indirect
golang.org/x/tools v0.11.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230706204954-ccb25ca9f130 // indirect
)

3
go.sum
View File

@ -78,6 +78,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI=
github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE=
github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -175,6 +177,7 @@ golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=