optimize: refine dns cache behaviour

This commit is contained in:
mzz2017
2023-03-17 13:13:42 +08:00
parent 936d9a773a
commit fe3f9c62c3
3 changed files with 54 additions and 14 deletions

View File

@ -71,7 +71,7 @@ var (
func Run(log *logrus.Logger, conf *config.Config) (err error) { func Run(log *logrus.Logger, conf *config.Config) (err error) {
// New ControlPlane. // New ControlPlane.
c, err := newControlPlane(log, nil, conf) c, err := newControlPlane(log, nil, nil, conf)
if err != nil { if err != nil {
return err return err
} }
@ -157,14 +157,15 @@ loop:
// New control plane. // New control plane.
obj := c.EjectBpf() obj := c.EjectBpf()
dnsCache := c.CloneDnsCache()
log.Warnln("[Reload] Load new control plane") log.Warnln("[Reload] Load new control plane")
newC, err := newControlPlane(log, obj, newConf) newC, err := newControlPlane(log, obj, dnsCache, newConf)
if err != nil { if err != nil {
log.WithFields(logrus.Fields{ log.WithFields(logrus.Fields{
"err": err, "err": err,
}).Errorln("[Reload] Failed to reload; try to roll back configuration") }).Errorln("[Reload] Failed to reload; try to roll back configuration")
// Load last config back. // Load last config back.
newC, err = newControlPlane(log, obj, conf) newC, err = newControlPlane(log, obj, dnsCache, conf)
if err != nil { if err != nil {
sdnotify.Stopping() sdnotify.Stopping()
obj.Close() obj.Close()
@ -201,7 +202,7 @@ loop:
return nil return nil
} }
func newControlPlane(log *logrus.Logger, bpf interface{}, conf *config.Config) (c *control.ControlPlane, err error) { func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*control.DnsCache, conf *config.Config) (c *control.ControlPlane, err error) {
// Deep copy to prevent modification. // Deep copy to prevent modification.
conf = deepcopy.Copy(conf).(*config.Config) conf = deepcopy.Copy(conf).(*config.Config)
@ -256,6 +257,7 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, conf *config.Config) (
c, err = control.NewControlPlane( c, err = control.NewControlPlane(
log, log,
bpf, bpf,
dnsCache,
tagToNodeList, tagToNodeList,
conf.Group, conf.Group,
&conf.Routing, &conf.Routing,

View File

@ -21,6 +21,7 @@ import (
"github.com/daeuniverse/dae/config" "github.com/daeuniverse/dae/config"
"github.com/daeuniverse/dae/pkg/config_parser" "github.com/daeuniverse/dae/pkg/config_parser"
internal "github.com/daeuniverse/dae/pkg/ebpf_internal" internal "github.com/daeuniverse/dae/pkg/ebpf_internal"
"github.com/mohae/deepcopy"
"github.com/mzz2017/softwind/pool" "github.com/mzz2017/softwind/pool"
"github.com/mzz2017/softwind/protocol/direct" "github.com/mzz2017/softwind/protocol/direct"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -64,6 +65,7 @@ type ControlPlane struct {
func NewControlPlane( func NewControlPlane(
log *logrus.Logger, log *logrus.Logger,
_bpf interface{}, _bpf interface{},
dnsCache map[string]*DnsCache,
tagToNodeList map[string][]string, tagToNodeList map[string][]string,
groups []config.Group, groups []config.Group,
routingA *config.Routing, routingA *config.Routing,
@ -114,8 +116,10 @@ func NewControlPlane(
} }
/// Load pre-compiled programs and maps into the kernel. /// Load pre-compiled programs and maps into the kernel.
log.Infof("Loading eBPF programs and maps into the kernel.") if _bpf == nil {
log.Infof("The loading process takes about 150MB free memory, which will be released after loading. Insufficient memory will cause loading failure.") log.Infof("Loading eBPF programs and maps into the kernel.")
log.Infof("The loading process takes about 150MB free memory, which will be released after loading. Insufficient memory will cause loading failure.")
}
//var bpf bpfObjects //var bpf bpfObjects
var ProgramOptions = ebpf.ProgramOptions{ var ProgramOptions = ebpf.ProgramOptions{
KernelTypes: nil, KernelTypes: nil,
@ -300,6 +304,7 @@ func NewControlPlane(
if err != nil { if err != nil {
return nil, fmt.Errorf("RoutingMatcherBuilder.BuildKernspace: %w", err) return nil, fmt.Errorf("RoutingMatcherBuilder.BuildKernspace: %w", err)
} }
/// Dial mode. /// Dial mode.
dialMode, err := consts.ParseDialMode(global.DialMode) dialMode, err := consts.ParseDialMode(global.DialMode)
if err != nil { if err != nil {
@ -356,6 +361,24 @@ func NewControlPlane(
}); err != nil { }); err != nil {
return nil, err return nil, err
} }
// Refresh domain routing cache with new routing.
if dnsCache != nil {
for cacheKey, cache := range dnsCache {
if time.Now().After(cache.Deadline) {
continue
}
lastDot := strings.LastIndex(cacheKey, ".")
if lastDot == -1 || lastDot == len(cacheKey)-1 {
// Not a valid key.
log.Warnln("Invalid cache key:", cacheKey)
continue
}
host := cacheKey[:lastDot]
typ := cacheKey[lastDot+1:]
_ = plane.dnsController.UpdateDnsCache(host, typ, cache.Answers, cache.Deadline)
}
}
// Init immediately to avoid DNS leaking in the very beginning because param control_plane_dns_routing will // Init immediately to avoid DNS leaking in the very beginning because param control_plane_dns_routing will
// be set in callback. // be set in callback.
go dnsUpstream.InitUpstreams() go dnsUpstream.InitUpstreams()
@ -372,6 +395,12 @@ func (c *ControlPlane) InjectBpf(bpf *bpfObjects) {
c.core.InjectBpf(bpf) c.core.InjectBpf(bpf)
} }
func (c *ControlPlane) CloneDnsCache() map[string]*DnsCache {
c.dnsController.dnsCacheMu.Lock()
defer c.dnsController.dnsCacheMu.Unlock()
return deepcopy.Copy(c.dnsController.dnsCache).(map[string]*DnsCache)
}
func (c *ControlPlane) dnsUpstreamReadyCallback(dnsUpstream *dns.Upstream) (err error) { func (c *ControlPlane) dnsUpstreamReadyCallback(dnsUpstream *dns.Upstream) (err error) {
// Waiting for ready. // Waiting for ready.
select { select {
@ -413,7 +442,7 @@ func (c *ControlPlane) dnsUpstreamReadyCallback(dnsUpstream *dns.Upstream) (err
A: dnsUpstream.Ip4.As4(), A: dnsUpstream.Ip4.As4(),
}, },
}} }}
if err = c.dnsController.UpdateDnsCache(dnsUpstream.Hostname, typ, answers, deadline); err != nil { if err = c.dnsController.UpdateDnsCache(dnsUpstream.Hostname, typ.String(), answers, deadline); err != nil {
return err return err
} }
} }
@ -431,7 +460,7 @@ func (c *ControlPlane) dnsUpstreamReadyCallback(dnsUpstream *dns.Upstream) (err
AAAA: dnsUpstream.Ip6.As16(), AAAA: dnsUpstream.Ip6.As16(),
}, },
}} }}
if err = c.dnsController.UpdateDnsCache(dnsUpstream.Hostname, typ, answers, deadline); err != nil { if err = c.dnsController.UpdateDnsCache(dnsUpstream.Hostname, typ.String(), answers, deadline); err != nil {
return err return err
} }
} }
@ -448,7 +477,7 @@ func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip
// Has A/AAAA records. It is a real domain. // Has A/AAAA records. It is a real domain.
dialMode = consts.DialMode_Domain dialMode = consts.DialMode_Domain
} else { } else {
// Check if the domain is in real domain set (bloom filter). // Check if the domain is in real-domain set (bloom filter).
c.muRealDomainSet.RLock() c.muRealDomainSet.RLock()
if c.realDomainSet.TestString(domain) { if c.realDomainSet.TestString(domain) {
c.muRealDomainSet.RUnlock() c.muRealDomainSet.RUnlock()

View File

@ -28,7 +28,8 @@ import (
) )
const ( const (
MaxDnsLookupDepth = 3 MaxDnsLookupDepth = 3
minFirefoxCacheTimeout = 120 * time.Second
) )
var ( var (
@ -80,7 +81,9 @@ func (c *DnsController) LookupDnsRespCache(domain string, t dnsmessage.Type) (ca
c.dnsCacheMu.Lock() c.dnsCacheMu.Lock()
cache, ok := c.dnsCache[strings.ToLower(domain)+t.String()] cache, ok := c.dnsCache[strings.ToLower(domain)+t.String()]
c.dnsCacheMu.Unlock() c.dnsCacheMu.Unlock()
if ok && cache.Deadline.After(now) { // We should make sure the remaining TTL is greater than 120s (minFirefoxCacheTimeout), or
// return nil and request a new lookup to refresh the cache.
if ok && cache.Deadline.After(now.Add(minFirefoxCacheTimeout)) {
return cache return cache
} }
return nil return nil
@ -194,14 +197,20 @@ loop:
"addition": FormatDnsRsc(msg.Additionals), "addition": FormatDnsRsc(msg.Additionals),
}).Tracef("Update DNS record cache") }).Tracef("Update DNS record cache")
} }
if err = c.UpdateDnsCache(q.Name.String(), q.Type, msg.Answers, time.Now().Add(time.Duration(ttl)*time.Second+DnsNatTimeout)); err != nil { cacheTimeout := time.Duration(ttl) * time.Second // TTL.
if cacheTimeout < minFirefoxCacheTimeout {
cacheTimeout = minFirefoxCacheTimeout
}
cacheTimeout += 5 * time.Second // DNS lookup timeout.
if err = c.UpdateDnsCache(q.Name.String(), q.Type.String(), msg.Answers, time.Now().Add(cacheTimeout)); err != nil {
return nil, err return nil, err
} }
// Pack to get newData. // Pack to get newData.
return &msg, nil return &msg, nil
} }
func (c *DnsController) UpdateDnsCache(host string, typ dnsmessage.Type, answers []dnsmessage.Resource, deadline time.Time) (err error) { func (c *DnsController) UpdateDnsCache(host string, dnsTyp string, answers []dnsmessage.Resource, deadline time.Time) (err error) {
var fqdn string var fqdn string
if strings.HasSuffix(host, ".") { if strings.HasSuffix(host, ".") {
fqdn = host fqdn = host
@ -213,7 +222,7 @@ func (c *DnsController) UpdateDnsCache(host string, typ dnsmessage.Type, answers
if _, err = netip.ParseAddr(host); err == nil { if _, err = netip.ParseAddr(host); err == nil {
return nil return nil
} }
cacheKey := fqdn + typ.String() cacheKey := fqdn + dnsTyp
c.dnsCacheMu.Lock() c.dnsCacheMu.Lock()
cache, ok := c.dnsCache[cacheKey] cache, ok := c.dnsCache[cacheKey]
if ok { if ok {