fix: rollback reloading and refresh fixed outbound alive state

This commit is contained in:
mzz2017
2023-03-23 15:34:56 +08:00
parent 7adbb2fbb9
commit b69cb63a11
5 changed files with 101 additions and 56 deletions

View File

@ -186,6 +186,7 @@ loop:
"err": err, "err": err,
}).Fatalln("[Reload] Failed to roll back configuration") }).Fatalln("[Reload] Failed to roll back configuration")
} }
newConf = conf
log.Warnln("[Reload] Last reload failed; rolled back configuration") log.Warnln("[Reload] Last reload failed; rolled back configuration")
} else { } else {
log.Warnln("[Reload] Stopped old control plane") log.Warnln("[Reload] Stopped old control plane")

View File

@ -7,10 +7,10 @@ package outbound
import ( import (
"fmt" "fmt"
"github.com/mzz2017/softwind/netproxy"
"github.com/sirupsen/logrus"
"github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/component/outbound/dialer" "github.com/daeuniverse/dae/component/outbound/dialer"
"github.com/mzz2017/softwind/netproxy"
"github.com/sirupsen/logrus"
"time" "time"
) )
@ -38,84 +38,96 @@ func NewDialerGroup(option *dialer.GlobalOption, name string, dialers []*dialer.
var aliveDnsUdp4DialerSet *dialer.AliveDialerSet var aliveDnsUdp4DialerSet *dialer.AliveDialerSet
var aliveDnsUdp6DialerSet *dialer.AliveDialerSet var aliveDnsUdp6DialerSet *dialer.AliveDialerSet
var needAliveState bool
switch p.Policy { switch p.Policy {
case consts.DialerSelectionPolicy_Random, case consts.DialerSelectionPolicy_Random,
consts.DialerSelectionPolicy_MinLastLatency, consts.DialerSelectionPolicy_MinLastLatency,
consts.DialerSelectionPolicy_MinAverage10Latencies, consts.DialerSelectionPolicy_MinAverage10Latencies,
consts.DialerSelectionPolicy_MinMovingAverageLatencies: consts.DialerSelectionPolicy_MinMovingAverageLatencies:
// Need to know the alive state or latency. // Need to know the alive state or latency.
networkType := &dialer.NetworkType{ needAliveState = true
L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_4, case consts.DialerSelectionPolicy_Fixed:
IsDns: false, // No need to know if the dialer is alive.
} needAliveState = false
default:
log.Panicf("Unexpected dialer selection policy: %v", p.Policy)
}
networkType := &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_4,
IsDns: false,
}
if needAliveState {
aliveTcp4DialerSet = dialer.NewAliveDialerSet( aliveTcp4DialerSet = dialer.NewAliveDialerSet(
log, name, networkType, option.CheckTolerance, p.Policy, dialers, log, name, networkType, option.CheckTolerance, p.Policy, dialers,
func(networkType *dialer.NetworkType) func(alive bool) { func(networkType *dialer.NetworkType) func(alive bool) {
// Use the trick to copy a pointer of *dialer.NetworkType. // Use the trick to copy a pointer of *dialer.NetworkType.
return func(alive bool) { aliveChangeCallback(alive, networkType, false) } return func(alive bool) { aliveChangeCallback(alive, networkType, false) }
}(networkType), true) }(networkType), true)
aliveChangeCallback(true, networkType, true) }
aliveChangeCallback(true, networkType, true)
networkType = &dialer.NetworkType{ networkType = &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_TCP, L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_6, IpVersion: consts.IpVersionStr_6,
IsDns: false, IsDns: false,
} }
if needAliveState {
aliveTcp6DialerSet = dialer.NewAliveDialerSet( aliveTcp6DialerSet = dialer.NewAliveDialerSet(
log, name, networkType, option.CheckTolerance, p.Policy, dialers, log, name, networkType, option.CheckTolerance, p.Policy, dialers,
func(networkType *dialer.NetworkType) func(alive bool) { func(networkType *dialer.NetworkType) func(alive bool) {
// Use the trick to copy a pointer of *dialer.NetworkType. // Use the trick to copy a pointer of *dialer.NetworkType.
return func(alive bool) { aliveChangeCallback(alive, networkType, false) } return func(alive bool) { aliveChangeCallback(alive, networkType, false) }
}(networkType), true) }(networkType), true)
aliveChangeCallback(true, networkType, true) }
aliveChangeCallback(true, networkType, true)
networkType = &dialer.NetworkType{ networkType = &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_UDP, L4Proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionStr_4, IpVersion: consts.IpVersionStr_4,
IsDns: true, IsDns: true,
} }
if needAliveState {
aliveDnsUdp4DialerSet = dialer.NewAliveDialerSet( aliveDnsUdp4DialerSet = dialer.NewAliveDialerSet(
log, name, networkType, option.CheckTolerance, p.Policy, dialers, log, name, networkType, option.CheckTolerance, p.Policy, dialers,
func(networkType *dialer.NetworkType) func(alive bool) { func(networkType *dialer.NetworkType) func(alive bool) {
// Use the trick to copy a pointer of *dialer.NetworkType. // Use the trick to copy a pointer of *dialer.NetworkType.
return func(alive bool) { aliveChangeCallback(alive, networkType, false) } return func(alive bool) { aliveChangeCallback(alive, networkType, false) }
}(networkType), true) }(networkType), true)
aliveChangeCallback(true, networkType, true) }
aliveChangeCallback(true, networkType, true)
networkType = &dialer.NetworkType{ networkType = &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_UDP, L4Proto: consts.L4ProtoStr_UDP,
IpVersion: consts.IpVersionStr_6, IpVersion: consts.IpVersionStr_6,
IsDns: true, IsDns: true,
} }
if needAliveState {
aliveDnsUdp6DialerSet = dialer.NewAliveDialerSet( aliveDnsUdp6DialerSet = dialer.NewAliveDialerSet(
log, name, networkType, option.CheckTolerance, p.Policy, dialers, log, name, networkType, option.CheckTolerance, p.Policy, dialers,
func(networkType *dialer.NetworkType) func(alive bool) { func(networkType *dialer.NetworkType) func(alive bool) {
// Use the trick to copy a pointer of *dialer.NetworkType. // Use the trick to copy a pointer of *dialer.NetworkType.
return func(alive bool) { aliveChangeCallback(alive, networkType, false) } return func(alive bool) { aliveChangeCallback(alive, networkType, false) }
}(networkType), true) }(networkType), true)
aliveChangeCallback(true, networkType, true) }
aliveChangeCallback(true, networkType, true)
if option.CheckDnsTcp { if option.CheckDnsTcp && needAliveState {
aliveDnsTcp4DialerSet = dialer.NewAliveDialerSet(log, name, &dialer.NetworkType{ aliveDnsTcp4DialerSet = dialer.NewAliveDialerSet(log, name, &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_TCP, L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_4, IpVersion: consts.IpVersionStr_4,
IsDns: true, IsDns: true,
}, option.CheckTolerance, p.Policy, dialers, func(alive bool) {}, true) }, option.CheckTolerance, p.Policy, dialers, func(alive bool) {}, true)
aliveDnsTcp6DialerSet = dialer.NewAliveDialerSet(log, name, &dialer.NetworkType{ aliveDnsTcp6DialerSet = dialer.NewAliveDialerSet(log, name, &dialer.NetworkType{
L4Proto: consts.L4ProtoStr_TCP, L4Proto: consts.L4ProtoStr_TCP,
IpVersion: consts.IpVersionStr_6, IpVersion: consts.IpVersionStr_6,
IsDns: true, IsDns: true,
}, option.CheckTolerance, p.Policy, dialers, func(alive bool) {}, true) }, option.CheckTolerance, p.Policy, dialers, func(alive bool) {}, true)
}
case consts.DialerSelectionPolicy_Fixed:
// No need to know if the dialer is alive.
default:
log.Panicf("Unexpected dialer selection policy: %v", p.Policy)
} }
for _, d := range dialers { for _, d := range dialers {

View File

@ -7,8 +7,8 @@ package control
import ( import (
"github.com/cilium/ebpf" "github.com/cilium/ebpf"
"github.com/sirupsen/logrus"
"github.com/daeuniverse/dae/component/outbound/dialer" "github.com/daeuniverse/dae/component/outbound/dialer"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"strconv" "strconv"
) )
@ -25,11 +25,17 @@ func FormatL4Proto(l4proto uint8) string {
func (c *controlPlaneCore) OutboundAliveChangeCallback(outbound uint8) func(alive bool, networkType *dialer.NetworkType, isInit bool) { func (c *controlPlaneCore) OutboundAliveChangeCallback(outbound uint8) func(alive bool, networkType *dialer.NetworkType, isInit bool) {
return func(alive bool, networkType *dialer.NetworkType, isInit bool) { return func(alive bool, networkType *dialer.NetworkType, isInit bool) {
if !isInit { select {
case <-c.closed.Done():
return
default:
}
if !isInit || c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{ c.log.WithFields(logrus.Fields{
"alive": alive, "alive": alive,
"network": networkType.StringWithoutDns(), "network": networkType.StringWithoutDns(),
"outbound": c.outboundId2Name[outbound], "outboundId": outbound,
"outbound": c.outboundId2Name[outbound],
}).Warnf("Outbound alive state changed, notify the kernel program.") }).Warnf("Outbound alive state changed, notify the kernel program.")
} }
@ -37,10 +43,16 @@ func (c *controlPlaneCore) OutboundAliveChangeCallback(outbound uint8) func(aliv
if alive { if alive {
value = 1 value = 1
} }
_ = c.bpf.OutboundConnectivityMap.Update(bpfOutboundConnectivityQuery{ if err := c.bpf.OutboundConnectivityMap.Update(bpfOutboundConnectivityQuery{
Outbound: outbound, Outbound: outbound,
L4proto: networkType.L4Proto.ToL4Proto(), L4proto: networkType.L4Proto.ToL4Proto(),
Ipversion: networkType.IpVersion.ToIpVersion(), Ipversion: networkType.IpVersion.ToIpVersion(),
}, value, ebpf.UpdateAny) }, value, ebpf.UpdateAny); err != nil {
c.log.WithFields(logrus.Fields{
"alive": alive,
"network": networkType.StringWithoutDns(),
"outbound": c.outboundId2Name[outbound],
}).Warnf("Failed to notify the kernel program: %v", err)
}
} }
} }

View File

@ -6,15 +6,16 @@
package control package control
import ( import (
"context"
"fmt" "fmt"
"github.com/cilium/ebpf" "github.com/cilium/ebpf"
ciliumLink "github.com/cilium/ebpf/link" ciliumLink "github.com/cilium/ebpf/link"
"github.com/mohae/deepcopy"
"github.com/safchain/ethtool"
"github.com/sirupsen/logrus"
"github.com/daeuniverse/dae/common" "github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/consts"
internal "github.com/daeuniverse/dae/pkg/ebpf_internal" internal "github.com/daeuniverse/dae/pkg/ebpf_internal"
"github.com/mohae/deepcopy"
"github.com/safchain/ethtool"
"github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@ -38,6 +39,9 @@ type controlPlaneCore struct {
flip int flip int
isReload bool isReload bool
bpfEjected bool bpfEjected bool
closed context.Context
close context.CancelFunc
} }
func newControlPlaneCore(log *logrus.Logger, func newControlPlaneCore(log *logrus.Logger,
@ -53,6 +57,7 @@ func newControlPlaneCore(log *logrus.Logger,
if !isReload { if !isReload {
deferFuncs = append(deferFuncs, bpf.Close) deferFuncs = append(deferFuncs, bpf.Close)
} }
closed, toClose := context.WithCancel(context.Background())
return &controlPlaneCore{ return &controlPlaneCore{
log: log, log: log,
deferFuncs: deferFuncs, deferFuncs: deferFuncs,
@ -61,6 +66,9 @@ func newControlPlaneCore(log *logrus.Logger,
kernelVersion: kernelVersion, kernelVersion: kernelVersion,
flip: coreFlip, flip: coreFlip,
isReload: isReload, isReload: isReload,
bpfEjected: false,
closed: closed,
close: toClose,
} }
} }
@ -68,6 +76,11 @@ func (c *controlPlaneCore) Flip() {
coreFlip = coreFlip&1 ^ 1 coreFlip = coreFlip&1 ^ 1
} }
func (c *controlPlaneCore) Close() (err error) { func (c *controlPlaneCore) Close() (err error) {
select {
case <-c.closed.Done():
return nil
default:
}
// Invoke defer funcs in reverse order. // Invoke defer funcs in reverse order.
for i := len(c.deferFuncs) - 1; i >= 0; i-- { for i := len(c.deferFuncs) - 1; i >= 0; i-- {
if e := c.deferFuncs[i](); e != nil { if e := c.deferFuncs[i](); e != nil {
@ -79,6 +92,7 @@ func (c *controlPlaneCore) Close() (err error) {
} }
} }
} }
c.close()
return err return err
} }

View File

@ -289,6 +289,12 @@ type dialArgument struct {
} }
func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) (err error) { func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) (err error) {
if c.log.IsLevelEnabled(logrus.TraceLevel) && len(dnsMessage.Questions) > 0 {
q := dnsMessage.Questions[0]
c.log.Tracef("Received UDP(DNS) %v <-> %v: %v %v",
RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), req.realDst.String(), strings.ToLower(q.Name.String()), q.Type,
)
}
if resp := c.LookupDnsRespCache_(dnsMessage); resp != nil { if resp := c.LookupDnsRespCache_(dnsMessage); resp != nil {
// Send cache to client directly. // Send cache to client directly.
if err = sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn, req.lanWanFlag); err != nil { if err = sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn, req.lanWanFlag); err != nil {
@ -296,7 +302,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest)
} }
if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 { if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 {
q := dnsMessage.Questions[0] q := dnsMessage.Questions[0]
c.log.Tracef("UDP(DNS) %v <-> Cache: %v %v", c.log.Debugf("UDP(DNS) %v <-> Cache: %v %v",
RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), strings.ToLower(q.Name.String()), q.Type, RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), strings.ToLower(q.Name.String()), q.Type,
) )
} }