diff --git a/common/netutils/dns.go b/common/netutils/dns.go index 6d12b0b..f863d5a 100644 --- a/common/netutils/dns.go +++ b/common/netutils/dns.go @@ -142,6 +142,10 @@ func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, tcp bool) (ans []dnsmessage.Resource, err error) { ctx, cancel := context.WithCancel(ctx) defer cancel() + fqdn := host + if !strings.HasSuffix(fqdn, ".") { + fqdn += "." + } switch typ { case dnsmessage.TypeA, dnsmessage.TypeAAAA: if addr, err := netip.ParseAddr(host); err == nil { @@ -149,7 +153,10 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st return []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ - Type: typ, + Name: dnsmessage.MustNewName(fqdn), + Class: dnsmessage.ClassINET, + TTL: 0, + Type: typ, }, Body: &dnsmessage.AResource{A: addr.As4()}, }, @@ -158,7 +165,10 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st return []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ - Type: typ, + Name: dnsmessage.MustNewName(fqdn), + Class: dnsmessage.ClassINET, + TTL: 0, + Type: typ, }, Body: &dnsmessage.AAAAResource{AAAA: addr.As16()}, }, @@ -181,10 +191,6 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st if err = builder.StartQuestions(); err != nil { return nil, err } - fqdn := host - if !strings.HasSuffix(fqdn, ".") { - fqdn += "." - } if err = builder.Question(dnsmessage.Question{ Name: dnsmessage.MustNewName(fqdn), Type: typ, diff --git a/control/control_plane_core.go b/control/control_plane_core.go index 6435d11..e5b6e2c 100644 --- a/control/control_plane_core.go +++ b/control/control_plane_core.go @@ -523,12 +523,17 @@ func (c *controlPlaneCore) BatchUpdateDomainRouting(cache *DnsCache) error { // Parse ips from DNS resp answers. var ips []netip.Addr for _, ans := range cache.Answers { + var ip netip.Addr switch ans.Header.Type { case dnsmessage.TypeA: - ips = append(ips, netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A)) + ip = netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A) case dnsmessage.TypeAAAA: - ips = append(ips, netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA)) + ip = netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA) } + if ip.IsUnspecified() { + continue + } + ips = append(ips, ip) } if len(ips) == 0 { return nil diff --git a/control/dns_cache.go b/control/dns_cache.go index 8660dfd..7ddf094 100644 --- a/control/dns_cache.go +++ b/control/dns_cache.go @@ -58,3 +58,13 @@ func (c *DnsCache) IncludeIp(ip netip.Addr) bool { } return false } + +func (c *DnsCache) IncludeAnyIp() bool { + for _, ans := range c.Answers { + switch ans.Body.(type) { + case *dnsmessage.AResource, *dnsmessage.AAAAResource: + return true + } + } + return false +} diff --git a/control/dns_control.go b/control/dns_control.go index 2ace736..96515b1 100644 --- a/control/dns_control.go +++ b/control/dns_control.go @@ -33,7 +33,8 @@ import ( const ( MaxDnsLookupDepth = 3 - minFirefoxCacheTimeout = 120 * time.Second + minFirefoxCacheTtl = 120 + minFirefoxCacheTimeout = minFirefoxCacheTtl * time.Second ) type IpVersionPrefer int @@ -49,6 +50,11 @@ var ( UnsupportedQuestionTypeError = fmt.Errorf("unsupported question type") ) +var ( + UnspecifiedAddressA = netip.MustParseAddr("0.0.0.0") + UnspecifiedAddressAAAA = netip.MustParseAddr("::") +) + type DnsControllerOption struct { Log *logrus.Logger CacheAccessCallback func(cache *DnsCache) (err error) @@ -125,14 +131,12 @@ func (c *DnsController) RemoveDnsRespCache(qname string, qtype dnsmessage.Type) c.dnsCacheMu.Unlock() } func (c *DnsController) LookupDnsRespCache(qname string, qtype dnsmessage.Type) (cache *DnsCache) { - now := time.Now() - c.dnsCacheMu.Lock() cache, ok := c.dnsCache[c.cacheKey(qname, qtype)] c.dnsCacheMu.Unlock() // We should make sure the remaining TTL is greater than 120s (minFirefoxCacheTimeout), or // return nil and request a new lookup to refresh the cache. - if ok && cache.Deadline.After(now.Add(minFirefoxCacheTimeout)) { + if ok { return cache } return nil @@ -187,35 +191,52 @@ func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMs return &msg, nil } - // Check req type. - switch q.Type { - case dnsmessage.TypeA, dnsmessage.TypeAAAA: - default: - return &msg, nil - } - - // Set ttl. + // Get TTL. var ttl uint32 for i := range msg.Answers { if ttl == 0 { ttl = msg.Answers[i].Header.TTL + break } + } + if ttl == 0 { + // It seems no answers (NXDomain). + ttl = minFirefoxCacheTtl + } + + // Check req type. + switch q.Type { + case dnsmessage.TypeA, dnsmessage.TypeAAAA: + default: + // Update DnsCache. + if err = c.updateDnsCache(&msg, ttl, &q); err != nil { + return nil, err + } + return &msg, nil + } + + // Set ttl. + for i := range msg.Answers { // Set TTL = zero. This requests applications must resend every request. // However, it may be not defined in the standard. msg.Answers[i].Header.TTL = 0 } - // Check if there is any A/AAAA record. - var hasIpRecord bool + // Check if request A/AAAA record. + var reqIpRecord bool loop: - for i := range msg.Answers { - switch msg.Answers[i].Header.Type { + for i := range msg.Questions { + switch msg.Questions[i].Type { case dnsmessage.TypeA, dnsmessage.TypeAAAA: - hasIpRecord = true + reqIpRecord = true break loop } } - if !hasIpRecord { + if !reqIpRecord { + // Update DnsCache. + if err = c.updateDnsCache(&msg, ttl, &q); err != nil { + return nil, err + } return &msg, nil } @@ -236,6 +257,15 @@ loop: } } + // Update DnsCache. + if err = c.updateDnsCache(&msg, ttl, &q); err != nil { + return nil, err + } + // Pack to get newData. + return &msg, nil +} + +func (c *DnsController) updateDnsCache(msg *dnsmessage.Message, ttl uint32, q *dnsmessage.Question) error { // Update DnsCache. if c.log.IsLevelEnabled(logrus.TraceLevel) { c.log.WithFields(logrus.Fields{ @@ -252,11 +282,10 @@ loop: } cacheTimeout += 5 * time.Second // DNS lookup timeout. - if err = c.UpdateDnsCache(q.Name.String(), q.Type.String(), msg.Answers, time.Now().Add(cacheTimeout)); err != nil { - return nil, err + if err := c.UpdateDnsCache(q.Name.String(), q.Type.String(), msg.Answers, time.Now().Add(cacheTimeout)); err != nil { + return err } - // Pack to get newData. - return &msg, nil + return nil } func (c *DnsController) UpdateDnsCache(host string, dnsTyp string, answers []dnsmessage.Resource, deadline time.Time) (err error) { @@ -407,7 +436,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) } // resp is valid. cache2 := c.LookupDnsRespCache(qname, qtype2) - if c.qtypePrefer == qtype || cache2 == nil { + if c.qtypePrefer == qtype || cache2 == nil || !cache2.IncludeAnyIp() { return sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn, req.lanWanFlag) } else { return c.sendReject_(dnsMessage, req) @@ -490,6 +519,31 @@ func (c *DnsController) handle_( // sendReject_ send empty answer. func (c *DnsController) sendReject_(dnsMessage *dnsmessage.Message, req *udpRequest) (err error) { dnsMessage.Answers = nil + if len(dnsMessage.Questions) > 0 { + q := dnsMessage.Questions[0] + switch typ := q.Type; typ { + case dnsmessage.TypeA: + dnsMessage.Answers = []dnsmessage.Resource{{ + Header: dnsmessage.ResourceHeader{ + Name: q.Name, + Type: typ, + Class: dnsmessage.ClassINET, + TTL: 0, + }, + Body: &dnsmessage.AResource{A: UnspecifiedAddressA.As4()}, + }} + case dnsmessage.TypeAAAA: + dnsMessage.Answers = []dnsmessage.Resource{{ + Header: dnsmessage.ResourceHeader{ + Name: q.Name, + Type: typ, + Class: dnsmessage.ClassINET, + TTL: 0, + }, + Body: &dnsmessage.AAAAResource{AAAA: UnspecifiedAddressAAAA.As16()}, + }} + } + } dnsMessage.RCode = dnsmessage.RCodeSuccess dnsMessage.Response = true dnsMessage.RecursionAvailable = true @@ -497,7 +551,7 @@ func (c *DnsController) sendReject_(dnsMessage *dnsmessage.Message, req *udpRequ if c.log.IsLevelEnabled(logrus.TraceLevel) { c.log.WithFields(logrus.Fields{ "question": dnsMessage.Questions, - }).Traceln("Reject with empty answer") + }).Traceln("Reject") } data, err := dnsMessage.Pack() if err != nil {