mirror of
https://github.com/daeuniverse/dae.git
synced 2025-07-09 15:30:06 +07:00
refactor(dns): replace dnsmessage with miekg/dns (#188)
This commit is contained in:
@ -16,8 +16,8 @@ import (
|
||||
"github.com/daeuniverse/dae/common/consts"
|
||||
"github.com/daeuniverse/dae/component/routing"
|
||||
"github.com/daeuniverse/dae/config"
|
||||
dnsmessage "github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
)
|
||||
|
||||
var BadUpstreamFormatError = fmt.Errorf("bad upstream format")
|
||||
@ -148,7 +148,7 @@ func (s *Dns) InitUpstreams() {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *Dns) RequestSelect(qname string, qtype dnsmessage.Type) (upstreamIndex consts.DnsRequestOutboundIndex, upstream *Upstream, err error) {
|
||||
func (s *Dns) RequestSelect(qname string, qtype uint16) (upstreamIndex consts.DnsRequestOutboundIndex, upstream *Upstream, err error) {
|
||||
// Route.
|
||||
upstreamIndex, err = s.reqMatcher.Match(qname, qtype)
|
||||
if err != nil {
|
||||
@ -170,29 +170,37 @@ func (s *Dns) RequestSelect(qname string, qtype dnsmessage.Type) (upstreamIndex
|
||||
return upstreamIndex, upstream, nil
|
||||
}
|
||||
|
||||
func (s *Dns) ResponseSelect(msg *dnsmessage.Message, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) {
|
||||
func (s *Dns) ResponseSelect(msg *dnsmessage.Msg, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) {
|
||||
if !msg.Response {
|
||||
return 0, nil, fmt.Errorf("DNS response expected but DNS request received")
|
||||
}
|
||||
|
||||
// Prepare routing.
|
||||
var qname string
|
||||
var qtype dnsmessage.Type
|
||||
var qtype uint16
|
||||
var ips []netip.Addr
|
||||
if len(msg.Questions) == 0 {
|
||||
if len(msg.Question) == 0 {
|
||||
qname = ""
|
||||
qtype = 0
|
||||
} else {
|
||||
q := msg.Questions[0]
|
||||
qname = q.Name.String()
|
||||
qtype = q.Type
|
||||
for _, ans := range msg.Answers {
|
||||
switch body := ans.Body.(type) {
|
||||
case *dnsmessage.AResource:
|
||||
ips = append(ips, netip.AddrFrom4(body.A))
|
||||
case *dnsmessage.AAAAResource:
|
||||
ips = append(ips, netip.AddrFrom16(body.AAAA))
|
||||
q := msg.Question[0]
|
||||
qname = q.Name
|
||||
qtype = q.Qtype
|
||||
for _, ans := range msg.Answer {
|
||||
var (
|
||||
ip netip.Addr
|
||||
ok bool
|
||||
)
|
||||
switch body := ans.(type) {
|
||||
case *dnsmessage.A:
|
||||
ip, ok = netip.AddrFromSlice(body.A)
|
||||
case *dnsmessage.AAAA:
|
||||
ip, ok = netip.AddrFromSlice(body.AAAA)
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -12,38 +12,20 @@ import (
|
||||
|
||||
"github.com/daeuniverse/dae/component/routing"
|
||||
"github.com/daeuniverse/dae/pkg/config_parser"
|
||||
dnsmessage "github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
)
|
||||
|
||||
var typeNames = map[string]dnsmessage.Type{
|
||||
"A": dnsmessage.TypeA,
|
||||
"NS": dnsmessage.TypeNS,
|
||||
"CNAME": dnsmessage.TypeCNAME,
|
||||
"SOA": dnsmessage.TypeSOA,
|
||||
"PTR": dnsmessage.TypePTR,
|
||||
"MX": dnsmessage.TypeMX,
|
||||
"TXT": dnsmessage.TypeTXT,
|
||||
"AAAA": dnsmessage.TypeAAAA,
|
||||
"SRV": dnsmessage.TypeSRV,
|
||||
"OPT": dnsmessage.TypeOPT,
|
||||
"WKS": dnsmessage.TypeWKS,
|
||||
"HINFO": dnsmessage.TypeHINFO,
|
||||
"MINFO": dnsmessage.TypeMINFO,
|
||||
"AXFR": dnsmessage.TypeAXFR,
|
||||
"ALL": dnsmessage.TypeALL,
|
||||
}
|
||||
|
||||
func TypeParserFactory(callback func(f *config_parser.Function, types []dnsmessage.Type, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
|
||||
func TypeParserFactory(callback func(f *config_parser.Function, types []uint16, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
|
||||
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) {
|
||||
var types []dnsmessage.Type
|
||||
var types []uint16
|
||||
for _, v := range paramValueGroup {
|
||||
if t, ok := typeNames[strings.ToUpper(v)]; ok {
|
||||
if t, ok := dnsmessage.StringToType[strings.ToUpper(v)]; ok {
|
||||
types = append(types, t)
|
||||
continue
|
||||
}
|
||||
if val, err := strconv.ParseUint(v, 0, 16); err == nil {
|
||||
types = append(types, dnsmessage.Type(val))
|
||||
types = append(types, uint16(val))
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("unknown DNS request type: %v", v)
|
||||
|
@ -15,7 +15,6 @@ import (
|
||||
"github.com/daeuniverse/dae/config"
|
||||
"github.com/daeuniverse/dae/pkg/config_parser"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
)
|
||||
|
||||
type RequestMatcherBuilder struct {
|
||||
@ -88,7 +87,7 @@ func (b *RequestMatcherBuilder) addQName(f *config_parser.Function, key string,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
|
||||
func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []uint16, upstream *routing.Outbound) (err error) {
|
||||
for i, value := range values {
|
||||
upstreamName := consts.OutboundLogicalOr.String()
|
||||
if i == len(values)-1 {
|
||||
@ -166,7 +165,7 @@ type requestMatchSet struct {
|
||||
|
||||
func (m *RequestMatcher) Match(
|
||||
qName string,
|
||||
qType dnsmessage.Type,
|
||||
qType uint16,
|
||||
) (upstreamIndex consts.DnsRequestOutboundIndex, err error) {
|
||||
var domainMatchBitmap []uint32
|
||||
if qName != "" {
|
||||
@ -185,7 +184,7 @@ func (m *RequestMatcher) Match(
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_QType:
|
||||
if qType == dnsmessage.Type(match.Value) {
|
||||
if qType == match.Value {
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_Fallback:
|
||||
|
@ -18,7 +18,6 @@ import (
|
||||
"github.com/daeuniverse/dae/pkg/config_parser"
|
||||
"github.com/daeuniverse/dae/pkg/trie"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
)
|
||||
|
||||
type ResponseMatcherBuilder struct {
|
||||
@ -138,7 +137,7 @@ func (b *ResponseMatcherBuilder) addUpstream(f *config_parser.Function, values [
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
|
||||
func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []uint16, upstream *routing.Outbound) (err error) {
|
||||
for i, value := range values {
|
||||
upstreamName := consts.OutboundLogicalOr.String()
|
||||
if i == len(values)-1 {
|
||||
@ -219,7 +218,7 @@ type responseMatchSet struct {
|
||||
|
||||
func (m *ResponseMatcher) Match(
|
||||
qName string,
|
||||
qType dnsmessage.Type,
|
||||
qType uint16,
|
||||
ips []netip.Addr,
|
||||
upstream consts.DnsRequestOutboundIndex,
|
||||
) (upstreamIndex consts.DnsResponseOutboundIndex, err error) {
|
||||
@ -253,7 +252,7 @@ func (m *ResponseMatcher) Match(
|
||||
}
|
||||
}
|
||||
case consts.MatchType_QType:
|
||||
if qType == dnsmessage.Type(match.Value) {
|
||||
if qType == uint16(match.Value) {
|
||||
goodSubrule = true
|
||||
}
|
||||
case consts.MatchType_Upstream:
|
||||
|
@ -25,11 +25,11 @@ import (
|
||||
|
||||
"github.com/daeuniverse/dae/common/consts"
|
||||
"github.com/daeuniverse/dae/common/netutils"
|
||||
dnsmessage "github.com/miekg/dns"
|
||||
"github.com/mzz2017/softwind/netproxy"
|
||||
"github.com/mzz2017/softwind/pkg/fastrand"
|
||||
"github.com/mzz2017/softwind/protocol/direct"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
)
|
||||
|
||||
type NetworkType struct {
|
||||
|
@ -7,13 +7,14 @@ package socks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/daeuniverse/dae/common/netutils"
|
||||
"github.com/daeuniverse/dae/component/outbound/dialer"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/daeuniverse/dae/common/netutils"
|
||||
"github.com/daeuniverse/dae/component/outbound/dialer"
|
||||
dnsmessage "github.com/miekg/dns"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func TestSocks5(t *testing.T) {
|
||||
|
@ -30,8 +30,8 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
|
||||
log := logger.NewLogger("trace", false)
|
||||
option := &dialer.GlobalOption{
|
||||
Log: log,
|
||||
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
|
||||
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
|
||||
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
|
||||
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
|
||||
CheckInterval: 15 * time.Second,
|
||||
CheckTolerance: 0,
|
||||
CheckDnsTcp: false,
|
||||
@ -46,7 +46,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
|
||||
FixedIndex: fixedIndex,
|
||||
}, func(alive bool, networkType *dialer.NetworkType, isInit bool) {})
|
||||
for i := 0; i < 10; i++ {
|
||||
d, _, err := g.Select(TestNetworkType)
|
||||
d, _, err := g.Select(TestNetworkType, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -58,7 +58,7 @@ func TestDialerGroup_Select_Fixed(t *testing.T) {
|
||||
fixedIndex = 0
|
||||
g.selectionPolicy.FixedIndex = fixedIndex
|
||||
for i := 0; i < 10; i++ {
|
||||
d, _, err := g.Select(TestNetworkType)
|
||||
d, _, err := g.Select(TestNetworkType, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -73,8 +73,8 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
|
||||
|
||||
option := &dialer.GlobalOption{
|
||||
Log: log,
|
||||
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
|
||||
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
|
||||
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
|
||||
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
|
||||
CheckInterval: 15 * time.Second,
|
||||
}
|
||||
dialers := []*dialer.Dialer{
|
||||
@ -120,7 +120,7 @@ func TestDialerGroup_Select_MinLastLatency(t *testing.T) {
|
||||
}
|
||||
g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(d, alive)
|
||||
}
|
||||
d, _, err := g.Select(TestNetworkType)
|
||||
d, _, err := g.Select(TestNetworkType, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -143,8 +143,8 @@ func TestDialerGroup_Select_Random(t *testing.T) {
|
||||
|
||||
option := &dialer.GlobalOption{
|
||||
Log: log,
|
||||
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
|
||||
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
|
||||
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
|
||||
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
|
||||
CheckInterval: 15 * time.Second,
|
||||
}
|
||||
dialers := []*dialer.Dialer{
|
||||
@ -159,7 +159,7 @@ func TestDialerGroup_Select_Random(t *testing.T) {
|
||||
}, func(alive bool, networkType *dialer.NetworkType, isInit bool) {})
|
||||
count := make([]int, len(dialers))
|
||||
for i := 0; i < 100; i++ {
|
||||
d, _, err := g.Select(TestNetworkType)
|
||||
d, _, err := g.Select(TestNetworkType, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -183,8 +183,8 @@ func TestDialerGroup_SetAlive(t *testing.T) {
|
||||
|
||||
option := &dialer.GlobalOption{
|
||||
Log: log,
|
||||
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: testTcpCheckUrl},
|
||||
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: testUdpCheckDns},
|
||||
TcpCheckOptionRaw: dialer.TcpCheckOptionRaw{Raw: []string{testTcpCheckUrl}},
|
||||
CheckDnsOptionRaw: dialer.CheckDnsOptionRaw{Raw: []string{testUdpCheckDns}},
|
||||
CheckInterval: 15 * time.Second,
|
||||
}
|
||||
dialers := []*dialer.Dialer{
|
||||
@ -201,7 +201,7 @@ func TestDialerGroup_SetAlive(t *testing.T) {
|
||||
g.MustGetAliveDialerSet(TestNetworkType).NotifyLatencyChange(dialers[zeroTarget], false)
|
||||
count := make([]int, len(dialers))
|
||||
for i := 0; i < 100; i++ {
|
||||
d, _, err := g.Select(TestNetworkType)
|
||||
d, _, err := g.Select(TestNetworkType, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user