optimize(routing): fix slow domain++ ip routing (#133)

This commit is contained in:
mzz
2023-06-11 12:48:52 +08:00
committed by GitHub
parent e885e76adf
commit 40b553edc9
7 changed files with 82 additions and 70 deletions

View File

@ -17,13 +17,10 @@ import (
"github.com/daeuniverse/dae/config" "github.com/daeuniverse/dae/config"
"github.com/daeuniverse/dae/pkg/config_parser" "github.com/daeuniverse/dae/pkg/config_parser"
"github.com/daeuniverse/dae/pkg/trie" "github.com/daeuniverse/dae/pkg/trie"
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
) )
var ValidCidrChars = trie.NewValidChars([]byte{'0', '1'})
type ResponseMatcherBuilder struct { type ResponseMatcherBuilder struct {
log *logrus.Logger log *logrus.Logger
upstreamName2Id map[string]uint8 upstreamName2Id map[string]uint8
@ -71,31 +68,6 @@ func (b *ResponseMatcherBuilder) upstreamToId(upstream string) (upstreamId const
return upstreamId, nil return upstreamId, nil
} }
func prefix2bin128(prefix netip.Prefix) (bin128 string) {
bits := prefix.Bits()
if prefix.Addr().Is4() {
bits += 96
}
ip := prefix.Addr().As16()
buf := buffer.NewBuffer(128)
defer buf.Put()
loop:
for i := 0; i < len(ip); i++ {
for j := 0; j < 8; j++ {
if (ip[i]>>j)&1 == 1 {
buf.WriteByte('1')
} else {
buf.WriteByte('0')
}
bits--
if bits == 0 {
break loop
}
}
}
return buf.String()
}
func (b *ResponseMatcherBuilder) addIp(f *config_parser.Function, cidrs []netip.Prefix, upstream *routing.Outbound) (err error) { func (b *ResponseMatcherBuilder) addIp(f *config_parser.Function, cidrs []netip.Prefix, upstream *routing.Outbound) (err error) {
upstreamId, err := b.upstreamToId(upstream.Name) upstreamId, err := b.upstreamToId(upstream.Name)
if err != nil { if err != nil {
@ -107,12 +79,7 @@ func (b *ResponseMatcherBuilder) addIp(f *config_parser.Function, cidrs []netip.
Not: f.Not, Not: f.Not,
Upstream: uint8(upstreamId), Upstream: uint8(upstreamId),
} }
var keys []string t, err := trie.NewTrieFromPrefixes(cidrs)
// Convert netip.Prefix -> '0' '1' string
for _, prefix := range cidrs {
keys = append(keys, prefix2bin128(prefix))
}
t, err := trie.NewTrie(keys, ValidCidrChars)
if err != nil { if err != nil {
return err return err
} }
@ -263,7 +230,7 @@ func (m *ResponseMatcher) Match(
domainMatchBitmap := m.domainMatcher.MatchDomainBitmap(qName) domainMatchBitmap := m.domainMatcher.MatchDomainBitmap(qName)
bin128 := make([]string, 0, len(ips)) bin128 := make([]string, 0, len(ips))
for _, ip := range ips { for _, ip := range ips {
bin128 = append(bin128, prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(ip.As16()), 128))) bin128 = append(bin128, trie.Prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(ip.As16()), 128)))
} }
goodSubrule := false goodSubrule := false

View File

@ -6,10 +6,12 @@
package dialer package dialer
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/daeuniverse/dae/common" "github.com/daeuniverse/dae/common"
"io"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
@ -580,6 +582,10 @@ func (d *Dialer) HttpCheck(ctx context.Context, u *netutils.URL, ip netip.Addr,
// Judge the status code. // Judge the status code.
if page := path.Base(req.URL.Path); strings.HasPrefix(page, "generate_") { if page := path.Base(req.URL.Path); strings.HasPrefix(page, "generate_") {
if strconv.Itoa(resp.StatusCode) != strings.TrimPrefix(page, "generate_") { if strconv.Itoa(resp.StatusCode) != strings.TrimPrefix(page, "generate_") {
b, _ := io.ReadAll(resp.Body)
buf := bytes.NewBuffer(nil)
_ = resp.Request.Write(buf)
d.Log.Debugln(buf.String(), "Resp: ", string(b))
return false, fmt.Errorf("unexpected status code: %v", resp.StatusCode) return false, fmt.Errorf("unexpected status code: %v", resp.StatusCode)
} }
return true, nil return true, nil

View File

@ -330,7 +330,7 @@ func NewControlPlane(
if err = builder.BuildKernspace(log); err != nil { if err = builder.BuildKernspace(log); err != nil {
return nil, fmt.Errorf("RoutingMatcherBuilder.BuildKernspace: %w", err) return nil, fmt.Errorf("RoutingMatcherBuilder.BuildKernspace: %w", err)
} }
routingMatcher, err := builder.BuildUserspace(core.bpf.LpmArrayMap) routingMatcher, err := builder.BuildUserspace()
if err != nil { if err != nil {
return nil, fmt.Errorf("RoutingMatcherBuilder.BuildUserspace: %w", err) return nil, fmt.Errorf("RoutingMatcherBuilder.BuildUserspace: %w", err)
} }

View File

@ -1016,8 +1016,8 @@ route(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
#ifdef __DEBUG_ROUTING #ifdef __DEBUG_ROUTING
key = match_set->type; key = match_set->type;
bpf_printk("key(match_set->type): %llu", key); bpf_printk("key(match_set->type): %llu", key);
bpf_printk("Skip to judge. bad_rule: %d, good_subrule: %d", bad_rule, bpf_printk("Skip to judge. bad_rule: %d, good_subrule: %d", isdns_must_goodsubrule_badrule&0b10,
good_subrule); isdns_must_goodsubrule_badrule&0b1);
#endif #endif
goto before_next_loop; goto before_next_loop;
} }
@ -1103,7 +1103,7 @@ route(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
before_next_loop: before_next_loop:
#ifdef __DEBUG_ROUTING #ifdef __DEBUG_ROUTING
bpf_printk("good_subrule: %d, bad_rule: %d", good_subrule, bad_rule); bpf_printk("good_subrule: %d, bad_rule: %d", isdns_must_goodsubrule_badrule&0b10, isdns_must_goodsubrule_badrule&0b1);
#endif #endif
if (match_set->outbound != OUTBOUND_LOGICAL_OR) { if (match_set->outbound != OUTBOUND_LOGICAL_OR) {
// This match_set reaches the end of subrule. // This match_set reaches the end of subrule.
@ -1119,7 +1119,7 @@ route(const __u32 flag[6], const void *l4hdr, const __be32 saddr[4],
isdns_must_goodsubrule_badrule &= ~0b10; isdns_must_goodsubrule_badrule &= ~0b10;
} }
#ifdef __DEBUG_ROUTING #ifdef __DEBUG_ROUTING
bpf_printk("_bad_rule: %d", bad_rule); bpf_printk("_bad_rule: %d", isdns_must_goodsubrule_badrule&0b1);
#endif #endif
if ((match_set->outbound & OUTBOUND_LOGICAL_MASK) != if ((match_set->outbound & OUTBOUND_LOGICAL_MASK) !=
OUTBOUND_LOGICAL_MASK) { OUTBOUND_LOGICAL_MASK) {

View File

@ -8,6 +8,7 @@ package control
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"github.com/daeuniverse/dae/pkg/trie"
"net/netip" "net/netip"
"strconv" "strconv"
@ -328,12 +329,21 @@ func (b *RoutingMatcherBuilder) BuildKernspace(log *logrus.Logger) (err error) {
return nil return nil
} }
func (b *RoutingMatcherBuilder) BuildUserspace(lpmArrayMap *ebpf.Map) (matcher *RoutingMatcher, err error) { func (b *RoutingMatcherBuilder) BuildUserspace() (matcher *RoutingMatcher, err error) {
// Build domainMatcher // Build domainMatcher
domainMatcher := domain_matcher.NewAhocorasickSlimtrie(b.log, consts.MaxMatchSetLen) domainMatcher := domain_matcher.NewAhocorasickSlimtrie(b.log, consts.MaxMatchSetLen)
for _, domains := range b.simulatedDomainSet { for _, domains := range b.simulatedDomainSet {
domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key) domainMatcher.AddSet(domains.RuleIndex, domains.Domains, domains.Key)
} }
// Build Ip matcher.
var lpmMatcher []*trie.Trie
for _, prefixes := range b.simulatedLpmTries {
t, err := trie.NewTrieFromPrefixes(prefixes)
if err != nil {
return nil, err
}
lpmMatcher = append(lpmMatcher, t)
}
if err = domainMatcher.Build(); err != nil { if err = domainMatcher.Build(); err != nil {
return nil, err return nil, err
} }
@ -345,7 +355,7 @@ func (b *RoutingMatcherBuilder) BuildUserspace(lpmArrayMap *ebpf.Map) (matcher *
} }
return &RoutingMatcher{ return &RoutingMatcher{
lpmArrayMap: lpmArrayMap, lpmMatcher: lpmMatcher,
domainMatcher: domainMatcher, domainMatcher: domainMatcher,
matches: b.rules, matches: b.rules,
}, nil }, nil

View File

@ -8,16 +8,16 @@ package control
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"github.com/daeuniverse/dae/pkg/trie"
"net" "net"
"net/netip"
"github.com/cilium/ebpf"
"github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/common/consts" "github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/component/routing" "github.com/daeuniverse/dae/component/routing"
) )
type RoutingMatcher struct { type RoutingMatcher struct {
lpmArrayMap *ebpf.Map lpmMatcher []*trie.Trie
domainMatcher routing.DomainMatcher // All domain matchSets use one DomainMatcher. domainMatcher routing.DomainMatcher // All domain matchSets use one DomainMatcher.
matches []bpfMatchSet matches []bpfMatchSet
@ -38,19 +38,12 @@ func (m *RoutingMatcher) Match(
if len(sourceAddr) != net.IPv6len || len(destAddr) != net.IPv6len || len(mac) != net.IPv6len { if len(sourceAddr) != net.IPv6len || len(destAddr) != net.IPv6len || len(mac) != net.IPv6len {
return 0, 0, false, fmt.Errorf("bad address length") return 0, 0, false, fmt.Errorf("bad address length")
} }
lpmKeys := make([]*_bpfLpmKey, consts.MatchType_Mac+1)
lpmKeys[consts.MatchType_IpSet] = &_bpfLpmKey{ bin128s := make([]string, consts.MatchType_Mac+1)
PrefixLen: 128, bin128s[consts.MatchType_IpSet] = trie.Prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(destAddr)), 128))
Data: common.Ipv6ByteSliceToUint32Array(destAddr), bin128s[consts.MatchType_SourceIpSet] = trie.Prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(sourceAddr)), 128))
} bin128s[consts.MatchType_Mac] = trie.Prefix2bin128(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(mac)), 128))
lpmKeys[consts.MatchType_SourceIpSet] = &_bpfLpmKey{
PrefixLen: 128,
Data: common.Ipv6ByteSliceToUint32Array(sourceAddr),
}
lpmKeys[consts.MatchType_Mac] = &_bpfLpmKey{
PrefixLen: 128,
Data: common.Ipv6ByteSliceToUint32Array(mac),
}
var domainMatchBitmap []uint32 var domainMatchBitmap []uint32
if domain != "" { if domain != "" {
domainMatchBitmap = m.domainMatcher.MatchDomainBitmap(domain) domainMatchBitmap = m.domainMatcher.MatchDomainBitmap(domain)
@ -65,19 +58,10 @@ func (m *RoutingMatcher) Match(
switch consts.MatchType(match.Type) { switch consts.MatchType(match.Type) {
case consts.MatchType_IpSet, consts.MatchType_SourceIpSet, consts.MatchType_Mac: case consts.MatchType_IpSet, consts.MatchType_SourceIpSet, consts.MatchType_Mac:
lpmIndex := uint32(binary.LittleEndian.Uint16(match.Value[:])) lpmIndex := uint32(binary.LittleEndian.Uint16(match.Value[:]))
var lpm *ebpf.Map m := m.lpmMatcher[lpmIndex]
if err = m.lpmArrayMap.Lookup(lpmIndex, &lpm); err != nil { if m.HasPrefix(bin128s[match.Type]) {
//logrus.Debugln("m.lpmArrayMap.Lookup:", err) goodSubrule = true
break
} }
var v uint32
if err = lpm.Lookup(*lpmKeys[int(match.Type)], &v); err != nil {
_ = lpm.Close()
//logrus.Debugln("lpm.Lookup:", err, lpmKeys[int(match.Type)], match.Type, destAddr)
break
}
_ = lpm.Close()
goodSubrule = true
case consts.MatchType_DomainSet: case consts.MatchType_DomainSet:
if domainMatchBitmap != nil && (domainMatchBitmap[i/32]>>(i%32))&1 > 0 { if domainMatchBitmap != nil && (domainMatchBitmap[i/32]>>(i%32))&1 > 0 {
goodSubrule = true goodSubrule = true

View File

@ -5,13 +5,17 @@ package trie
import ( import (
"fmt" "fmt"
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
"math/bits" "math/bits"
"net/netip"
"sort" "sort"
"github.com/daeuniverse/dae/common" "github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/common/bitlist" "github.com/daeuniverse/dae/common/bitlist"
) )
var ValidCidrChars = NewValidChars([]byte{'0', '1'})
type ValidChars struct { type ValidChars struct {
table [256]byte table [256]byte
n uint16 n uint16
@ -87,6 +91,47 @@ type Trie struct {
chars *ValidChars chars *ValidChars
} }
func Prefix2bin128(prefix netip.Prefix) (bin128 string) {
n := prefix.Bits()
if n == -1 {
panic("! BadPrefix: " + prefix.String())
}
if prefix.Addr().Is4() {
n += 96
}
ip := prefix.Addr().As16()
buf := buffer.NewBuffer(128)
defer buf.Put()
loop:
for i := 0; i < len(ip); i++ {
for j := 7; j >= 0; j-- {
if (ip[i]>>j)&1 == 1 {
_ = buf.WriteByte('1')
} else {
_ = buf.WriteByte('0')
}
n--
if n == 0 {
break loop
}
}
}
return buf.String()
}
func NewTrieFromPrefixes(cidrs []netip.Prefix) (*Trie, error) {
var keys []string
// Convert netip.Prefix -> '0' '1' string
for _, prefix := range cidrs {
keys = append(keys, Prefix2bin128(prefix))
}
t, err := NewTrie(keys, ValidCidrChars)
if err != nil {
return nil, err
}
return t, nil
}
// NewTrie creates a new *Trie struct, from a slice of sorted strings. // NewTrie creates a new *Trie struct, from a slice of sorted strings.
func NewTrie(keys []string, chars *ValidChars) (*Trie, error) { func NewTrie(keys []string, chars *ValidChars) (*Trie, error) {
// Check chars. // Check chars.