chore: refine code and add annotations

This commit is contained in:
mzz2017
2023-02-08 15:38:13 +08:00
parent 7c0418d245
commit a3d4a06dab
2 changed files with 77 additions and 53 deletions

View File

@ -11,11 +11,13 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/cilium/ebpf" "github.com/cilium/ebpf"
"github.com/sirupsen/logrus"
"github.com/v2rayA/dae/common" "github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts" "github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/pkg/ebpf_internal" "github.com/v2rayA/dae/pkg/ebpf_internal"
"net/netip" "net/netip"
"os" "os"
"path/filepath"
"reflect" "reflect"
"strings" "strings"
"sync" "sync"
@ -183,3 +185,61 @@ func (p bpfIfParams) CheckVersionRequirement(version *internal.Version) (err err
} }
return nil return nil
} }
type loadBpfOptions struct {
PinPath string
CollectionOptions *ebpf.CollectionOptions
BindLan bool
BindWan bool
}
func selectivelyLoadBpfObjects(
log *logrus.Logger,
bpf *bpfObjects,
opts *loadBpfOptions,
) (err error) {
// Trick. Replace the beams with rotten timbers to reduce the loading.
var obj interface{} = bpf // Bind to both LAN and WAN.
if opts.BindLan && !opts.BindWan {
// Only bind LAN.
obj = &bpfObjectsLan{}
} else if !opts.BindLan && opts.BindWan {
// Only bind to WAN.
// Trick. Replace the beams with rotten timbers.
obj = &bpfObjectsWan{}
}
retryLoadBpf:
if err = loadBpfObjects(obj, opts.CollectionOptions); err != nil {
if errors.Is(err, ebpf.ErrMapIncompatible) {
// Map property is incompatible. Remove the old map and try again.
prefix := "use pinned map "
_, after, ok := strings.Cut(err.Error(), prefix)
if !ok {
return fmt.Errorf("loading objects: bad format: %w", err)
}
mapName, _, _ := strings.Cut(after, ":")
_ = os.Remove(filepath.Join(opts.PinPath, mapName))
log.Infof("Incompatible new map format with existing map %v detected; removed the old one.", mapName)
goto retryLoadBpf
}
// Get detailed log from ebpf.internal.(*VerifierError)
if log.Level == logrus.FatalLevel {
if v := reflect.Indirect(reflect.ValueOf(errors.Unwrap(errors.Unwrap(err)))); v.Kind() == reflect.Struct {
if _log := v.FieldByName("Log"); _log.IsValid() {
if strSlice, ok := _log.Interface().([]string); ok {
log.Fatalln(strings.Join(strSlice, "\n"))
}
}
}
}
if strings.Contains(err.Error(), "no BTF found for kernel version") {
err = fmt.Errorf("%w: maybe installing the linux-headers package will solve it", err)
}
return err
}
if _, ok := obj.(*bpfObjects); !ok {
// Reverse takeover.
AssignBpfObjects(bpf, obj)
}
return nil
}

View File

