mirror of
https://github.com/daeuniverse/dae.git
synced 2024-12-22 15:44:42 +07:00
388 lines
12 KiB
Go
388 lines
12 KiB
Go
/*
|
|
* SPDX-License-Identifier: AGPL-3.0-only
|
|
* Copyright (c) 2022-2024, daeuniverse Organization <dae@v2raya.org>
|
|
*/
|
|
|
|
package control
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"fmt"
|
|
"net/netip"
|
|
"strconv"
|
|
|
|
"github.com/daeuniverse/dae/pkg/trie"
|
|
|
|
"github.com/cilium/ebpf"
|
|
"github.com/daeuniverse/dae/common"
|
|
"github.com/daeuniverse/dae/common/consts"
|
|
"github.com/daeuniverse/dae/component/routing"
|
|
"github.com/daeuniverse/dae/component/routing/domain_matcher"
|
|
"github.com/daeuniverse/dae/config"
|
|
"github.com/daeuniverse/dae/pkg/config_parser"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type RoutingMatcherBuilder struct {
|
|
log *logrus.Logger
|
|
outboundName2Id map[string]uint8
|
|
bpf *bpfObjects
|
|
rules []bpfMatchSet
|
|
simulatedLpmTries [][]netip.Prefix
|
|
simulatedDomainSet []routing.DomainSet
|
|
fallback *routing.Outbound
|
|
}
|
|
|
|
func NewRoutingMatcherBuilder(log *logrus.Logger, rules []*config_parser.RoutingRule, outboundName2Id map[string]uint8, bpf *bpfObjects, fallback config.FunctionOrString) (b *RoutingMatcherBuilder, err error) {
|
|
b = &RoutingMatcherBuilder{log: log, outboundName2Id: outboundName2Id, bpf: bpf}
|
|
rulesBuilder := routing.NewRulesBuilder(log)
|
|
rulesBuilder.RegisterFunctionParser(consts.Function_Domain, routing.PlainParserFactory(b.addDomain))
|
|
rulesBuilder.RegisterFunctionParser(consts.Function_Ip, routing.IpParserFactory(b.addIp))
|
|
rulesBuilder.RegisterFunctionParser(consts.Function_SourceIp, routing.IpParserFactory(b.addSourceIp))
|
|
rulesBuilder.RegisterFunctionParser(consts.Function_Port, routing.PortRangeParserFactory(b.addPort))
|
|
rulesBuilder.RegisterFunctionParser(consts.Function_SourcePort, routing.PortRangeParserFactory(b.addSourcePort))
|
|
rulesBuilder.RegisterFunctionParser(consts.Function_L4Proto, routing.L4ProtoParserFactory(b.addL4Proto))
|
|
rulesBuilder.RegisterFunctionParser(consts.Function_Mac, routing.MacParserFactory(b.addSourceMac))
|
|
rulesBuilder.RegisterFunctionParser(consts.Function_ProcessName, routing.ProcessNameParserFactory(b.addProcessName))
|
|
rulesBuilder.RegisterFunctionParser(consts.Function_Dscp, routing.UintParserFactory(b.addDscp))
|
|
rulesBuilder.RegisterFunctionParser(consts.Function_IpVersion, routing.IpVersionParserFactory(b.addIpVersion))
|
|
if err = rulesBuilder.Apply(rules); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err = b.addFallback(fallback); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return b, nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) outboundToId(outbound string) (uint8, error) {
|
|
var outboundId uint8
|
|
switch outbound {
|
|
case consts.OutboundLogicalOr.String():
|
|
outboundId = uint8(consts.OutboundLogicalOr)
|
|
case consts.OutboundLogicalAnd.String():
|
|
outboundId = uint8(consts.OutboundLogicalAnd)
|
|
case consts.OutboundMustRules.String():
|
|
outboundId = uint8(consts.OutboundMustRules)
|
|
default:
|
|
var ok bool
|
|
outboundId, ok = b.outboundName2Id[outbound]
|
|
if !ok {
|
|
return 0, fmt.Errorf("outbound (group) %v not found; please define it in section \"group\"", strconv.Quote(outbound))
|
|
}
|
|
}
|
|
return outboundId, nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) addDomain(f *config_parser.Function, key string, values []string, outbound *routing.Outbound) (err error) {
|
|
switch consts.RoutingDomainKey(key) {
|
|
case consts.RoutingDomainKey_Regex,
|
|
consts.RoutingDomainKey_Full,
|
|
consts.RoutingDomainKey_Keyword,
|
|
consts.RoutingDomainKey_Suffix:
|
|
default:
|
|
return fmt.Errorf("addDomain: unsupported key: %v", key)
|
|
}
|
|
b.simulatedDomainSet = append(b.simulatedDomainSet, routing.DomainSet{
|
|
Key: consts.RoutingDomainKey(key),
|
|
RuleIndex: len(b.rules),
|
|
Domains: values,
|
|
})
|
|
outboundId, err := b.outboundToId(outbound.Name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
b.rules = append(b.rules, bpfMatchSet{
|
|
Type: uint8(consts.MatchType_DomainSet),
|
|
Not: f.Not,
|
|
Outbound: outboundId,
|
|
Mark: outbound.Mark,
|
|
Must: outbound.Must,
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) addSourceMac(f *config_parser.Function, macAddrs [][6]byte, outbound *routing.Outbound) (err error) {
|
|
var addr16 [16]byte
|
|
values := make([]netip.Prefix, 0, len(macAddrs))
|
|
for _, mac := range macAddrs {
|
|
copy(addr16[10:], mac[:])
|
|
prefix := netip.PrefixFrom(netip.AddrFrom16(addr16), 128)
|
|
values = append(values, prefix)
|
|
}
|
|
lpmTrieIndex := len(b.simulatedLpmTries)
|
|
b.simulatedLpmTries = append(b.simulatedLpmTries, values)
|
|
outboundId, err := b.outboundToId(outbound.Name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
set := bpfMatchSet{
|
|
Value: [16]byte{},
|
|
Type: uint8(consts.MatchType_Mac),
|
|
Not: f.Not,
|
|
Outbound: outboundId,
|
|
Mark: outbound.Mark,
|
|
Must: outbound.Must,
|
|
}
|
|
binary.LittleEndian.PutUint32(set.Value[:], uint32(lpmTrieIndex))
|
|
b.rules = append(b.rules, set)
|
|
return nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) addIp(f *config_parser.Function, values []netip.Prefix, outbound *routing.Outbound) (err error) {
|
|
lpmTrieIndex := len(b.simulatedLpmTries)
|
|
b.simulatedLpmTries = append(b.simulatedLpmTries, values)
|
|
outboundId, err := b.outboundToId(outbound.Name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
set := bpfMatchSet{
|
|
Value: [16]byte{},
|
|
Type: uint8(consts.MatchType_IpSet),
|
|
Not: f.Not,
|
|
Outbound: outboundId,
|
|
Mark: outbound.Mark,
|
|
Must: outbound.Must,
|
|
}
|
|
binary.LittleEndian.PutUint32(set.Value[:], uint32(lpmTrieIndex))
|
|
b.rules = append(b.rules, set)
|
|
return nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) addPort(f *config_parser.Function, values [][2]uint16, outbound *routing.Outbound) (err error) {
|
|
for i, value := range values {
|
|
outboundName := consts.OutboundLogicalOr.String()
|
|
if i == len(values)-1 {
|
|
outboundName = outbound.Name
|
|
}
|
|
outboundId, err := b.outboundToId(outboundName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
b.rules = append(b.rules, bpfMatchSet{
|
|
Type: uint8(consts.MatchType_Port),
|
|
Value: _bpfPortRange{
|
|
PortStart: value[0],
|
|
PortEnd: value[1],
|
|
}.Encode(),
|
|
Not: f.Not,
|
|
Outbound: outboundId,
|
|
Mark: outbound.Mark,
|
|
Must: outbound.Must,
|
|
})
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) addSourceIp(f *config_parser.Function, values []netip.Prefix, outbound *routing.Outbound) (err error) {
|
|
lpmTrieIndex := len(b.simulatedLpmTries)
|
|
b.simulatedLpmTries = append(b.simulatedLpmTries, values)
|
|
outboundId, err := b.outboundToId(outbound.Name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
set := bpfMatchSet{
|
|
Value: [16]byte{},
|
|
Type: uint8(consts.MatchType_SourceIpSet),
|
|
Not: f.Not,
|
|
Outbound: outboundId,
|
|
Mark: outbound.Mark,
|
|
Must: outbound.Must,
|
|
}
|
|
binary.LittleEndian.PutUint32(set.Value[:], uint32(lpmTrieIndex))
|
|
b.rules = append(b.rules, set)
|
|
return nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) addSourcePort(f *config_parser.Function, values [][2]uint16, outbound *routing.Outbound) (err error) {
|
|
for i, value := range values {
|
|
outboundName := consts.OutboundLogicalOr.String()
|
|
if i == len(values)-1 {
|
|
outboundName = outbound.Name
|
|
}
|
|
outboundId, err := b.outboundToId(outboundName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
b.rules = append(b.rules, bpfMatchSet{
|
|
Type: uint8(consts.MatchType_SourcePort),
|
|
Value: _bpfPortRange{
|
|
PortStart: value[0],
|
|
PortEnd: value[1],
|
|
}.Encode(),
|
|
Not: f.Not,
|
|
Outbound: outboundId,
|
|
Mark: outbound.Mark,
|
|
Must: outbound.Must,
|
|
})
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) addL4Proto(f *config_parser.Function, values consts.L4ProtoType, outbound *routing.Outbound) (err error) {
|
|
outboundId, err := b.outboundToId(outbound.Name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
b.rules = append(b.rules, bpfMatchSet{
|
|
Value: [16]byte{byte(values)},
|
|
Type: uint8(consts.MatchType_L4Proto),
|
|
Not: f.Not,
|
|
Outbound: outboundId,
|
|
Mark: outbound.Mark,
|
|
Must: outbound.Must,
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) addIpVersion(f *config_parser.Function, values consts.IpVersionType, outbound *routing.Outbound) (err error) {
|
|
outboundId, err := b.outboundToId(outbound.Name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
b.rules = append(b.rules, bpfMatchSet{
|
|
Value: [16]byte{byte(values)},
|
|
Type: uint8(consts.MatchType_IpVersion),
|
|
Not: f.Not,
|
|
Outbound: outboundId,
|
|
Mark: outbound.Mark,
|
|
Must: outbound.Must,
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) addProcessName(f *config_parser.Function, values [][consts.TaskCommLen]byte, outbound *routing.Outbound) (err error) {
|
|
for i, value := range values {
|
|
outboundName := consts.OutboundLogicalOr.String()
|
|
if i == len(values)-1 {
|
|
outboundName = outbound.Name
|
|
}
|
|
outboundId, err := b.outboundToId(outboundName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
matchSet := bpfMatchSet{
|
|
Type: uint8(consts.MatchType_ProcessName),
|
|
Not: f.Not,
|
|
Outbound: outboundId,
|
|
Mark: outbound.Mark,
|
|
Must: outbound.Must,
|
|
}
|
|
copy(matchSet.Value[:], value[:])
|
|
b.rules = append(b.rules, matchSet)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) addDscp(f *config_parser.Function, values []uint8, outbound *routing.Outbound) (err error) {
|
|
for i, value := range values {
|
|
outboundName := consts.OutboundLogicalOr.String()
|
|
if i == len(values)-1 {
|
|
outboundName = outbound.Name
|
|
}
|
|
outboundId, err := b.outboundToId(outboundName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
matchSet := bpfMatchSet{
|
|
Type: uint8(consts.MatchType_Dscp),
|
|
Not: f.Not,
|
|
Outbound: outboundId,
|
|
Mark: outbound.Mark,
|
|
Must: outbound.Must,
|
|
}
|
|
matchSet.Value[0] = value
|
|
b.rules = append(b.rules, matchSet)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) addFallback(fallbackOutbound config.FunctionOrString) (err error) {
|
|
outbound, err := routing.ParseOutbound(config.FunctionOrStringToFunction(fallbackOutbound))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
outboundId, err := b.outboundToId(outbound.Name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
b.rules = append(b.rules, bpfMatchSet{
|
|
Type: uint8(consts.MatchType_Fallback),
|
|
Outbound: outboundId,
|
|
Mark: outbound.Mark,
|
|
Must: outbound.Must,
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) BuildKernspace(log *logrus.Logger) (err error) {
|
|
// Update lpm_array_map.
|
|
for i, cidrs := range b.simulatedLpmTries {
|
|
var keys []_bpfLpmKey
|
|
var values []uint32
|
|
for _, cidr := range cidrs {
|
|
keys = append(keys, cidrToBpfLpmKey(cidr))
|
|
values = append(values, 1)
|
|
}
|
|
m, err := b.bpf.newLpmMap(keys, values)
|
|
if err != nil {
|
|
return fmt.Errorf("newLpmMap: %w", err)
|
|
}
|
|
// We cannot invoke BpfMapBatchUpdate when value is ebpf.Map.
|
|
if err = b.bpf.LpmArrayMap.Update(uint32(i), m, ebpf.UpdateAny); err != nil {
|
|
m.Close()
|
|
return fmt.Errorf("Update: %w", err)
|
|
}
|
|
m.Close()
|
|
}
|
|
// Write routings.
|
|
// Fallback rule MUST be the last.
|
|
if b.rules[len(b.rules)-1].Type != uint8(consts.MatchType_Fallback) {
|
|
return fmt.Errorf("fallback rule MUST be the last")
|
|
}
|
|
routingsLen := uint32(len(b.rules))
|
|
routingsKeys := common.ARangeU32(routingsLen)
|
|
if _, err = BpfMapBatchUpdate(b.bpf.RoutingMap, routingsKeys, b.rules, &ebpf.BatchOptions{
|
|
ElemFlags: uint64(ebpf.UpdateAny),
|
|
}); err != nil {
|
|
return fmt.Errorf("BpfMapBatchUpdate: %w", err)
|
|
}
|
|
log.Infof("Routing match set len: %v/%v", len(b.rules), consts.MaxMatchSetLen)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (b *RoutingMatcherBuilder) BuildUserspace() (matcher *RoutingMatcher, err error) {
|
|
// Build domainMatcher
|
|
domainMatcher := domain_matcher.NewAhocorasickSlimtrie(b.log, consts.MaxMatchSetLen)
|
|
for _, domains := range b.simulatedDomainSet {
|
|
domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
|
|
}
|
|
// Build Ip matcher.
|
|
var lpmMatcher []*trie.Trie
|
|
for _, prefixes := range b.simulatedLpmTries {
|
|
t, err := trie.NewTrieFromPrefixes(prefixes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
lpmMatcher = append(lpmMatcher, t)
|
|
}
|
|
if err = domainMatcher.Build(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Write routings.
|
|
// Fallback rule MUST be the last.
|
|
if b.rules[len(b.rules)-1].Type != uint8(consts.MatchType_Fallback) {
|
|
return nil, fmt.Errorf("fallback rule MUST be the last")
|
|
}
|
|
|
|
return &RoutingMatcher{
|
|
lpmMatcher: lpmMatcher,
|
|
domainMatcher: domainMatcher,
|
|
matches: b.rules,
|
|
}, nil
|
|
}
|