dae/component/control/routing_matcher_builder.go
2023-01-23 20:01:24 +08:00

134 lines
3.4 KiB
Go

/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) since 2022, mzz2017 (mzz@tuta.io). All rights reserved.
*/
package control
import (
"fmt"
"github.com/v2rayA/dae/common"
"github.com/v2rayA/dae/common/consts"
"github.com/v2rayA/dae/component/routing"
"github.com/cilium/ebpf"
"net/netip"
"strconv"
)
type DomainSet struct {
Key string
RuleIndex int
Domains []string
}
type RoutingMatcherBuilder struct {
*routing.DefaultMatcherBuilder
outboundName2Id map[string]uint8
bpf *bpfObjects
rules []bpfRouting
SimulatedLpmTries [][]netip.Prefix
SimulatedDomainSet []DomainSet
Final string
err error
}
func NewRoutingMatcherBuilder(outboundName2Id map[string]uint8, bpf *bpfObjects) *RoutingMatcherBuilder {
return &RoutingMatcherBuilder{outboundName2Id: outboundName2Id, bpf: bpf}
}
func (b *RoutingMatcherBuilder) OutboundToId(outbound string) uint8 {
var outboundId uint8
if outbound == routing.FakeOutbound_AND {
outboundId = uint8(consts.OutboundLogicalAnd)
} else {
var ok bool
outboundId, ok = b.outboundName2Id[outbound]
if !ok {
b.err = fmt.Errorf("%v not in outboundName2Id", strconv.Quote(outbound))
}
}
return outboundId
}
func (b *RoutingMatcherBuilder) AddDomain(key string, values []string, outbound string) {
if b.err != nil {
return
}
switch key {
case consts.RoutingDomain_Regex,
consts.RoutingDomain_Full,
consts.RoutingDomain_Keyword,
consts.RoutingDomain_Suffix:
default:
b.err = fmt.Errorf("AddDomain: unsupported key: %v", key)
return
}
b.SimulatedDomainSet = append(b.SimulatedDomainSet, DomainSet{
Key: key,
RuleIndex: len(b.rules),
Domains: values,
})
b.rules = append(b.rules, bpfRouting{
Type: uint8(consts.RoutingType_DomainSet),
Outbound: b.OutboundToId(outbound),
})
}
func (b *RoutingMatcherBuilder) AddIp(values []netip.Prefix, outbound string) {
if b.err != nil {
return
}
lpmTrieIndex := len(b.SimulatedLpmTries)
b.SimulatedLpmTries = append(b.SimulatedLpmTries, values)
b.rules = append(b.rules, bpfRouting{
Type: uint8(consts.RoutingType_IpSet),
Value: uint32(lpmTrieIndex),
Outbound: b.OutboundToId(outbound),
})
}
func (b *RoutingMatcherBuilder) AddFinal(outbound string) {
b.Final = outbound
b.rules = append(b.rules, bpfRouting{
Type: uint8(consts.RoutingType_Final),
Outbound: b.OutboundToId(outbound),
})
}
func (b *RoutingMatcherBuilder) Build() (err error) {
if b.err != nil {
return b.err
}
// Update lpm_array_map.
for i, cidrs := range b.SimulatedLpmTries {
var keys []bpfLpmKey
var values []uint32
for _, cidr := range cidrs {
keys = append(keys, cidrToBpfLpmKey(cidr))
values = append(values, 1)
}
m, err := b.bpf.NewLpmMap(keys, values)
if err != nil {
return fmt.Errorf("NewLpmMap: %w", err)
}
// ebpf.Map cannot be BatchUpdate
if err = b.bpf.LpmArrayMap.Update(uint32(i), m, ebpf.UpdateAny); err != nil {
m.Close()
return fmt.Errorf("Update: %w", err)
}
m.Close()
}
// Update routings.
routingsLen := uint32(len(b.rules))
routingsKeys := common.ARangeU32(routingsLen)
if _, err = b.bpf.RoutingMap.BatchUpdate(routingsKeys, b.rules, &ebpf.BatchOptions{
ElemFlags: uint64(ebpf.UpdateAny),
}); err != nil {
return fmt.Errorf("BatchUpdate: %w", err)
}
if err = b.bpf.ParamMap.Update(consts.RoutingsLenKey, routingsLen, ebpf.UpdateAny); err != nil {
return fmt.Errorf("Update: %w", err)
}
return nil
}