@ -7,7 +7,6 @@ package control
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/cilium/ebpf" "github.com/cilium/ebpf"
"github.com/cilium/ebpf/rlimit" "github.com/cilium/ebpf/rlimit"
@ -26,7 +25,6 @@ import (
"net/netip" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -68,7 +66,8 @@ func NewControlPlane(
if e != nil { if e != nil {
return nil, fmt.Errorf("failed to get kernel version: %w", e) return nil, fmt.Errorf("failed to get kernel version: %w", e)
} }
// Must judge version from high to low to reduce the number of user upgrading kernel. /// Check linux kernel requirements.
// Check version from high to low to reduce the number of user upgrading kernel.
if kernelVersion.Less(consts.ChecksumFeatureVersion) { if kernelVersion.Less(consts.ChecksumFeatureVersion) {
return nil, fmt.Errorf("your kernel version %v does not support checksum related features; expect >=%v; upgrade your kernel and try again", return nil, fmt.Errorf("your kernel version %v does not support checksum related features; expect >=%v; upgrade your kernel and try again",
kernelVersion.String(), kernelVersion.String(),
@ -90,7 +89,7 @@ func NewControlPlane(
consts.BasicFeatureVersion.String()) consts.BasicFeatureVersion.String())
} }
// Allow the current process to lock memory for eBPF resources. /// Allow the current process to lock memory for eBPF resources.
if err = rlimit.RemoveMemlock(); err != nil { if err = rlimit.RemoveMemlock(); err != nil {
return nil, fmt.Errorf("rlimit.RemoveMemlock:%v", err) return nil, fmt.Errorf("rlimit.RemoveMemlock:%v", err)
} }
@ -99,7 +98,7 @@ func NewControlPlane(
return nil, err return nil, err
} }
// Load pre-compiled programs and maps into the kernel. /// Load pre-compiled programs and maps into the kernel.
log.Infof("Loading eBPF programs and maps into the kernel") log.Infof("Loading eBPF programs and maps into the kernel")
var bpf bpfObjects var bpf bpfObjects
var ProgramOptions = ebpf.ProgramOptions{ var ProgramOptions = ebpf.ProgramOptions{
@ -109,54 +108,19 @@ func NewControlPlane(
ProgramOptions.LogLevel = ebpf.LogLevelBranch | ebpf.LogLevelStats ProgramOptions.LogLevel = ebpf.LogLevelBranch | ebpf.LogLevelStats
// ProgramOptions.LogLevel = ebpf.LogLevelInstruction | ebpf.LogLevelStats // ProgramOptions.LogLevel = ebpf.LogLevelInstruction | ebpf.LogLevelStats
} }
collectionOpts := &ebpf.CollectionOptions{
// Trick. Replace the beams with rotten timbers to reduce the loading.
var obj interface{} = &bpf // Bind to both LAN and WAN.
if len(lanInterface) > 0 && len(wanInterface) == 0 {
// Only bind LAN.
obj = &bpfObjectsLan{}
} else if len(lanInterface) == 0 && len(wanInterface) > 0 {
// Only bind to WAN.
// Trick. Replace the beams with rotten timbers.
obj = &bpfObjectsWan{}
}
retryLoadBpf:
if err = loadBpfObjects(obj, &ebpf.CollectionOptions{
Maps: ebpf.MapOptions{ Maps: ebpf.MapOptions{
PinPath: pinPath, PinPath: pinPath,
}, },
Programs: ProgramOptions, Programs: ProgramOptions,
}
if err = selectivelyLoadBpfObjects(log, &bpf, &loadBpfOptions{
PinPath: pinPath,
CollectionOptions: collectionOpts,
BindLan: len(lanInterface) > 0,
BindWan: len(wanInterface) > 0,
}); err != nil { }); err != nil {
if errors.Is(err, ebpf.ErrMapIncompatible) { return nil, fmt.Errorf("load eBPF objects: %w", err)
// Map property is incompatible. Remove the old map and try again.
prefix := "use pinned map "
_, after, ok := strings.Cut(err.Error(), prefix)
if !ok {
return nil, fmt.Errorf("loading objects: bad format: %w", err)
}
mapName, _, _ := strings.Cut(after, ":")
_ = os.Remove(filepath.Join(pinPath, mapName))
log.Infof("Incompatible new map format with existing map %v detected; removed the old one.", mapName)
goto retryLoadBpf
}
// Get detailed log from ebpf.internal.(*VerifierError)
if log.Level == logrus.FatalLevel {
if v := reflect.Indirect(reflect.ValueOf(errors.Unwrap(errors.Unwrap(err)))); v.Kind() == reflect.Struct {
if _log := v.FieldByName("Log"); _log.IsValid() {
if strSlice, ok := _log.Interface().([]string); ok {
log.Fatalln(strings.Join(strSlice, "\n"))
}
}
}
}
if strings.Contains(err.Error(), "no BTF found for kernel version") {
err = fmt.Errorf("%w: maybe installing the linux-headers package will solve it", err)
}
return nil, fmt.Errorf("loading objects: %w", err)
}
if _, ok := obj.(*bpfObjects); !ok {
// Reverse takeover.
AssignBpfObjects(&bpf, obj)
} }
// Write params. // Write params.
@ -170,7 +134,7 @@ retryLoadBpf:
if err = bpf.ParamMap.Update(consts.ControlPlaneOidKey, uint32(os.Getpid()), ebpf.UpdateAny); err != nil { if err = bpf.ParamMap.Update(consts.ControlPlaneOidKey, uint32(os.Getpid()), ebpf.UpdateAny); err != nil {
return nil, err return nil, err
} }
// Write ip_proto to hdr_size map for IPv6 extension extraction. // Write ip_proto to hdr_size mapping for IPv6 extension extraction (it is just for eBPF code insns optimization).
if err = bpf.IpprotoHdrsizeMap.Update(uint32(unix.IPPROTO_HOPOPTS), int32(-1), ebpf.UpdateAny); err != nil { if err = bpf.IpprotoHdrsizeMap.Update(uint32(unix.IPPROTO_HOPOPTS), int32(-1), ebpf.UpdateAny); err != nil {
return nil, err return nil, err
} }
@ -226,7 +190,7 @@ retryLoadBpf:
} }
} }
// DialerGroups (outbounds). /// DialerGroups (outbounds).
option := &dialer.GlobalOption{ option := &dialer.GlobalOption{
Log: log, Log: log,
CheckUrl: checkUrl, CheckUrl: checkUrl,
@ -274,6 +238,7 @@ retryLoadBpf:
outbounds = append(outbounds, dialerGroup) outbounds = append(outbounds, dialerGroup)
} }
/// Routing.
// Generate outboundName2Id from outbounds. // Generate outboundName2Id from outbounds.
if len(outbounds) > 0xff { if len(outbounds) > 0xff {
return nil, fmt.Errorf("too many outbounds") return nil, fmt.Errorf("too many outbounds")
@ -283,8 +248,6 @@ retryLoadBpf:
outboundName2Id[o.Name] = uint8(i) outboundName2Id[o.Name] = uint8(i)
} }
builder := NewRoutingMatcherBuilder(outboundName2Id, &bpf) builder := NewRoutingMatcherBuilder(outboundName2Id, &bpf)
// Routing.
var rules []*config_parser.RoutingRule var rules []*config_parser.RoutingRule
if rules, err = routing.ApplyRulesOptimizers(routingA.Rules, if rules, err = routing.ApplyRulesOptimizers(routingA.Rules,
&routing.RefineFunctionParamKeyOptimizer{}, &routing.RefineFunctionParamKeyOptimizer{},
@ -308,7 +271,7 @@ retryLoadBpf:
return nil, fmt.Errorf("RoutingMatcherBuilder.Build: %w", err) return nil, fmt.Errorf("RoutingMatcherBuilder.Build: %w", err)
} }
// DNS upstream. /// DNS upstream.
var dnsAddrPort netip.AddrPort var dnsAddrPort netip.AddrPort
if dnsUpstream != "" { if dnsUpstream != "" {
dnsAddrPort, err = netip.ParseAddrPort(dnsUpstream) dnsAddrPort, err = netip.ParseAddrPort(dnsUpstream)
@ -332,6 +295,7 @@ retryLoadBpf:
} }
} }
/// Listen address.
listenIp := "::1" listenIp := "::1"
if len(wanInterface) > 0 { if len(wanInterface) > 0 {
listenIp = "0.0.0.0" listenIp = "0.0.0.0"