From 218ae3f65432a2d84063a2987140ea4b40f60239 Mon Sep 17 00:00:00 2001 From: mzz <2017@duck.com> Date: Thu, 26 Sep 2024 22:40:29 +0800 Subject: [PATCH] fix: connection leaks (#624) Co-authored-by: dae-prow[bot] <136105375+dae-prow[bot]@users.noreply.github.com> --- .gitignore | 2 ++ cmd/run.go | 17 +++++++++-------- common/netutils/dns.go | 3 +-- component/outbound/dialer/connectivity_check.go | 3 +-- control/dns_control.go | 7 ++----- control/tcp.go | 5 +---- control/udp_endpoint_pool.go | 5 +---- go.mod | 2 +- go.sum | 4 ++-- pkg/logger/logger.go | 6 +----- 10 files changed, 21 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index 102649a..25c56d8 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ *.tmp bpf_bpfeb*.go bpf_bpfel*.go +bpf_*_bpfeb*.go +bpf_*_bpfel*.go dae outline.json go-mod/ diff --git a/cmd/run.go b/cmd/run.go index f7dde16..16f2fc5 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -105,8 +105,9 @@ var ( Compress: true, } } - log := logger.NewLogger(conf.Global.LogLevel, disableTimestamp, logOpts) - logrus.SetLevel(log.Level) + log := logrus.New() + logger.SetLogger(log, conf.Global.LogLevel, disableTimestamp, logOpts) + logger.SetLogger(logrus.StandardLogger(), conf.Global.LogLevel, disableTimestamp, logOpts) log.Infof("Include config files: [%v]", strings.Join(includes, ", ")) if err := Run(log, conf, []string{filepath.Dir(cfgFile)}); err != nil { @@ -238,9 +239,11 @@ loop: } // New logger. oldLogOutput := log.Out - log = logger.NewLogger(newConf.Global.LogLevel, disableTimestamp, nil) + log = logrus.New() + logger.SetLogger(log, newConf.Global.LogLevel, disableTimestamp, nil) + logger.SetLogger(logrus.StandardLogger(), newConf.Global.LogLevel, disableTimestamp, nil) log.SetOutput(oldLogOutput) // FIXME: THIS IS A HACK. - logrus.SetLevel(log.Level) + logrus.SetOutput(oldLogOutput) // New control plane. obj := c.EjectBpf() @@ -330,8 +333,7 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c client := http.Client{ Transport: &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { - cd := netproxy.ContextDialerConverter{Dialer: direct.SymmetricDirect} - conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr) + conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr) if err != nil { return nil, err } @@ -372,8 +374,7 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c client := http.Client{ Transport: &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { - cd := netproxy.ContextDialerConverter{Dialer: direct.SymmetricDirect} - conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr) + conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr) if err != nil { return nil, err } diff --git a/common/netutils/dns.go b/common/netutils/dns.go index 7c37447..fdb3eb6 100644 --- a/common/netutils/dns.go +++ b/common/netutils/dns.go @@ -207,8 +207,7 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st } // Dial and write. - cd := &netproxy.ContextDialerConverter{Dialer: d} - c, err := cd.DialContext(ctx, network, dns.String()) + c, err := d.DialContext(ctx, network, dns.String()) if err != nil { return nil, err } diff --git a/component/outbound/dialer/connectivity_check.go b/component/outbound/dialer/connectivity_check.go index 4d330e3..a725a96 100644 --- a/component/outbound/dialer/connectivity_check.go +++ b/component/outbound/dialer/connectivity_check.go @@ -600,12 +600,11 @@ func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr, if method == "" { method = http.MethodGet } - cd := &netproxy.ContextDialerConverter{Dialer: d.Dialer} cli := http.Client{ Transport: &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { // Force to dial "ip". - conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", soMark, mptcp), net.JoinHostPort(ip.String(), u.Port())) + conn, err := d.Dialer.DialContext(ctx, common.MagicNetwork("tcp", soMark, mptcp), net.JoinHostPort(ip.String(), u.Port())) if err != nil { return nil, err } diff --git a/control/dns_control.go b/control/dns_control.go index b6cbaee..ac653e8 100644 --- a/control/dns_control.go +++ b/control/dns_control.go @@ -562,16 +562,13 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte ctxDial, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout) defer cancel() - bestContextDialer := netproxy.ContextDialerConverter{ - Dialer: dialArgument.bestDialer, - } switch dialArgument.l4proto { case consts.L4ProtoStr_UDP: // Get udp endpoint. // TODO: connection pool. - conn, err = bestContextDialer.DialContext( + conn, err = dialArgument.bestDialer.DialContext( ctxDial, common.MagicNetwork("udp", dialArgument.mark, dialArgument.mptcp), dialArgument.bestTarget.String(), @@ -636,7 +633,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte case consts.L4ProtoStr_TCP: // We can block here because we are in a coroutine. - conn, err = bestContextDialer.DialContext(ctxDial, common.MagicNetwork("tcp", dialArgument.mark, dialArgument.mptcp), dialArgument.bestTarget.String()) + conn, err = dialArgument.bestDialer.DialContext(ctxDial, common.MagicNetwork("tcp", dialArgument.mark, dialArgument.mptcp), dialArgument.bestTarget.String()) if err != nil { return fmt.Errorf("failed to dial proxy to tcp: %w", err) } diff --git a/control/tcp.go b/control/tcp.go index acdb14b..67a0544 100644 --- a/control/tcp.go +++ b/control/tcp.go @@ -164,10 +164,7 @@ func (c *ControlPlane) RouteDialTcp(p *RouteDialParam) (conn netproxy.Conn, err } ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout) defer cancel() - cd := netproxy.ContextDialerConverter{ - Dialer: d, - } - return cd.DialContext(ctx, common.MagicNetwork("tcp", routingResult.Mark, c.mptcp), dialTarget) + return d.DialContext(ctx, common.MagicNetwork("tcp", routingResult.Mark, c.mptcp), dialTarget) } type WriteCloser interface { diff --git a/control/udp_endpoint_pool.go b/control/udp_endpoint_pool.go index 1bb29e0..da7ef00 100644 --- a/control/udp_endpoint_pool.go +++ b/control/udp_endpoint_pool.go @@ -134,12 +134,9 @@ begin: if err != nil { return nil, false, err } - cd := netproxy.ContextDialerConverter{ - Dialer: dialOption.Dialer, - } ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout) defer cancel() - udpConn, err := cd.DialContext(ctx, dialOption.Network, dialOption.Target) + udpConn, err := dialOption.Dialer.DialContext(ctx, dialOption.Network, dialOption.Target) if err != nil { return nil, true, err } diff --git a/go.mod b/go.mod index 573ed27..1f78498 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/bits-and-blooms/bloom/v3 v3.5.0 github.com/cilium/ebpf v0.12.3 github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d - github.com/daeuniverse/outbound v0.0.0-20240911144232-d470a59233a5 + github.com/daeuniverse/outbound v0.0.0-20240926143218-3cf58cdd942f github.com/fsnotify/fsnotify v1.7.0 github.com/json-iterator/go v1.1.12 github.com/mholt/archiver/v3 v3.5.1 diff --git a/go.sum b/go.sum index e9f0383..3e82961 100644 --- a/go.sum +++ b/go.sum @@ -23,8 +23,8 @@ github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBS github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d h1:hnC39MjR7xt5kZjrKlef7DXKFDkiX8MIcDXYC/6Jf9Q= github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d/go.mod h1:VGWGgv7pCP5WGyHGUyb9+nq/gW0yBm+i/GfCNATOJ1M= -github.com/daeuniverse/outbound v0.0.0-20240911144232-d470a59233a5 h1:L450vqT1TO+Ygzd8buBMna8d4/0asT0q74qitGTWSl4= -github.com/daeuniverse/outbound v0.0.0-20240911144232-d470a59233a5/go.mod h1:0dkFMC58MVUWMB19jwQuXEg1G16uAIAtdAU7v+yWXYs= +github.com/daeuniverse/outbound v0.0.0-20240926143218-3cf58cdd942f h1:HB2IMJcU6FqLFqgDHbhhK9F0At6AFfpDRKk/oZz3T2A= +github.com/daeuniverse/outbound v0.0.0-20240926143218-3cf58cdd942f/go.mod h1:0dkFMC58MVUWMB19jwQuXEg1G16uAIAtdAU7v+yWXYs= github.com/daeuniverse/quic-go v0.0.0-20240413031024-943f218e0810 h1:YtEYouFaNrg9sV9vf3UabvKShKn6sD0QaCdOxCwaF3g= github.com/daeuniverse/quic-go v0.0.0-20240413031024-943f218e0810/go.mod h1:61o2uZUGLrlv1i+oO2rx9sVX0vbf8cHzdSHt7h6lMnM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index a1267bb..487b39d 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -11,9 +11,7 @@ import ( "gopkg.in/natefinch/lumberjack.v2" ) -func NewLogger(logLevel string, disableTimestamp bool, logFileOpt *lumberjack.Logger) *logrus.Logger { - log := logrus.New() - +func SetLogger(log *logrus.Logger, logLevel string, disableTimestamp bool, logFileOpt *lumberjack.Logger) { level, err := logrus.ParseLevel(logLevel) if err != nil { level = logrus.InfoLevel @@ -28,6 +26,4 @@ func NewLogger(logLevel string, disableTimestamp bool, logFileOpt *lumberjack.Lo if logFileOpt != nil { log.SetOutput(logFileOpt) } - - return log }