diff --git a/component/dns/dns.go b/component/dns/dns.go index e2b8358..6689263 100644 --- a/component/dns/dns.go +++ b/component/dns/dns.go @@ -50,7 +50,8 @@ func New(log *logrus.Logger, dns *config.Dns, opt *NewOption) (s *Dns, err error if tag == "" { return nil, fmt.Errorf("%w: '%v' has no tag", BadUpstreamFormatError, upstreamRaw) } - u, err := url.Parse(link) + var u *url.URL + u, err = url.Parse(link) if err != nil { return nil, fmt.Errorf("%w: %v", BadUpstreamFormatError, err) } @@ -73,9 +74,6 @@ func New(log *logrus.Logger, dns *config.Dns, opt *NewOption) (s *Dns, err error } upstreamName2Id[tag] = uint8(len(s.upstream)) s.upstream = append(s.upstream, r) - // Init immediately to avoid DNS leaking in the very beginning because param control_plane_dns_routing will - // be set in callback. - go r.GetUpstream() } // Optimize routings. if dns.Routing.Request.Rules, err = routing.ApplyRulesOptimizers(dns.Routing.Request.Rules, @@ -119,6 +117,12 @@ func New(log *logrus.Logger, dns *config.Dns, opt *NewOption) (s *Dns, err error return s, nil } +func (s *Dns) InitUpstreams() { + for _, upstream := range s.upstream { + upstream.GetUpstream() + } +} + func (s *Dns) RequestSelect(msg *dnsmessage.Message) (upstream *Upstream, err error) { if msg.Response { return nil, fmt.Errorf("DNS request expected but DNS response received") diff --git a/control/control_plane.go b/control/control_plane.go index 782d984..0c717dd 100644 --- a/control/control_plane.go +++ b/control/control_plane.go @@ -51,6 +51,9 @@ type ControlPlane struct { dialMode consts.DialMode routingMatcher *RoutingMatcher + + closed chan struct{} + ready chan struct{} } func NewControlPlane( @@ -302,7 +305,14 @@ func NewControlPlane( outbounds: outbounds, dialMode: dialMode, routingMatcher: routingMatcher, + closed: make(chan struct{}), + ready: make(chan struct{}), } + defer func() { + if err != nil { + close(c.closed) + } + }() /// DNS upstream. dnsUpstream, err := dns.New(log, dnsConfig, &dns.NewOption{ @@ -313,7 +323,7 @@ func NewControlPlane( } /// Dns controller. - c.dnsController, err = NewDnsController(dnsUpstream, &DnsControllerOption{ + if c.dnsController, err = NewDnsController(dnsUpstream, &DnsControllerOption{ Log: log, CacheAccessCallback: func(cache *DnsCache) (err error) { // Write mappings into eBPF map: @@ -331,8 +341,14 @@ func NewControlPlane( }, nil }, BestDialerChooser: c.chooseBestDnsDialer, - }) + }); err != nil { + return nil, err + } + // Init immediately to avoid DNS leaking in the very beginning because param control_plane_dns_routing will + // be set in callback. + dnsUpstream.InitUpstreams() + close(c.ready) return c, nil } @@ -342,6 +358,13 @@ func (c *ControlPlane) EjectBpf() *bpfObjects { } func (c *ControlPlane) dnsUpstreamReadyCallback(raw *url.URL, dnsUpstream *dns.Upstream) (err error) { + // Waiting for ready. + select { + case <-c.closed: + return nil + case <-c.ready: + } + /// Notify dialers to check. c.onceNetworkReady.Do(func() { for _, out := range c.outbounds { @@ -736,5 +759,6 @@ func (c *ControlPlane) Close() (err error) { } } } + close(c.closed) return c.core.Close() }