fix: connection leaks (#624)

Co-authored-by: dae-prow[bot] <136105375+dae-prow[bot]@users.noreply.github.com>
This commit is contained in:
mzz 2024-09-26 22:40:29 +08:00 committed by GitHub
parent da8890c38a
commit 218ae3f654
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 21 additions and 33 deletions

2
.gitignore vendored
View File

@ -4,6 +4,8 @@
*.tmp *.tmp
bpf_bpfeb*.go bpf_bpfeb*.go
bpf_bpfel*.go bpf_bpfel*.go
bpf_*_bpfeb*.go
bpf_*_bpfel*.go
dae dae
outline.json outline.json
go-mod/ go-mod/

View File

@ -105,8 +105,9 @@ var (
Compress: true, Compress: true,
} }
} }
log := logger.NewLogger(conf.Global.LogLevel, disableTimestamp, logOpts) log := logrus.New()
logrus.SetLevel(log.Level) logger.SetLogger(log, conf.Global.LogLevel, disableTimestamp, logOpts)
logger.SetLogger(logrus.StandardLogger(), conf.Global.LogLevel, disableTimestamp, logOpts)
log.Infof("Include config files: [%v]", strings.Join(includes, ", ")) log.Infof("Include config files: [%v]", strings.Join(includes, ", "))
if err := Run(log, conf, []string{filepath.Dir(cfgFile)}); err != nil { if err := Run(log, conf, []string{filepath.Dir(cfgFile)}); err != nil {
@ -238,9 +239,11 @@ loop:
} }
// New logger. // New logger.
oldLogOutput := log.Out oldLogOutput := log.Out
log = logger.NewLogger(newConf.Global.LogLevel, disableTimestamp, nil) log = logrus.New()
logger.SetLogger(log, newConf.Global.LogLevel, disableTimestamp, nil)
logger.SetLogger(logrus.StandardLogger(), newConf.Global.LogLevel, disableTimestamp, nil)
log.SetOutput(oldLogOutput) // FIXME: THIS IS A HACK. log.SetOutput(oldLogOutput) // FIXME: THIS IS A HACK.
logrus.SetLevel(log.Level) logrus.SetOutput(oldLogOutput)
// New control plane. // New control plane.
obj := c.EjectBpf() obj := c.EjectBpf()
@ -330,8 +333,7 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c
client := http.Client{ client := http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
cd := netproxy.ContextDialerConverter{Dialer: direct.SymmetricDirect} conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr)
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -372,8 +374,7 @@ func newControlPlane(log *logrus.Logger, bpf interface{}, dnsCache map[string]*c
client := http.Client{ client := http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
cd := netproxy.ContextDialerConverter{Dialer: direct.SymmetricDirect} conn, err := direct.SymmetricDirect.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr)
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", conf.Global.SoMarkFromDae, conf.Global.Mptcp), addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -207,8 +207,7 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
} }
// Dial and write. // Dial and write.
cd := &netproxy.ContextDialerConverter{Dialer: d} c, err := d.DialContext(ctx, network, dns.String())
c, err := cd.DialContext(ctx, network, dns.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -600,12 +600,11 @@ func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr,
if method == "" { if method == "" {
method = http.MethodGet method = http.MethodGet
} }
cd := &netproxy.ContextDialerConverter{Dialer: d.Dialer}
cli := http.Client{ cli := http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
// Force to dial "ip". // Force to dial "ip".
conn, err := cd.DialContext(ctx, common.MagicNetwork("tcp", soMark, mptcp), net.JoinHostPort(ip.String(), u.Port())) conn, err := d.Dialer.DialContext(ctx, common.MagicNetwork("tcp", soMark, mptcp), net.JoinHostPort(ip.String(), u.Port()))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -562,16 +562,13 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
ctxDial, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout) ctxDial, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout)
defer cancel() defer cancel()
bestContextDialer := netproxy.ContextDialerConverter{
Dialer: dialArgument.bestDialer,
}
switch dialArgument.l4proto { switch dialArgument.l4proto {
case consts.L4ProtoStr_UDP: case consts.L4ProtoStr_UDP:
// Get udp endpoint. // Get udp endpoint.
// TODO: connection pool. // TODO: connection pool.
conn, err = bestContextDialer.DialContext( conn, err = dialArgument.bestDialer.DialContext(
ctxDial, ctxDial,
common.MagicNetwork("udp", dialArgument.mark, dialArgument.mptcp), common.MagicNetwork("udp", dialArgument.mark, dialArgument.mptcp),
dialArgument.bestTarget.String(), dialArgument.bestTarget.String(),
@ -636,7 +633,7 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
case consts.L4ProtoStr_TCP: case consts.L4ProtoStr_TCP:
// We can block here because we are in a coroutine. // We can block here because we are in a coroutine.
conn, err = bestContextDialer.DialContext(ctxDial, common.MagicNetwork("tcp", dialArgument.mark, dialArgument.mptcp), dialArgument.bestTarget.String()) conn, err = dialArgument.bestDialer.DialContext(ctxDial, common.MagicNetwork("tcp", dialArgument.mark, dialArgument.mptcp), dialArgument.bestTarget.String())
if err != nil { if err != nil {
return fmt.Errorf("failed to dial proxy to tcp: %w", err) return fmt.Errorf("failed to dial proxy to tcp: %w", err)
} }

View File

@ -164,10 +164,7 @@ func (c *ControlPlane) RouteDialTcp(p *RouteDialParam) (conn netproxy.Conn, err
} }
ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout) ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout)
defer cancel() defer cancel()
cd := netproxy.ContextDialerConverter{ return d.DialContext(ctx, common.MagicNetwork("tcp", routingResult.Mark, c.mptcp), dialTarget)
Dialer: d,
}
return cd.DialContext(ctx, common.MagicNetwork("tcp", routingResult.Mark, c.mptcp), dialTarget)
} }
type WriteCloser interface { type WriteCloser interface {

View File

@ -134,12 +134,9 @@ begin:
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
cd := netproxy.ContextDialerConverter{
Dialer: dialOption.Dialer,
}
ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout) ctx, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout)
defer cancel() defer cancel()
udpConn, err := cd.DialContext(ctx, dialOption.Network, dialOption.Target) udpConn, err := dialOption.Dialer.DialContext(ctx, dialOption.Network, dialOption.Target)
if err != nil { if err != nil {
return nil, true, err return nil, true, err
} }

