refactor(dns): replace dnsmessage with miekg/dns (#188)

This commit is contained in:
mzz
2023-07-09 16:02:17 +08:00
committed by GitHub
parent b82b31e350
commit 00cf4bc3cd
20 changed files with 327 additions and 427 deletions

View File

@ -16,8 +16,8 @@ import (
"github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/component/routing"
"github.com/daeuniverse/dae/config"
dnsmessage "github.com/miekg/dns"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
var BadUpstreamFormatError = fmt.Errorf("bad upstream format")
@ -148,7 +148,7 @@ func (s *Dns) InitUpstreams() {
wg.Wait()
}
func (s *Dns) RequestSelect(qname string, qtype dnsmessage.Type) (upstreamIndex consts.DnsRequestOutboundIndex, upstream *Upstream, err error) {
func (s *Dns) RequestSelect(qname string, qtype uint16) (upstreamIndex consts.DnsRequestOutboundIndex, upstream *Upstream, err error) {
// Route.
upstreamIndex, err = s.reqMatcher.Match(qname, qtype)
if err != nil {
@ -170,29 +170,37 @@ func (s *Dns) RequestSelect(qname string, qtype dnsmessage.Type) (upstreamIndex
return upstreamIndex, upstream, nil
}
func (s *Dns) ResponseSelect(msg *dnsmessage.Message, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) {
func (s *Dns) ResponseSelect(msg *dnsmessage.Msg, fromUpstream *Upstream) (upstreamIndex consts.DnsResponseOutboundIndex, upstream *Upstream, err error) {
if !msg.Response {
return 0, nil, fmt.Errorf("DNS response expected but DNS request received")
}
// Prepare routing.
var qname string
var qtype dnsmessage.Type
var qtype uint16
var ips []netip.Addr
if len(msg.Questions) == 0 {
if len(msg.Question) == 0 {
qname = ""
qtype = 0
} else {
q := msg.Questions[0]
qname = q.Name.String()
qtype = q.Type
for _, ans := range msg.Answers {
switch body := ans.Body.(type) {
case *dnsmessage.AResource:
ips = append(ips, netip.AddrFrom4(body.A))
case *dnsmessage.AAAAResource:
ips = append(ips, netip.AddrFrom16(body.AAAA))
q := msg.Question[0]
qname = q.Name
qtype = q.Qtype
for _, ans := range msg.Answer {
var (
ip netip.Addr
ok bool
)
switch body := ans.(type) {
case *dnsmessage.A:
ip, ok = netip.AddrFromSlice(body.A)
case *dnsmessage.AAAA:
ip, ok = netip.AddrFromSlice(body.AAAA)
}
if !ok {
continue
}
ips = append(ips, ip)
}
}

View File

@ -12,38 +12,20 @@ import (
"github.com/daeuniverse/dae/component/routing"
"github.com/daeuniverse/dae/pkg/config_parser"
dnsmessage "github.com/miekg/dns"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
var typeNames = map[string]dnsmessage.Type{
"A": dnsmessage.TypeA,
"NS": dnsmessage.TypeNS,
"CNAME": dnsmessage.TypeCNAME,
"SOA": dnsmessage.TypeSOA,
"PTR": dnsmessage.TypePTR,
"MX": dnsmessage.TypeMX,
"TXT": dnsmessage.TypeTXT,
"AAAA": dnsmessage.TypeAAAA,
"SRV": dnsmessage.TypeSRV,
"OPT": dnsmessage.TypeOPT,
"WKS": dnsmessage.TypeWKS,
"HINFO": dnsmessage.TypeHINFO,
"MINFO": dnsmessage.TypeMINFO,
"AXFR": dnsmessage.TypeAXFR,
"ALL": dnsmessage.TypeALL,
}
func TypeParserFactory(callback func(f *config_parser.Function, types []dnsmessage.Type, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
func TypeParserFactory(callback func(f *config_parser.Function, types []uint16, overrideOutbound *routing.Outbound) (err error)) routing.FunctionParser {
return func(log *logrus.Logger, f *config_parser.Function, key string, paramValueGroup []string, overrideOutbound *routing.Outbound) (err error) {
var types []dnsmessage.Type
var types []uint16
for _, v := range paramValueGroup {
if t, ok := typeNames[strings.ToUpper(v)]; ok {
if t, ok := dnsmessage.StringToType[strings.ToUpper(v)]; ok {
types = append(types, t)
continue
}
if val, err := strconv.ParseUint(v, 0, 16); err == nil {
types = append(types, dnsmessage.Type(val))
types = append(types, uint16(val))
continue
}
return fmt.Errorf("unknown DNS request type: %v", v)

View File

@ -15,7 +15,6 @@ import (
"github.com/daeuniverse/dae/config"
"github.com/daeuniverse/dae/pkg/config_parser"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
type RequestMatcherBuilder struct {
@ -88,7 +87,7 @@ func (b *RequestMatcherBuilder) addQName(f *config_parser.Function, key string,
return nil
}
func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
func (b *RequestMatcherBuilder) addQType(f *config_parser.Function, values []uint16, upstream *routing.Outbound) (err error) {
for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
@ -166,7 +165,7 @@ type requestMatchSet struct {
func (m *RequestMatcher) Match(
qName string,
qType dnsmessage.Type,
qType uint16,
) (upstreamIndex consts.DnsRequestOutboundIndex, err error) {
var domainMatchBitmap []uint32
if qName != "" {
@ -185,7 +184,7 @@ func (m *RequestMatcher) Match(
goodSubrule = true
}
case consts.MatchType_QType:
if qType == dnsmessage.Type(match.Value) {
if qType == match.Value {
goodSubrule = true
}
case consts.MatchType_Fallback:

View File

@ -18,7 +18,6 @@ import (
"github.com/daeuniverse/dae/pkg/config_parser"
"github.com/daeuniverse/dae/pkg/trie"
"github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage"
)
type ResponseMatcherBuilder struct {
@ -138,7 +137,7 @@ func (b *ResponseMatcherBuilder) addUpstream(f *config_parser.Function, values [
return nil
}
func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []dnsmessage.Type, upstream *routing.Outbound) (err error) {
func (b *ResponseMatcherBuilder) addQType(f *config_parser.Function, values []uint16, upstream *routing.Outbound) (err error) {
for i, value := range values {
upstreamName := consts.OutboundLogicalOr.String()
if i == len(values)-1 {
@ -219,7 +218,7 @@ type responseMatchSet struct {
func (m *ResponseMatcher) Match(
qName string,
qType dnsmessage.Type,
qType uint16,
ips []netip.Addr,
upstream consts.DnsRequestOutboundIndex,
) (upstreamIndex consts.DnsResponseOutboundIndex, err error) {
@ -253,7 +252,7 @@ func (m *ResponseMatcher) Match(
}
}
case consts.MatchType_QType:
if qType == dnsmessage.Type(match.Value) {
if qType == uint16(match.Value) {
goodSubrule = true
}
case consts.MatchType_Upstream: