From 9f7a49b81d5f3b04889f90b02d7162effbdc0047 Mon Sep 17 00:00:00 2001 From: mzz <2017@duck.com> Date: Tue, 14 Nov 2023 16:26:33 +0800 Subject: [PATCH] feat: support --abort for reload and suspend (#346) --- cmd/cmd.go | 4 ++++ cmd/reload.go | 7 +++++++ cmd/run.go | 7 +++++++ cmd/suspend.go | 6 ++++++ control/control_plane.go | 16 +++++++++++++++- 5 files changed, 39 insertions(+), 1 deletion(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index 8f7f495..a7201d9 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -5,6 +5,10 @@ import ( "github.com/spf13/cobra" ) +const ( + AbortFile = "/var/run/dae.abort" +) + var ( Version = "unknown" rootCmd = &cobra.Command{ diff --git a/cmd/reload.go b/cmd/reload.go index 5560353..7e2f7fd 100644 --- a/cmd/reload.go +++ b/cmd/reload.go @@ -17,6 +17,7 @@ import ( ) var ( + abort bool reloadCmd = &cobra.Command{ Use: "reload [pid]", Short: "To reload config file without interrupt connections.", @@ -35,6 +36,11 @@ var ( cmd.Help() os.Exit(1) } + if abort { + if f, err := os.Create(AbortFile); err == nil { + f.Close() + } + } if err = syscall.Kill(pid, syscall.SIGUSR1); err != nil { fmt.Println(err) os.Exit(1) @@ -46,4 +52,5 @@ var ( func init() { rootCmd.AddCommand(reloadCmd) + reloadCmd.PersistentFlags().BoolVarP(&abort, "abort", "a", false, "Abort established connections.") } diff --git a/cmd/run.go b/cmd/run.go index 493b115..0dbcdeb 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -108,6 +108,8 @@ var ( ) func Run(log *logrus.Logger, conf *config.Config, externGeoDataDirs []string) (err error) { + // Remove AbortFile at beginning. + _ = os.Remove(AbortFile) // New ControlPlane. c, err := newControlPlane(log, nil, nil, conf, externGeoDataDirs) @@ -135,6 +137,7 @@ func Run(log *logrus.Logger, conf *config.Config, externGeoDataDirs []string) (e }() reloading := false isSuspend := false + abortConnections := false loop: for sig := range sigs { switch sig { @@ -174,6 +177,7 @@ loop: sdnotify.Reloading() // Load new config. + abortConnections = os.Remove(AbortFile) == nil log.Warnln("[Reload] Load new config") var newConf *config.Config if isSuspend { @@ -247,6 +251,9 @@ loop: reloading = true // Ready to close. + if abortConnections { + oldC.AbortConnections() + } oldC.Close() case syscall.SIGHUP: // Ignore. diff --git a/cmd/suspend.go b/cmd/suspend.go index 9e8ef1f..b5ff423 100644 --- a/cmd/suspend.go +++ b/cmd/suspend.go @@ -35,6 +35,11 @@ var ( cmd.Help() os.Exit(1) } + if abort { + if f, err := os.Create(AbortFile); err == nil { + f.Close() + } + } if err = syscall.Kill(pid, syscall.SIGUSR2); err != nil { fmt.Println(err) os.Exit(1) @@ -46,4 +51,5 @@ var ( func init() { rootCmd.AddCommand(suspendCmd) + suspendCmd.PersistentFlags().BoolVarP(&abort, "abort", "a", false, "Abort established connections.") } diff --git a/control/control_plane.go b/control/control_plane.go index 00f2384..e658fd8 100644 --- a/control/control_plane.go +++ b/control/control_plane.go @@ -7,6 +7,7 @@ package control import ( "context" + "errors" "fmt" "net" "net/netip" @@ -50,7 +51,8 @@ type ControlPlane struct { listenIp string // TODO: add mutex? - outbounds []*outbound.DialerGroup + outbounds []*outbound.DialerGroup + inConnections sync.Map dnsController *DnsController onceNetworkReady sync.Once @@ -707,6 +709,8 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err break } go func(lconn net.Conn) { + c.inConnections.Store(lconn, struct{}{}) + defer c.inConnections.Delete(lconn) if err := c.handleConn(lconn); err != nil { c.log.Warnln("handleConn:", err) } @@ -920,6 +924,16 @@ func (c *ControlPlane) chooseBestDnsDialer( }, nil } +func (c *ControlPlane) AbortConnections() (err error) { + var errs []error + c.inConnections.Range(func(key, value any) bool { + if err = key.(net.Conn).Close(); err != nil { + errs = append(errs, err) + } + return true + }) + return errors.Join(errs...) +} func (c *ControlPlane) Close() (err error) { // Invoke defer funcs in reverse order. for i := len(c.deferFuncs) - 1; i >= 0; i-- {