feat: support --abort for reload and suspend (#346)

This commit is contained in:
mzz
2023-11-14 16:26:33 +08:00
committed by GitHub
parent 7e57531f91
commit 9f7a49b81d
5 changed files with 39 additions and 1 deletions

View File

@ -5,6 +5,10 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
const (
AbortFile = "/var/run/dae.abort"
)
var ( var (
Version = "unknown" Version = "unknown"
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{

View File

@ -17,6 +17,7 @@ import (
) )
var ( var (
abort bool
reloadCmd = &cobra.Command{ reloadCmd = &cobra.Command{
Use: "reload [pid]", Use: "reload [pid]",
Short: "To reload config file without interrupt connections.", Short: "To reload config file without interrupt connections.",
@ -35,6 +36,11 @@ var (
cmd.Help() cmd.Help()
os.Exit(1) os.Exit(1)
} }
if abort {
if f, err := os.Create(AbortFile); err == nil {
f.Close()
}
}
if err = syscall.Kill(pid, syscall.SIGUSR1); err != nil { if err = syscall.Kill(pid, syscall.SIGUSR1); err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
@ -46,4 +52,5 @@ var (
func init() { func init() {
rootCmd.AddCommand(reloadCmd) rootCmd.AddCommand(reloadCmd)
reloadCmd.PersistentFlags().BoolVarP(&abort, "abort", "a", false, "Abort established connections.")
} }

View File

@ -108,6 +108,8 @@ var (
) )
func Run(log *logrus.Logger, conf *config.Config, externGeoDataDirs []string) (err error) { func Run(log *logrus.Logger, conf *config.Config, externGeoDataDirs []string) (err error) {
// Remove AbortFile at beginning.
_ = os.Remove(AbortFile)
// New ControlPlane. // New ControlPlane.
c, err := newControlPlane(log, nil, nil, conf, externGeoDataDirs) c, err := newControlPlane(log, nil, nil, conf, externGeoDataDirs)
@ -135,6 +137,7 @@ func Run(log *logrus.Logger, conf *config.Config, externGeoDataDirs []string) (e
}() }()
reloading := false reloading := false
isSuspend := false isSuspend := false
abortConnections := false
loop: loop:
for sig := range sigs { for sig := range sigs {
switch sig { switch sig {
@ -174,6 +177,7 @@ loop:
sdnotify.Reloading() sdnotify.Reloading()
// Load new config. // Load new config.
abortConnections = os.Remove(AbortFile) == nil
log.Warnln("[Reload] Load new config") log.Warnln("[Reload] Load new config")
var newConf *config.Config var newConf *config.Config
if isSuspend { if isSuspend {
@ -247,6 +251,9 @@ loop:
reloading = true reloading = true
// Ready to close. // Ready to close.
if abortConnections {
oldC.AbortConnections()
}
oldC.Close() oldC.Close()
case syscall.SIGHUP: case syscall.SIGHUP:
// Ignore. // Ignore.

View File

@ -35,6 +35,11 @@ var (
cmd.Help() cmd.Help()
os.Exit(1) os.Exit(1)
} }
if abort {
if f, err := os.Create(AbortFile); err == nil {
f.Close()
}
}
if err = syscall.Kill(pid, syscall.SIGUSR2); err != nil { if err = syscall.Kill(pid, syscall.SIGUSR2); err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
@ -46,4 +51,5 @@ var (
func init() { func init() {
rootCmd.AddCommand(suspendCmd) rootCmd.AddCommand(suspendCmd)
suspendCmd.PersistentFlags().BoolVarP(&abort, "abort", "a", false, "Abort established connections.")
} }

View File

@ -7,6 +7,7 @@ package control
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@ -50,7 +51,8 @@ type ControlPlane struct {
listenIp string listenIp string
// TODO: add mutex? // TODO: add mutex?
outbounds []*outbound.DialerGroup outbounds []*outbound.DialerGroup
inConnections sync.Map
dnsController *DnsController dnsController *DnsController
onceNetworkReady sync.Once onceNetworkReady sync.Once
@ -707,6 +709,8 @@ func (c *ControlPlane) Serve(readyChan chan<- bool, listener *Listener) (err err
break break
} }
go func(lconn net.Conn) { go func(lconn net.Conn) {
c.inConnections.Store(lconn, struct{}{})
defer c.inConnections.Delete(lconn)
if err := c.handleConn(lconn); err != nil { if err := c.handleConn(lconn); err != nil {
c.log.Warnln("handleConn:", err) c.log.Warnln("handleConn:", err)
} }
@ -920,6 +924,16 @@ func (c *ControlPlane) chooseBestDnsDialer(
}, nil }, nil
} }
func (c *ControlPlane) AbortConnections() (err error) {
var errs []error
c.inConnections.Range(func(key, value any) bool {
if err = key.(net.Conn).Close(); err != nil {
errs = append(errs, err)
}
return true
})
return errors.Join(errs...)
}
func (c *ControlPlane) Close() (err error) { func (c *ControlPlane) Close() (err error) {
// Invoke defer funcs in reverse order. // Invoke defer funcs in reverse order.
for i := len(c.deferFuncs) - 1; i >= 0; i-- { for i := len(c.deferFuncs) - 1; i >= 0; i-- {