refactor(dns): replace dnsmessage with miekg/dns (#188)

This commit is contained in:
mzz
2023-07-09 16:02:17 +08:00
committed by GitHub
parent b82b31e350
commit 00cf4bc3cd
20 changed files with 327 additions and 427 deletions

View File

@ -16,8 +16,8 @@ import (
"github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/component/routing"
"github.com/daeuniverse/dae/config"
dnsmessage "github.com/miekg/dns"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
var BadUpstreamFormatError = fmt.Errorf("bad upstream format")
@ -148,7 +148,7 @@ func (s *Dns) InitUpstreams() {
wg.Wait()
}
func (s *Dns) RequestSelect(qname string, qtype dnsmessage.Type) (upstreamIndex consts.DnsRequestOutboundIndex, upstream *Upstream, err error) {
func (s *Dns) RequestSelect(qname string, qtype uint16) (upstreamIndex consts.DnsRequestOutboundIndex, upstream *Upstream, err error) {
// Route.
upstreamIndex, err = s.reqMatcher.Match(qname, qtype)
if err != nil {
@ -170,29 +170,37 @@ func (s *Dns) RequestSelect(qname string, qtype dnsmessage.Type) (upstreamIndex
return upstreamIndex, upstream, nil
}
func (s *Dns) ResponseSelect(msg *dnsmessage.Message, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) {
func (s *Dns) ResponseSelect(msg *dnsmessage.Msg, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) {
if !msg.Response {
return 0, nil, fmt.Errorf("DNS response expected but DNS request received")
}
// Prepare routing.
var qname string
var qtype dnsmessage.Type
var qtype uint16
var ips []netip.Addr
if len(msg.Questions) == 0 {
if len(msg.Question) == 0 {
qname = ""
qtype = 0
} else {
q := msg.Questions[0]
qname = q.Name.String()
qtype = q.Type
for _, ans := range msg.Answers {
switch body := ans.Body.(type) {
case *dnsmessage.AResource:
ips = append(ips, netip.AddrFrom4(body.A))
case *dnsmessage.AAAAResource:
ips = append(ips, netip.AddrFrom16(body.AAAA))
q := msg.Question[0]
qname = q.Name
qtype = q.Qtype
for _, ans := range msg.Answer {
var (
ip netip.Addr
ok bool
)
switch body := ans.(type) {
case *dnsmessage.A:
ip, ok = netip.AddrFromSlice(body.A)
case *dnsmessage.AAAA:
ip, ok = netip.AddrFromSlice(body.AAAA)
}
if !ok {
continue
}
ips = append(ips, ip)
}
}

View File

@ -12,38 +12,20 @@ import (
"github.com/daeuniverse/dae/component/routing"
"github.com/daeuniverse/dae/pkg/config_parser"
dnsmessage "github.com/miekg/dns"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
var typeNames = map[string]dnsmessage.Type{
"A": dnsmessage.TypeA,
"NS": dnsmessage.TypeNS,
"CNAME": dnsmessage.TypeCNAME,
"SOA": dnsmessage.TypeSOA,
"PTR": dnsmessage.TypePTR,
"MX": dnsmessage.TypeMX,
"TXT": dnsmessage.TypeTXT,
"AAAA": dnsmessage.TypeAAAA,
"SRV": dnsmessage.TypeSRV,
"OPT": dnsmessage.TypeOPT,
"WKS": dnsmessage.TypeWKS,
"HINFO": dnsmessage.TypeHINFO,
"MINFO": dnsmessage.TypeMINFO,
"AXFR": dnsmessage.TypeAXFR,
"ALL": dnsmessage.TypeALL,
}
func TypeParserFactory(callback func(f *config_parser.Function, types []dnsmessage.Type, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
func TypeParserFactory(callback func(f *config_parser.Function, types []uint16, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) {
var types []dnsmessage.Type
var types []uint16
for _, v := range paramValueGroup {
if t, ok := typeNames[strings.ToUpper(v)]; ok {
if t, ok := dnsmessage.StringToType[strings.ToUpper(v)]; ok {
types = append(types, t)
continue
}
if val, err := strconv.ParseUint(v, 0, 16); err == nil {
types = append(types, dnsmessage.Type(val))
types = append(types, uint16(val))
continue
}
return fmt.Errorf("unknown DNS request type: %v", v)

View File

@ -15,7 +15,6 @@ import (
"github.com/daeuniverse/dae/config"
"github.com/daeuniverse/dae/pkg/config_parser"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
type RequestMatcherBuilder struct {
@ -88,7 +87,7 @@ func (b *RequestMatcherBuilder) addQName(f *config_parser.Function, key string,
return nil
}
func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []uint16, upstream *routing.Outbound) (err error) {
for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
@ -166,7 +165,7 @@ type requestMatchSet struct {
func (m *RequestMatcher) Match(
qName string,
qType dnsmessage.Type,
qType uint16,
) (upstreamIndex consts.DnsRequestOutboundIndex, err error) {
var domainMatchBitmap []uint32
if qName != "" {
@ -185,7 +184,7 @@ func (m *RequestMatcher) Match(
goodSubrule = true
}
case consts.MatchType_QType:
if qType == dnsmessage.Type(match.Value) {
if qType == match.Value {
goodSubrule = true
}
case consts.MatchType_Fallback:

View File

@ -18,7 +18,6 @@ import (
"github.com/daeuniverse/dae/pkg/config_parser"
"github.com/daeuniverse/dae/pkg/trie"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
type ResponseMatcherBuilder struct {
@ -138,7 +137,7 @@ func (b *ResponseMatcherBuilder) addUpstream(f *config_parser.Function, values [
return nil
}
func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []uint16, upstream *routing.Outbound) (err error) {
for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
@ -219,7 +218,7 @@ type responseMatchSet struct {
func (m *ResponseMatcher) Match(
qName string,
qType dnsmessage.Type,
qType uint16,
ips []netip.Addr,
upstream consts.DnsRequestOutboundIndex,
) (upstreamIndex consts.DnsResponseOutboundIndex, err error) {
@ -253,7 +252,7 @@ func (m *ResponseMatcher) Match(
}
}
case consts.MatchType_QType:
if qType == dnsmessage.Type(match.Value) {
if qType == uint16(match.Value) {
goodSubrule = true
}
case consts.MatchType_Upstream:

View File

@ -25,11 +25,11 @@ import (
"github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/common/netutils"
dnsmessage "github.com/miekg/dns"
"github.com/mzz2017/softwind/netproxy"
"github.com/mzz2017/softwind/pkg/fastrand"
"github.com/mzz2017/softwind/protocol/direct"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
type NetworkType struct {

View File

@ -7,13 +7,14 @@ package socks
import (
"context"
"github.com/daeuniverse/dae/common/netutils"
"github.com/daeuniverse/dae/component/outbound/dialer"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
"net/netip"
"testing"
"time"
"github.com/daeuniverse/dae/common/netutils"
"github.com/daeuniverse/dae/component/outbound/dialer"
dnsmessage "github.com/miekg/dns"
"github.com/sirupsen/logrus"
)
func TestSocks5(t *testing.T) {

View File

@ -30,8 +30,8 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
log := logger.NewLogger("trace", false)
option := &dialer.GlobalOption{
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
CheckInterval: 15 * time.Second,
CheckTolerance: 0,
CheckDnsTcp: false,
@ -46,7 +46,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
FixedIndex: fixedIndex,
}, func(alive bool, networkType *dialer.NetworkType, isInit bool) {})
for i := 0; i < 10; i++ {
d, _, err := g.Select(TestNetworkType)
d, _, err := g.Select(TestNetworkType, false)
if err != nil {
t.Fatal(err)
}
@ -58,7 +58,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
fixedIndex = 0
g.selectionPolicy.FixedIndex = fixedIndex
for i := 0; i < 10; i++ {
d, _, err := g.Select(TestNetworkType)
d, _, err := g.Select(TestNetworkType, false)
if err != nil {
t.Fatal(err)
}
@ -73,8 +73,8 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
option := &dialer.GlobalOption{
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
CheckInterval: 15 * time.Second,
}
dialers := []*dialer.Dialer{
@ -120,7 +120,7 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
}
g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(d, alive)
}
d, _, err := g.Select(TestNetworkType)
d, _, err := g.Select(TestNetworkType, false)
if err != nil {
t.Fatal(err)
}
@ -143,8 +143,8 @@ func TestDialerGroup_Select_Random(t *testing.T) {
option := &dialer.GlobalOption{
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
CheckInterval: 15 * time.Second,
}
dialers := []*dialer.Dialer{
@ -159,7 +159,7 @@ func TestDialerGroup_Select_Random(t *testing.T) {
}, func(alive bool, networkType *dialer.NetworkType, isInit bool) {})
count := make([]int, len(dialers))
for i := 0; i < 100; i++ {
d, _, err := g.Select(TestNetworkType)
d, _, err := g.Select(TestNetworkType, false)
if err != nil {
t.Fatal(err)
}
@ -183,8 +183,8 @@ func TestDialerGroup_SetAlive(t *testing.T) {
option := &dialer.GlobalOption{
Log: log,
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
CheckInterval: 15 * time.Second,
}
dialers := []*dialer.Dialer{
@ -201,7 +201,7 @@ func TestDialerGroup_SetAlive(t *testing.T) {
g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(dialers[zeroTarget], false)
count := make([]int, len(dialers))
for i := 0; i < 100; i++ {
d, _, err := g.Select(TestNetworkType)
d, _, err := g.Select(TestNetworkType, false)
if err != nil {
t.Fatal(err)
}