feat: support reload

This commit is contained in:
mzz2017
2023-02-27 13:29:42 +08:00
parent 8b59492fe5
commit 01162f3d7e
10 changed files with 270 additions and 107 deletions

View File

@ -38,7 +38,7 @@ import (
type ControlPlane struct {
log *logrus.Logger
core *ControlPlaneCore
core *controlPlaneCore
deferFuncs []func() error
listenIp string
@ -55,6 +55,7 @@ type ControlPlane struct {
func NewControlPlane(
log *logrus.Logger,
_bpf interface{},
tagToNodeList map[string][]string,
groups []config.Group,
routingA *config.Routing,
@ -88,6 +89,8 @@ func NewControlPlane(
consts.BasicFeatureVersion.String())
}
var deferFuncs []func() error
/// Allow the current process to lock memory for eBPF resources.
if err = rlimit.RemoveMemlock(); err != nil {
return nil, fmt.Errorf("rlimit.RemoveMemlock:%v", err)
@ -102,7 +105,7 @@ func NewControlPlane(
/// Load pre-compiled programs and maps into the kernel.
log.Infof("Loading eBPF programs and maps into the kernel")
var bpf bpfObjects
//var bpf bpfObjects
var ProgramOptions = ebpf.ProgramOptions{
KernelTypes: nil,
}
@ -116,24 +119,37 @@ func NewControlPlane(
},
Programs: ProgramOptions,
}
if err = fullLoadBpfObjects(log, &bpf, &loadBpfOptions{
PinPath: pinPath,
CollectionOptions: collectionOpts,
BindLan: len(global.LanInterface) > 0,
BindWan: len(global.WanInterface) > 0,
}); err != nil {
if log.Level == logrus.PanicLevel {
log.Panicln(err)
var bpf *bpfObjects
if _bpf != nil {
if _bpf, ok := _bpf.(*bpfObjects); ok {
bpf = _bpf
} else {
return nil, fmt.Errorf("unexpected bpf type: %T", _bpf)
}
} else {
bpf = new(bpfObjects)
if err = fullLoadBpfObjects(log, bpf, &loadBpfOptions{
PinPath: pinPath,
CollectionOptions: collectionOpts,
BindLan: len(global.LanInterface) > 0,
BindWan: len(global.WanInterface) > 0,
}); err != nil {
if log.Level == logrus.PanicLevel {
log.Panicln(err)
}
return nil, fmt.Errorf("load eBPF objects: %w", err)
}
return nil, fmt.Errorf("load eBPF objects: %w", err)
}
core := &ControlPlaneCore{
log: log,
deferFuncs: []func() error{bpf.Close},
bpf: &bpf,
kernelVersion: &kernelVersion,
}
// outboundId2Name can be modified later.
outboundId2Name := make(map[uint8]string)
core := newControlPlaneCore(
log,
bpf,
outboundId2Name,
&kernelVersion,
_bpf != nil,
)
defer func() {
if err != nil {
_ = core.Close()
@ -141,16 +157,7 @@ func NewControlPlane(
}()
// Write params.
var lanNatDirect uint32
if global.LanNatDirect {
lanNatDirect = 1
} else {
lanNatDirect = 0
}
if err = bpf.ParamMap.Update(consts.ControlPlaneNatDirectKey, lanNatDirect, ebpf.UpdateAny); err != nil {
return nil, err
}
if err = bpf.ParamMap.Update(consts.ControlPlanePidKey, uint32(os.Getpid()), ebpf.UpdateAny); err != nil {
if err = core.bpf.ParamMap.Update(consts.ControlPlanePidKey, uint32(os.Getpid()), ebpf.UpdateAny); err != nil {
return nil, err
}
@ -212,6 +219,7 @@ func NewControlPlane(
// Filter out groups.
dialerSet := outbound.NewDialerSetFromLinks(option, tagToNodeList)
deferFuncs = append(deferFuncs, dialerSet.Close)
for _, group := range groups {
// Parse policy.
policy, err := outbound.NewDialerSelectionPolicyFromGroupParam(&group)
@ -244,7 +252,6 @@ func NewControlPlane(
return nil, fmt.Errorf("too many outbounds")
}
outboundName2Id := make(map[string]uint8)
outboundId2Name := make(map[uint8]string)
for i, o := range outbounds {
if _, exist := outboundName2Id[o.Name]; exist {
return nil, fmt.Errorf("duplicated outbound name: %v", o.Name)
@ -252,7 +259,6 @@ func NewControlPlane(
outboundName2Id[o.Name] = uint8(i)
outboundId2Name[uint8(i)] = o.Name
}
core.outboundId2Name = outboundId2Name
// Apply rules optimizers.
var rules []*config_parser.RoutingRule
if rules, err = routing.ApplyRulesOptimizers(routingA.Rules,
@ -272,7 +278,7 @@ func NewControlPlane(
log.Debugf("RoutingA:\n%vfallback: %v\n", debugBuilder.String(), routingA.Fallback)
}
// Parse rules and build.
builder, err := NewRoutingMatcherBuilder(log, rules, outboundName2Id, &bpf, routingA.Fallback)
builder, err := NewRoutingMatcherBuilder(log, rules, outboundName2Id, core.bpf, routingA.Fallback)
if err != nil {
return nil, fmt.Errorf("NewRoutingMatcherBuilder: %w", err)
}
@ -293,7 +299,7 @@ func NewControlPlane(
c = &ControlPlane{
log: log,
core: core,
deferFuncs: nil,
deferFuncs: deferFuncs,
listenIp: "0.0.0.0",
outbounds: outbounds,
dialMode: dialMode,
@ -332,6 +338,11 @@ func NewControlPlane(
return c, nil
}
// EjectBpf will resect bpf from destroying life-cycle of control plane.
func (c *ControlPlane) EjectBpf() *bpfObjects {
return c.core.EjectBpf()
}
func (c *ControlPlane) dnsUpstreamReadyCallback(raw *url.URL, dnsUpstream *dns.Upstream) (err error) {
/// Notify dialers to check.
c.onceNetworkReady.Do(func() {
@ -424,7 +435,17 @@ func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip
case consts.DialMode_Ip:
dialTarget = dst.String()
case consts.DialMode_Domain:
dialTarget = net.JoinHostPort(domain, strconv.Itoa(int(dst.Port())))
if _, err := netip.ParseAddr(domain); err == nil {
// domain is IPv4 or IPv6 (has colon)
dialTarget = net.JoinHostPort(domain, strconv.Itoa(int(dst.Port())))
} else if _, _, err := net.SplitHostPort(domain); err == nil {
// domain is already domain:port
dialTarget = domain
} else {
dialTarget = net.JoinHostPort(domain, strconv.Itoa(int(dst.Port())))
}
c.log.WithFields(logrus.Fields{
"from": dst.String(),
"to": dialTarget,
@ -433,29 +454,33 @@ func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip
return dialTarget, dialMode
}
func (c *ControlPlane) ListenAndServe(port uint16) (err error) {
// Listen.
var listenConfig = net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
return dialer.TproxyControl(c)
},
}
listenAddr := net.JoinHostPort(c.listenIp, strconv.Itoa(int(port)))
tcpListener, err := listenConfig.Listen(context.TODO(), "tcp", listenAddr)
if err != nil {
return fmt.Errorf("listenTCP: %w", err)
}
defer tcpListener.Close()
packetConn, err := listenConfig.ListenPacket(context.TODO(), "udp", listenAddr)
if err != nil {
return fmt.Errorf("listenUDP: %w", err)
}
defer packetConn.Close()
udpConn := packetConn.(*net.UDPConn)
type Listener struct {
tcpListener net.Listener
packetConn net.PacketConn
port uint16
}
func (l *Listener) Close() error {
var (
err error
err2 error
)
if err, err2 = l.tcpListener.Close(), l.packetConn.Close(); err2 != nil {
if err == nil {
err = err2
} else {
err = fmt.Errorf("%w: %v", err, err2)
}
}
return err
}
func (c *ControlPlane) Serve(listener *Listener) (err error) {
udpConn := listener.packetConn.(*net.UDPConn)
/// Serve.
// TCP socket.
tcpFile, err := tcpListener.(*net.TCPListener).File()
tcpFile, err := listener.tcpListener.(*net.TCPListener).File()
if err != nil {
return fmt.Errorf("failed to retrieve copy of the underlying TCP connection file")
}
@ -477,7 +502,7 @@ func (c *ControlPlane) ListenAndServe(port uint16) (err error) {
return err
}
// Port.
if err := c.core.bpf.ParamMap.Update(consts.BigEndianTproxyPortKey, uint32(common.Htons(port)), ebpf.UpdateAny); err != nil {
if err := c.core.bpf.ParamMap.Update(consts.BigEndianTproxyPortKey, uint32(common.Htons(listener.port)), ebpf.UpdateAny); err != nil {
return err
}
@ -489,23 +514,33 @@ func (c *ControlPlane) ListenAndServe(port uint16) (err error) {
go func() {
defer cancel()
for {
lconn, err := tcpListener.Accept()
select {
case <-ctx.Done():
return
default:
}
lconn, err := listener.tcpListener.Accept()
if err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
c.log.Errorf("Error when accept: %v", err)
}
break
}
go func() {
go func(lconn net.Conn) {
if err := c.handleConn(lconn); err != nil {
c.log.Warnln("handleConn:", err)
}
}()
}(lconn)
}
}()
go func() {
defer cancel()
for {
select {
case <-ctx.Done():
return
default:
}
var buf [65535]byte
var oob [120]byte // Size for original dest
n, oobn, _, src, err := udpConn.ReadMsgUDPAddrPort(buf[:], oob[:])
@ -551,6 +586,42 @@ func (c *ControlPlane) ListenAndServe(port uint16) (err error) {
return nil
}
func (c *ControlPlane) ListenAndServe(port uint16) (listener *Listener, err error) {
// Listen.
var listenConfig = net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
return dialer.TproxyControl(c)
},
}
listenAddr := net.JoinHostPort(c.listenIp, strconv.Itoa(int(port)))
tcpListener, err := listenConfig.Listen(context.TODO(), "tcp", listenAddr)
if err != nil {
return nil, fmt.Errorf("listenTCP: %w", err)
}
packetConn, err := listenConfig.ListenPacket(context.TODO(), "udp", listenAddr)
if err != nil {
_ = tcpListener.Close()
return nil, fmt.Errorf("listenUDP: %w", err)
}
listener = &Listener{
tcpListener: tcpListener,
packetConn: packetConn,
port: port,
}
defer func() {
if err != nil {
_ = listener.Close()
}
}()
// Serve
if err = c.Serve(listener); err != nil {
return nil, fmt.Errorf("failed to serve: %w", err)
}
return listener, nil
}
func (c *ControlPlane) chooseBestDnsDialer(
req *udpRequest,
dnsUpstream *dns.Upstream,