2
go.mod
View File

@ -8,7 +8,7 @@ require (
github.com/bits-and-blooms/bloom/v3 v3.5.0 github.com/bits-and-blooms/bloom/v3 v3.5.0
github.com/cilium/ebpf v0.12.3 github.com/cilium/ebpf v0.12.3
github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d
github.com/daeuniverse/outbound v0.0.0-20240911144232-d470a59233a5 github.com/daeuniverse/outbound v0.0.0-20240926143218-3cf58cdd942f
github.com/fsnotify/fsnotify v1.7.0 github.com/fsnotify/fsnotify v1.7.0
github.com/json-iterator/go v1.1.12 github.com/json-iterator/go v1.1.12
github.com/mholt/archiver/v3 v3.5.1 github.com/mholt/archiver/v3 v3.5.1

4
go.sum
View File

@ -23,8 +23,8 @@ github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBS
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d h1:hnC39MjR7xt5kZjrKlef7DXKFDkiX8MIcDXYC/6Jf9Q= github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d h1:hnC39MjR7xt5kZjrKlef7DXKFDkiX8MIcDXYC/6Jf9Q=
github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d/go.mod h1:VGWGgv7pCP5WGyHGUyb9+nq/gW0yBm+i/GfCNATOJ1M= github.com/daeuniverse/dae-config-dist/go/dae_config v0.0.0-20230604120805-1c27619b592d/go.mod h1:VGWGgv7pCP5WGyHGUyb9+nq/gW0yBm+i/GfCNATOJ1M=
github.com/daeuniverse/outbound v0.0.0-20240911144232-d470a59233a5 h1:L450vqT1TO+Ygzd8buBMna8d4/0asT0q74qitGTWSl4= github.com/daeuniverse/outbound v0.0.0-20240926143218-3cf58cdd942f h1:HB2IMJcU6FqLFqgDHbhhK9F0At6AFfpDRKk/oZz3T2A=
github.com/daeuniverse/outbound v0.0.0-20240911144232-d470a59233a5/go.mod h1:0dkFMC58MVUWMB19jwQuXEg1G16uAIAtdAU7v+yWXYs= github.com/daeuniverse/outbound v0.0.0-20240926143218-3cf58cdd942f/go.mod h1:0dkFMC58MVUWMB19jwQuXEg1G16uAIAtdAU7v+yWXYs=
github.com/daeuniverse/quic-go v0.0.0-20240413031024-943f218e0810 h1:YtEYouFaNrg9sV9vf3UabvKShKn6sD0QaCdOxCwaF3g= github.com/daeuniverse/quic-go v0.0.0-20240413031024-943f218e0810 h1:YtEYouFaNrg9sV9vf3UabvKShKn6sD0QaCdOxCwaF3g=
github.com/daeuniverse/quic-go v0.0.0-20240413031024-943f218e0810/go.mod h1:61o2uZUGLrlv1i+oO2rx9sVX0vbf8cHzdSHt7h6lMnM= github.com/daeuniverse/quic-go v0.0.0-20240413031024-943f218e0810/go.mod h1:61o2uZUGLrlv1i+oO2rx9sVX0vbf8cHzdSHt7h6lMnM=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

View File

@ -11,9 +11,7 @@ import (
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
) )
func NewLogger(logLevel string, disableTimestamp bool, logFileOpt *lumberjack.Logger) *logrus.Logger { func SetLogger(log *logrus.Logger, logLevel string, disableTimestamp bool, logFileOpt *lumberjack.Logger) {
log := logrus.New()
level, err := logrus.ParseLevel(logLevel) level, err := logrus.ParseLevel(logLevel)
if err != nil { if err != nil {
level = logrus.InfoLevel level = logrus.InfoLevel
@ -28,6 +26,4 @@ func NewLogger(logLevel string, disableTimestamp bool, logFileOpt *lumberjack.Lo
if logFileOpt != nil { if logFileOpt != nil {
log.SetOutput(logFileOpt) log.SetOutput(logFileOpt)
} }
return log
} }