From 00b39df677937b00ca5878d8f14ab99fb09c6eac Mon Sep 17 00:00:00 2001 From: mzz <2017@duck.com> Date: Sun, 16 Jul 2023 11:28:28 +0800 Subject: [PATCH] fix(dns): should reject with nx instead of 0.0.0.0 (#141) --- control/control_plane.go | 15 +-- control/dns_cache.go | 7 +- control/dns_control.go | 211 ++++++++++++++++----------------------- 3 files changed, 99 insertions(+), 134 deletions(-) diff --git a/control/control_plane.go b/control/control_plane.go index 332af41..7e362db 100644 --- a/control/control_plane.go +++ b/control/control_plane.go @@ -409,11 +409,12 @@ func NewControlPlane( } return nil }, - NewCache: func(fqdn string, answers []dnsmessage.RR, deadline time.Time) (cache *DnsCache, err error) { + NewCache: func(fqdn string, answers []dnsmessage.RR, deadline time.Time, originalDeadline time.Time) (cache *DnsCache, err error) { return &DnsCache{ - DomainBitmap: plane.routingMatcher.domainMatcher.MatchDomainBitmap(fqdn), - Answer: answers, - Deadline: deadline, + DomainBitmap: plane.routingMatcher.domainMatcher.MatchDomainBitmap(fqdn), + Answer: answers, + Deadline: deadline, + OriginalDeadline: originalDeadline, }, nil }, BestDialerChooser: plane.chooseBestDnsDialer, @@ -423,7 +424,9 @@ func NewControlPlane( return nil, err } // Refresh domain routing cache with new routing. - if dnsCache != nil && len(dnsCache) > 0 { + // FIXME: We temperarily disable it because we want to make change of DNS section take effects immediately. + // TODO: Add change detection. + if false && len(dnsCache) > 0 { for cacheKey, cache := range dnsCache { // Also refresh out-dated routing because kernel map items have no expiration. lastDot := strings.LastIndex(cacheKey, ".") @@ -556,7 +559,7 @@ func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip if !outbound.IsReserved() && domain != "" { switch c.dialMode { case consts.DialMode_Domain: - if cache := c.dnsController.LookupDnsRespCache(domain, common.AddrToDnsType(dst.Addr())); cache != nil { + if cache := c.dnsController.LookupDnsRespCache(c.dnsController.cacheKey(domain, common.AddrToDnsType(dst.Addr())), true); cache != nil { // Has A/AAAA records. It is a real domain. dialMode = consts.DialMode_Domain } else { diff --git a/control/dns_cache.go b/control/dns_cache.go index 43122ea..a88c25d 100644 --- a/control/dns_cache.go +++ b/control/dns_cache.go @@ -14,9 +14,10 @@ import ( ) type DnsCache struct { - DomainBitmap []uint32 - Answer []dnsmessage.RR - Deadline time.Time + DomainBitmap []uint32 + Answer []dnsmessage.RR + Deadline time.Time + OriginalDeadline time.Time // This field is not impacted by `fixed_domain_ttl`. } func (c *DnsCache) FillInto(req *dnsmessage.Msg) { diff --git a/control/dns_control.go b/control/dns_control.go index 6c12927..e3d7d0b 100644 --- a/control/dns_control.go +++ b/control/dns_control.go @@ -59,7 +59,7 @@ type DnsControllerOption struct { Log *logrus.Logger CacheAccessCallback func(cache *DnsCache) (err error) CacheRemoveCallback func(cache *DnsCache) (err error) - NewCache func(fqdn string, answers []dnsmessage.RR, deadline time.Time) (cache *DnsCache, err error) + NewCache func(fqdn string, answers []dnsmessage.RR, deadline time.Time, originalDeadline time.Time) (cache *DnsCache, err error) BestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error) IpVersionPrefer int FixedDomainTtl map[string]int @@ -74,7 +74,7 @@ type DnsController struct { log *logrus.Logger cacheAccessCallback func(cache *DnsCache) (err error) cacheRemoveCallback func(cache *DnsCache) (err error) - newCache func(fqdn string, answers []dnsmessage.RR, deadline time.Time) (cache *DnsCache, err error) + newCache func(fqdn string, answers []dnsmessage.RR, deadline time.Time, originalDeadline time.Time) (cache *DnsCache, err error) bestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error) fixedDomainTtl map[string]int @@ -124,37 +124,42 @@ func (c *DnsController) cacheKey(qname string, qtype uint16) string { return dnsmessage.CanonicalName(qname) + strconv.Itoa(int(qtype)) } -func (c *DnsController) RemoveDnsRespCache(qname string, qtype uint16) { +func (c *DnsController) RemoveDnsRespCache(cacheKey string) { c.dnsCacheMu.Lock() - key := c.cacheKey(qname, qtype) - _, ok := c.dnsCache[key] + _, ok := c.dnsCache[cacheKey] if ok { - delete(c.dnsCache, key) + delete(c.dnsCache, cacheKey) } c.dnsCacheMu.Unlock() } -func (c *DnsController) LookupDnsRespCache(qname string, qtype uint16) (cache *DnsCache) { +func (c *DnsController) LookupDnsRespCache(cacheKey string, ignoreFixedTtl bool) (cache *DnsCache) { c.dnsCacheMu.Lock() - cache, ok := c.dnsCache[c.cacheKey(qname, qtype)] + cache, ok := c.dnsCache[cacheKey] c.dnsCacheMu.Unlock() + if !ok { + return nil + } + var deadline time.Time + if !ignoreFixedTtl { + deadline = cache.Deadline + } else { + deadline = cache.OriginalDeadline + } // We should make sure the cache did not expire, or // return nil and request a new lookup to refresh the cache. - if ok && cache.Deadline.After(time.Now()) { - return cache + if !deadline.After(time.Now()) { + return nil } - return nil + if err := c.cacheAccessCallback(cache); err != nil { + c.log.Warnf("failed to BatchUpdateDomainRouting: %v", err) + return nil + } + return cache } // LookupDnsRespCache_ will modify the msg in place. -func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Msg) (resp []byte) { - if len(msg.Question) == 0 { - return nil - } - q := msg.Question[0] - if msg.Response { - return nil - } - cache := c.LookupDnsRespCache(q.Name, q.Qtype) +func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Msg, cacheKey string, ignoreFixedTtl bool) (resp []byte) { + cache := c.LookupDnsRespCache(cacheKey, ignoreFixedTtl) if cache != nil { cache.FillInto(msg) b, err := msg.Pack() @@ -162,31 +167,23 @@ func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Msg) (resp []byte) { c.log.Warnf("failed to pack: %v", err) return nil } - if err = c.cacheAccessCallback(cache); err != nil { - c.log.Warnf("failed to BatchUpdateDomainRouting: %v", err) - return nil - } return b } return nil } -// DnsRespHandler handle DNS resp. -func (c *DnsController) DnsRespHandler(data []byte) (newMsg *dnsmessage.Msg, err error) { - var msg dnsmessage.Msg - if err = msg.Unpack(data); err != nil { - return nil, fmt.Errorf("unpack dns pkt: %w", err) - } +// NormalizeAndCacheDnsResp_ handle DNS resp in place. +func (c *DnsController) NormalizeAndCacheDnsResp_(msg *dnsmessage.Msg) (err error) { // Check healthy resp. if !msg.Response || len(msg.Question) == 0 { - return &msg, nil + return nil } q := msg.Question[0] // Check suc resp. if msg.Rcode != dnsmessage.RcodeSuccess { - return &msg, nil + return nil } // Get TTL. @@ -207,10 +204,10 @@ func (c *DnsController) DnsRespHandler(data []byte) (newMsg *dnsmessage.Msg, err case dnsmessage.TypeA, dnsmessage.TypeAAAA: default: // Update DnsCache. - if err = c.updateDnsCache(&msg, ttl, &q); err != nil { - return nil, err + if err = c.updateDnsCache(msg, ttl, &q); err != nil { + return err } - return &msg, nil + return nil } // Set ttl. @@ -232,18 +229,18 @@ loop: } if !reqIpRecord { // Update DnsCache. - if err = c.updateDnsCache(&msg, ttl, &q); err != nil { - return nil, err + if err = c.updateDnsCache(msg, ttl, &q); err != nil { + return err } - return &msg, nil + return nil } // Update DnsCache. - if err = c.updateDnsCache(&msg, ttl, &q); err != nil { - return nil, err + if err = c.updateDnsCache(msg, ttl, &q); err != nil { + return err } // Pack to get newData. - return &msg, nil + return nil } func (c *DnsController) updateDnsCache(msg *dnsmessage.Msg, ttl uint32, q *dnsmessage.Question) error { @@ -262,7 +259,9 @@ func (c *DnsController) updateDnsCache(msg *dnsmessage.Msg, ttl uint32, q *dnsme return nil } -func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp uint16, answers []dnsmessage.RR, deadlineFunc func(now time.Time, host string) time.Time) (err error) { +type daedlineFunc func(now time.Time, host string) (deadline time.Time, originalDeadline time.Time) + +func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp uint16, answers []dnsmessage.RR, deadlineFunc daedlineFunc) (err error) { var fqdn string if strings.HasSuffix(host, ".") { fqdn = strings.ToLower(host) @@ -276,7 +275,7 @@ func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp uint16, ans } now := time.Now() - deadline := deadlineFunc(now, host) + deadline, originalDeadline := deadlineFunc(now, host) cacheKey := c.cacheKey(fqdn, dnsTyp) c.dnsCacheMu.Lock() @@ -284,9 +283,10 @@ func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp uint16, ans if ok { cache.Answer = answers cache.Deadline = deadline + cache.OriginalDeadline = originalDeadline c.dnsCacheMu.Unlock() } else { - cache, err = c.newCache(fqdn, answers, deadline) + cache, err = c.newCache(fqdn, answers, deadline, originalDeadline) if err != nil { c.dnsCacheMu.Unlock() return err @@ -302,42 +302,29 @@ func (c *DnsController) __updateDnsCacheDeadline(host string, dnsTyp uint16, ans } func (c *DnsController) UpdateDnsCacheDeadline(host string, dnsTyp uint16, answers []dnsmessage.RR, deadline time.Time) (err error) { - return c.__updateDnsCacheDeadline(host, dnsTyp, answers, func(now time.Time, host string) time.Time { + return c.__updateDnsCacheDeadline(host, dnsTyp, answers, func(now time.Time, host string) (daedline time.Time, originalDeadline time.Time) { if fixedTtl, ok := c.fixedDomainTtl[host]; ok { /// NOTICE: Cannot set TTL accurately. if now.Sub(deadline).Seconds() > float64(fixedTtl) { - return now.Add(time.Duration(fixedTtl) * time.Second) + deadline := now.Add(time.Duration(fixedTtl) * time.Second) + return deadline, deadline } } - return deadline + return deadline, deadline }) } func (c *DnsController) UpdateDnsCacheTtl(host string, dnsTyp uint16, answers []dnsmessage.RR, ttl int) (err error) { - return c.__updateDnsCacheDeadline(host, dnsTyp, answers, func(now time.Time, host string) time.Time { + return c.__updateDnsCacheDeadline(host, dnsTyp, answers, func(now time.Time, host string) (daedline time.Time, originalDeadline time.Time) { + originalDeadline = now.Add(time.Duration(ttl) * time.Second) if fixedTtl, ok := c.fixedDomainTtl[host]; ok { - return now.Add(time.Duration(fixedTtl) * time.Second) + return now.Add(time.Duration(fixedTtl) * time.Second), originalDeadline } else { - return now.Add(time.Duration(ttl) * time.Second) + return originalDeadline, originalDeadline } }) } -func (c *DnsController) DnsRespHandlerFactory() func(data []byte, from netip.AddrPort) (msg *dnsmessage.Msg, err error) { - return func(data []byte, from netip.AddrPort) (msg *dnsmessage.Msg, err error) { - // Do not return conn-unrelated err in this func. - - msg, err = c.DnsRespHandler(data) - if err != nil { - if c.log.IsLevelEnabled(logrus.DebugLevel) { - c.log.Debugf("DnsRespHandler: %v", err) - } - return nil, err - } - return msg, nil - } -} - type udpRequest struct { lanWanFlag consts.LanWanFlag realSrc netip.AddrPort @@ -412,8 +399,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Msg, req *udpRequest) (er } // Join results and consider whether to response. - dnsMessage.Response = false - resp := c.LookupDnsRespCache_(dnsMessage) + resp := c.LookupDnsRespCache_(dnsMessage, c.cacheKey(qname, qtype), true) if resp == nil { // resp is not valid. c.log.WithFields(logrus.Fields{ @@ -422,7 +408,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Msg, req *udpRequest) (er return c.sendReject_(dnsMessage, req) } // resp is valid. - cache2 := c.LookupDnsRespCache(qname, qtype2) + cache2 := c.LookupDnsRespCache(c.cacheKey(qname, qtype2), true) if c.qtypePrefer == qtype || cache2 == nil || !cache2.IncludeAnyIp() { return sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn, req.lanWanFlag) } else { @@ -450,19 +436,22 @@ func (c *DnsController) handle_( return err } + cacheKey := c.cacheKey(qname, qtype) + if upstreamIndex == consts.DnsRequestOutboundIndex_Reject { // Reject with empty answer. - c.RemoveDnsRespCache(qname, qtype) + c.RemoveDnsRespCache(cacheKey) return c.sendReject_(dnsMessage, req) } // No parallel for the same lookup. - _mu, _ := c.handling.LoadOrStore(c.cacheKey(qname, qtype), new(sync.Mutex)) + _mu, _ := c.handling.LoadOrStore(cacheKey, new(sync.Mutex)) mu := _mu.(*sync.Mutex) mu.Lock() defer mu.Unlock() + defer c.handling.Delete(cacheKey) - if resp := c.LookupDnsRespCache_(dnsMessage); resp != nil { + if resp := c.LookupDnsRespCache_(dnsMessage, cacheKey, false); resp != nil { // Send cache to client directly. if needResp { if err = sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn, req.lanWanFlag); err != nil { @@ -500,31 +489,6 @@ func (c *DnsController) handle_( // sendReject_ send empty answer. func (c *DnsController) sendReject_(dnsMessage *dnsmessage.Msg, req *udpRequest) (err error) { dnsMessage.Answer = nil - if len(dnsMessage.Question) > 0 { - q := dnsMessage.Question[0] - switch typ := q.Qtype; typ { - case dnsmessage.TypeA: - dnsMessage.Answer = []dnsmessage.RR{&dnsmessage.A{ - Hdr: dnsmessage.RR_Header{ - Name: q.Name, - Rrtype: typ, - Class: dnsmessage.ClassINET, - Ttl: 0, - }, - A: UnspecifiedAddressA.AsSlice(), - }} - case dnsmessage.TypeAAAA: - dnsMessage.Answer = []dnsmessage.RR{&dnsmessage.AAAA{ - Hdr: dnsmessage.RR_Header{ - Name: q.Name, - Rrtype: typ, - Class: dnsmessage.ClassINET, - Ttl: 0, - }, - AAAA: UnspecifiedAddressAAAA.AsSlice(), - }} - } - } dnsMessage.Rcode = dnsmessage.RcodeSuccess dnsMessage.Response = true dnsMessage.RecursionAvailable = true @@ -582,8 +546,6 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte IsDns: true, // UDP relies on DNS check result. } - // dnsRespHandler caches dns response and check rush answers. - dnsRespHandler := c.DnsRespHandlerFactory() // Dial and send. var respMsg *dnsmessage.Msg // defer in a recursive call will delay Close(), thus we Close() before @@ -621,7 +583,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte dnsReqCtx, cancelDnsReqCtx := context.WithTimeout(context.TODO(), 5*time.Second) defer cancelDnsReqCtx() go func() { - // Send DNS request at 0, 2, 4 seconds. + // Send DNS request every seconds. for { _, err = conn.Write(data) if err != nil { @@ -641,7 +603,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte select { case <-dnsReqCtx.Done(): return - case <-time.After(2 * time.Second): + case <-time.After(1 * time.Second): } } }() @@ -649,21 +611,18 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte // We can block here because we are in a coroutine. respBuf := pool.GetFullCap(consts.EthernetMtu) defer pool.Put(respBuf) - for { - // Wait for response. - n, err := conn.Read(respBuf) - if err != nil { - return fmt.Errorf("failed to read from: %v (dialer: %v): %w", dialArgument.bestTarget, dialArgument.bestDialer.Property().Name, err) - } - respMsg, err = dnsRespHandler(respBuf[:n], dialArgument.bestTarget) - if err != nil { - return err - } - if respMsg != nil { - cancelDnsReqCtx() - break - } + // Wait for response. + n, err := conn.Read(respBuf) + if err != nil { + return fmt.Errorf("failed to read from: %v (dialer: %v): %w", dialArgument.bestTarget, dialArgument.bestDialer.Property().Name, err) } + var msg dnsmessage.Msg + if err = msg.Unpack(respBuf[:n]); err != nil { + return err + } + respMsg = &msg + cancelDnsReqCtx() + case consts.L4ProtoStr_TCP: // We can block here because we are in a coroutine. @@ -677,7 +636,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte } }() - _ = conn.SetDeadline(time.Now().Add(5 * time.Second)) + _ = conn.SetDeadline(time.Now().Add(4900 * time.Millisecond)) // We should write two byte length in the front of TCP DNS request. bReq := pool.Get(2 + len(data)) defer pool.Put(bReq) @@ -705,13 +664,11 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte if n, err = io.ReadFull(conn, buf[:respLen]); err != nil { return fmt.Errorf("failed to read DNS resp payload: %w", err) } - respMsg, err = dnsRespHandler(buf[:n], dialArgument.bestTarget) - if respMsg == nil && err == nil { - err = fmt.Errorf("bad DNS response") - } - if err != nil { - return fmt.Errorf("failed to write DNS resp to client: %w", err) + var msg dnsmessage.Msg + if err = msg.Unpack(buf[:n]); err != nil { + return err } + respMsg = &msg default: return fmt.Errorf("unexpected l4proto: %v", dialArgument.l4proto) } @@ -743,6 +700,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte "upstream": upstreamName, }).Traceln("Reject with empty answer") } + // We also cache response reject. default: if c.log.IsLevelEnabled(logrus.TraceLevel) { c.log.WithFields(logrus.Fields{ @@ -783,13 +741,16 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte return fmt.Errorf("unknown upstream: %v", upstreamIndex.String()) } } - // Keep the id the same with request. - respMsg.Id = id - data, err = respMsg.Pack() - if err != nil { + if err = c.NormalizeAndCacheDnsResp_(respMsg); err != nil { return err } if needResp { + // Keep the id the same with request. + respMsg.Id = id + data, err = respMsg.Pack() + if err != nil { + return err + } if err = sendPkt(data, req.realDst, req.realSrc, req.src, req.lConn, req.lanWanFlag); err != nil { return err }