2024-01-04 16:28:16 +07:00
|
|
|
/*
|
|
|
|
* SPDX-License-Identifier: AGPL-3.0-only
|
|
|
|
* Copyright (c) 2022-2024, daeuniverse Organization <dae@v2raya.org>
|
|
|
|
*/
|
|
|
|
|
2023-02-19 18:15:16 +07:00
|
|
|
// Package trie is modified from https://github.com/openacid/succinct/blob/loc100/sskv.go.
|
2023-02-25 01:38:21 +07:00
|
|
|
// Slower than about 30% but more than 40% memory saving.
|
2023-02-19 18:15:16 +07:00
|
|
|
|
2023-02-19 13:08:13 +07:00
|
|
|
package trie
|
|
|
|
|
2023-02-19 18:15:16 +07:00
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"math/bits"
|
2023-06-11 11:48:52 +07:00
|
|
|
"net/netip"
|
2023-03-31 22:36:24 +07:00
|
|
|
"sort"
|
2023-04-23 12:27:29 +07:00
|
|
|
|
|
|
|
"github.com/daeuniverse/dae/common"
|
|
|
|
"github.com/daeuniverse/dae/common/bitlist"
|
2024-04-24 01:22:50 +07:00
|
|
|
"github.com/daeuniverse/outbound/pool"
|
2023-02-19 18:15:16 +07:00
|
|
|
)
|
|
|
|
|
2023-06-11 11:48:52 +07:00
|
|
|
var ValidCidrChars = NewValidChars([]byte{'0', '1'})
|
|
|
|
|
2023-02-25 01:38:21 +07:00
|
|
|
type ValidChars struct {
|
|
|
|
table [256]byte
|
|
|
|
n uint16
|
|
|
|
zeroChar byte
|
2023-02-19 18:15:16 +07:00
|
|
|
}
|
|
|
|
|
2023-02-25 01:38:21 +07:00
|
|
|
func NewValidChars(validChars []byte) (v *ValidChars) {
|
|
|
|
v = new(ValidChars)
|
|
|
|
for _, c := range validChars {
|
|
|
|
if v.n == 0 {
|
|
|
|
v.zeroChar = c
|
|
|
|
}
|
|
|
|
v.table[c] = byte(v.n)
|
|
|
|
v.n++
|
|
|
|
}
|
|
|
|
return v
|
|
|
|
}
|
|
|
|
|
|
|
|
func (v *ValidChars) Size() int {
|
|
|
|
return int(v.n)
|
|
|
|
}
|
2023-02-19 18:15:16 +07:00
|
|
|
|
2023-02-25 01:38:21 +07:00
|
|
|
func (v *ValidChars) IsValidChar(c byte) bool {
|
|
|
|
return v.table[c] > 0 || c == v.zeroChar
|
2023-02-19 18:15:16 +07:00
|
|
|
}
|
2023-02-19 13:08:13 +07:00
|
|
|
|
|
|
|
// Trie is a succinct, sorted and static string set impl with compacted trie as
|
|
|
|
// storage. The space cost is about half lower than the original data.
|
|
|
|
//
|
|
|
|
// # Implementation
|
|
|
|
//
|
|
|
|
// It stores sorted strings in a compacted trie(AKA prefix tree).
|
|
|
|
// A trie node has at most 256 outgoing labels.
|
|
|
|
// A label is just a single byte.
|
|
|
|
// E.g., [ab, abc, abcd, axy, buv] is represented with a trie like the following:
|
|
|
|
// (Numbers are node id)
|
|
|
|
//
|
|
|
|
// ^ -a-> 1 -b-> 3 $
|
|
|
|
// | | `c-> 6 $
|
|
|
|
// | | `d-> 9 $
|
|
|
|
// | `x-> 4 -y-> 7 $
|
|
|
|
// `b-> 2 -u-> 5 -v-> 8 $
|
|
|
|
//
|
|
|
|
// Internally it uses a packed []byte and a bitmap with `len([]byte)` bits to
|
|
|
|
// describe the outgoing labels of a node,:
|
|
|
|
//
|
|
|
|
// ^: ab 00
|
|
|
|
// 1: bx 00
|
|
|
|
// 2: u 0
|
|
|
|
// 3: c 0
|
|
|
|
// 4: y 0
|
|
|
|
// 5: v 0
|
|
|
|
// 6: d 0
|
|
|
|
// 7: ø
|
|
|
|
// 8: ø
|
|
|
|
// 9: ø
|
|
|
|
//
|
|
|
|
// In storage it packs labels together and bitmaps joined with separator `1`:
|
|
|
|
//
|
|
|
|
// labels(ignore space): "ab bx u c y v d"
|
|
|
|
// label bitmap: 0010010101010101111
|
|
|
|
//
|
|
|
|
// Finally leaf nodes are indicated by another bitmap `leaves`, in which a `1`
|
|
|
|
// at i-th bit indicates the i-th node is a leaf:
|
|
|
|
//
|
|
|
|
// leaves: 0001001111
|
|
|
|
type Trie struct {
|
|
|
|
leaves, labelBitmap []uint64
|
|
|
|
ranks, selects []int32
|
2023-02-19 18:15:16 +07:00
|
|
|
labels *bitlist.CompactBitList
|
|
|
|
ranksBL, selectsBL *bitlist.CompactBitList
|
2023-02-25 01:38:21 +07:00
|
|
|
|
|
|
|
chars *ValidChars
|
2023-02-19 13:08:13 +07:00
|
|
|
}
|
|
|
|
|
2023-06-11 11:48:52 +07:00
|
|
|
func Prefix2bin128(prefix netip.Prefix) (bin128 string) {
|
|
|
|
n := prefix.Bits()
|
|
|
|
if n == -1 {
|
|
|
|
panic("! BadPrefix: " + prefix.String())
|
|
|
|
}
|
|
|
|
if prefix.Addr().Is4() {
|
|
|
|
n += 96
|
|
|
|
}
|
|
|
|
ip := prefix.Addr().As16()
|
2023-11-15 13:32:57 +07:00
|
|
|
buf := pool.GetBuffer()
|
|
|
|
defer pool.PutBuffer(buf)
|
2023-06-11 11:48:52 +07:00
|
|
|
loop:
|
|
|
|
for i := 0; i < len(ip); i++ {
|
|
|
|
for j := 7; j >= 0; j-- {
|
|
|
|
if (ip[i]>>j)&1 == 1 {
|
|
|
|
_ = buf.WriteByte('1')
|
|
|
|
} else {
|
|
|
|
_ = buf.WriteByte('0')
|
|
|
|
}
|
|
|
|
n--
|
|
|
|
if n == 0 {
|
|
|
|
break loop
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return buf.String()
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewTrieFromPrefixes(cidrs []netip.Prefix) (*Trie, error) {
|
|
|
|
var keys []string
|
|
|
|
// Convert netip.Prefix -> '0' '1' string
|
|
|
|
for _, prefix := range cidrs {
|
|
|
|
keys = append(keys, Prefix2bin128(prefix))
|
|
|
|
}
|
|
|
|
t, err := NewTrie(keys, ValidCidrChars)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return t, nil
|
|
|
|
}
|
|
|
|
|
2023-02-19 13:08:13 +07:00
|
|
|
// NewTrie creates a new *Trie struct, from a slice of sorted strings.
|
2023-02-25 01:38:21 +07:00
|
|
|
func NewTrie(keys []string, chars *ValidChars) (*Trie, error) {
|
2023-02-19 18:15:16 +07:00
|
|
|
// Check chars.
|
2023-03-31 22:36:24 +07:00
|
|
|
keys = common.Deduplicate(keys)
|
|
|
|
sort.Strings(keys)
|
2023-02-19 18:15:16 +07:00
|
|
|
for _, key := range keys {
|
|
|
|
for _, c := range []byte(key) {
|
2023-02-25 01:38:21 +07:00
|
|
|
if !chars.IsValidChar(c) {
|
2023-02-19 18:15:16 +07:00
|
|
|
return nil, fmt.Errorf("char out of range: %c", c)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2023-02-19 13:08:13 +07:00
|
|
|
|
2023-02-25 01:38:21 +07:00
|
|
|
ss := &Trie{
|
|
|
|
chars: chars,
|
|
|
|
labels: bitlist.NewCompactBitList(bits.Len(uint(chars.Size()))),
|
|
|
|
}
|
2023-02-19 13:08:13 +07:00
|
|
|
lIdx := 0
|
|
|
|
|
|
|
|
type qElt struct{ s, e, col int }
|
|
|
|
|
|
|
|
queue := []qElt{{0, len(keys), 0}}
|
|
|
|
|
|
|
|
for i := 0; i < len(queue); i++ {
|
|
|
|
elt := queue[i]
|
|
|
|
|
|
|
|
if elt.col == len(keys[elt.s]) {
|
|
|
|
// a leaf node
|
|
|
|
elt.s++
|
|
|
|
setBit(&ss.leaves, i, 1)
|
|
|
|
}
|
|
|
|
|
|
|
|
for j := elt.s; j < elt.e; {
|
|
|
|
|
|
|
|
frm := j
|
|
|
|
|
|
|
|
for ; j < elt.e && keys[j][elt.col] == keys[frm][elt.col]; j++ {
|
|
|
|
}
|
|
|
|
|
|
|
|
queue = append(queue, qElt{frm, j, elt.col + 1})
|
2023-02-25 01:38:21 +07:00
|
|
|
ss.labels.Append(uint64(chars.table[keys[frm][elt.col]]))
|
2023-02-19 13:08:13 +07:00
|
|
|
setBit(&ss.labelBitmap, lIdx, 0)
|
|
|
|
lIdx++
|
|
|
|
}
|
|
|
|
|
|
|
|
setBit(&ss.labelBitmap, lIdx, 1)
|
|
|
|
lIdx++
|
|
|
|
}
|
|
|
|
|
|
|
|
ss.init()
|
2023-02-19 18:15:16 +07:00
|
|
|
|
|
|
|
// Tighten.
|
|
|
|
ss.labels.Tighten()
|
|
|
|
|
|
|
|
leaves := make([]uint64, len(ss.leaves))
|
|
|
|
copy(leaves, ss.leaves)
|
|
|
|
ss.leaves = leaves
|
|
|
|
|
|
|
|
labelBitmap := make([]uint64, len(ss.labelBitmap))
|
|
|
|
copy(labelBitmap, ss.labelBitmap)
|
|
|
|
ss.labelBitmap = labelBitmap
|
|
|
|
|
|
|
|
ss.ranksBL = bitlist.NewCompactBitList(bits.Len64(uint64(ss.ranks[len(ss.ranks)-1])))
|
|
|
|
ss.selectsBL = bitlist.NewCompactBitList(bits.Len64(uint64(ss.selects[len(ss.selects)-1])))
|
|
|
|
for _, v := range ss.ranks {
|
|
|
|
ss.ranksBL.Append(uint64(v))
|
|
|
|
}
|
|
|
|
for _, v := range ss.selects {
|
|
|
|
ss.selectsBL.Append(uint64(v))
|
|
|
|
}
|
|
|
|
ss.ranksBL.Tighten()
|
|
|
|
ss.selectsBL.Tighten()
|
|
|
|
ss.ranks = nil
|
|
|
|
ss.selects = nil
|
|
|
|
|
|
|
|
return ss, nil
|
2023-02-19 13:08:13 +07:00
|
|
|
}
|
|
|
|
|
|
|
|
// HasPrefix query for a word and return whether a prefix of the word is in the Trie.
|
|
|
|
func (ss *Trie) HasPrefix(word string) bool {
|
|
|
|
|
|
|
|
nodeId, bmIdx := 0, 0
|
|
|
|
|
|
|
|
for i := 0; i < len(word); i++ {
|
2023-02-19 18:15:16 +07:00
|
|
|
if getBit(ss.leaves, nodeId) != 0 {
|
|
|
|
return true
|
|
|
|
}
|
2023-02-19 13:08:13 +07:00
|
|
|
c := word[i]
|
2023-02-25 01:38:21 +07:00
|
|
|
if !ss.chars.IsValidChar(c) {
|
|
|
|
return false
|
|
|
|
}
|
2023-02-19 13:08:13 +07:00
|
|
|
for ; ; bmIdx++ {
|
|
|
|
if getBit(ss.labelBitmap, bmIdx) != 0 {
|
|
|
|
// no more labels in this node
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
2023-02-25 01:38:21 +07:00
|
|
|
if byte(ss.labels.Get(bmIdx-nodeId)) == ss.chars.table[c] {
|
2023-02-19 13:08:13 +07:00
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// go to next level
|
|
|
|
|
2023-02-19 18:15:16 +07:00
|
|
|
nodeId = countZeros(ss.labelBitmap, ss.ranksBL, bmIdx+1)
|
|
|
|
bmIdx = selectIthOne(ss.labelBitmap, ss.ranksBL, ss.selectsBL, nodeId-1) + 1
|
2023-02-19 13:08:13 +07:00
|
|
|
}
|
|
|
|
|
2023-02-19 18:15:16 +07:00
|
|
|
return getBit(ss.leaves, nodeId) != 0
|
2023-02-19 13:08:13 +07:00
|
|
|
}
|
|
|
|
|
|
|
|
func setBit(bm *[]uint64, i int, v int) {
|
|
|
|
for i>>6 >= len(*bm) {
|
|
|
|
*bm = append(*bm, 0)
|
|
|
|
}
|
|
|
|
(*bm)[i>>6] |= uint64(v) << uint(i&63)
|
|
|
|
}
|
|
|
|
|
|
|
|
func getBit(bm []uint64, i int) uint64 {
|
|
|
|
return bm[i>>6] & (1 << uint(i&63))
|
|
|
|
}
|
|
|
|
|
|
|
|
// init builds pre-calculated cache to speed up rank() and select()
|
|
|
|
func (ss *Trie) init() {
|
|
|
|
ss.ranks = []int32{0}
|
|
|
|
for i := 0; i < len(ss.labelBitmap); i++ {
|
|
|
|
n := bits.OnesCount64(ss.labelBitmap[i])
|
|
|
|
ss.ranks = append(ss.ranks, ss.ranks[len(ss.ranks)-1]+int32(n))
|
|
|
|
}
|
|
|
|
|
|
|
|
ss.selects = []int32{}
|
|
|
|
n := 0
|
|
|
|
for i := 0; i < len(ss.labelBitmap)<<6; i++ {
|
|
|
|
z := int(ss.labelBitmap[i>>6]>>uint(i&63)) & 1
|
|
|
|
if z == 1 && n&63 == 0 {
|
|
|
|
ss.selects = append(ss.selects, int32(i))
|
|
|
|
}
|
|
|
|
n += z
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// countZeros counts the number of "0" in a bitmap before the i-th bit(excluding
|
|
|
|
// the i-th bit) on behalf of rank index.
|
|
|
|
// E.g.:
|
|
|
|
//
|
|
|
|
// countZeros("010010", 4) == 3
|
|
|
|
// // 012345
|
2023-02-19 18:15:16 +07:00
|
|
|
func countZeros(bm []uint64, ranks *bitlist.CompactBitList, i int) int {
|
|
|
|
return i - int(ranks.Get(i>>6)) - bits.OnesCount64(bm[i>>6]&(1<<uint(i&63)-1))
|
2023-02-19 13:08:13 +07:00
|
|
|
}
|
|
|
|
|
|
|
|
// selectIthOne returns the index of the i-th "1" in a bitmap, on behalf of rank
|
|
|
|
// and select indexes.
|
|
|
|
// E.g.:
|
|
|
|
//
|
|
|
|
// selectIthOne("010010", 1) == 4
|
|
|
|
// // 012345
|
2023-02-19 18:15:16 +07:00
|
|
|
func selectIthOne(bm []uint64, ranks, selects *bitlist.CompactBitList, i int) int {
|
|
|
|
base := int(selects.Get(i>>6)) & ^63
|
|
|
|
findIthOne := i - int(ranks.Get(base>>6))
|
2023-02-19 13:08:13 +07:00
|
|
|
|
|
|
|
for i := base >> 6; i < len(bm); i++ {
|
|
|
|
bitIdx := 0
|
|
|
|
for w := bm[i]; w > 0; {
|
|
|
|
findIthOne -= int(w & 1)
|
|
|
|
if findIthOne < 0 {
|
|
|
|
return i<<6 + bitIdx
|
|
|
|
}
|
|
|
|
t0 := bits.TrailingZeros64(w &^ 1)
|
|
|
|
w >>= uint(t0)
|
|
|
|
bitIdx += t0
|
|
|
|
}
|
|
|
|
}
|
|
|
|
panic("no more ones")
|
|
|
|
}
|