mirror of
https://github.com/daeuniverse/dae.git
synced 2025-07-04 23:40:30 +07:00
optimize: optimize momory of domain matching trie
This commit is contained in:
143
common/bitlist/bitlist.go
Normal file
143
common/bitlist/bitlist.go
Normal file
@ -0,0 +1,143 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package bitlist
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/mzz2017/softwind/common"
|
||||
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
|
||||
"github.com/mzz2017/softwind/pool"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
// CompactBitList allows your units to be of arbitrary bit size.
|
||||
type CompactBitList struct {
|
||||
unitBitSize int
|
||||
size int
|
||||
b *buffer.Buffer
|
||||
unitNum int
|
||||
}
|
||||
|
||||
func NewCompactBitList(unitBitSize int) *CompactBitList {
|
||||
return &CompactBitList{
|
||||
unitBitSize: unitBitSize,
|
||||
size: 0,
|
||||
b: buffer.NewBuffer(1),
|
||||
}
|
||||
}
|
||||
|
||||
// Set is not optimized yet.
|
||||
func (m *CompactBitList) Set(iUnit int, v uint64) {
|
||||
if bits.Len64(v) > m.unitBitSize {
|
||||
panic(fmt.Sprintf("value %v exceeds unit bit size", v))
|
||||
}
|
||||
m.growByUnitIndex(iUnit)
|
||||
b := m.b.Bytes()
|
||||
i := iUnit * m.unitBitSize / 8
|
||||
j := iUnit * m.unitBitSize % 8
|
||||
for unitToTravel := m.unitBitSize; unitToTravel > 0; unitToTravel -= 8 {
|
||||
k := 0
|
||||
for ; k < unitToTravel && j+k < 8; k++ {
|
||||
b[i] &= ^(1 << (k + j)) // clear bit.
|
||||
val := uint8((v & (1 << k)) << j)
|
||||
b[i] |= val // set bit.
|
||||
}
|
||||
// Now unitBitSize is traveled and we should break the loop,
|
||||
// OR we did not travel the byte and we need to travel the next byte.
|
||||
if k >= unitToTravel {
|
||||
break
|
||||
}
|
||||
i++
|
||||
bakJ := j
|
||||
j = k
|
||||
for ; k < unitToTravel && k < 8; k++ {
|
||||
b[i] &= ^(1 << (k - j)) // clear bit.
|
||||
val := uint8((v & (1 << k)) >> j)
|
||||
b[i] |= val // set bit.
|
||||
}
|
||||
v >>= 8
|
||||
j = (bakJ + 8) % 8
|
||||
}
|
||||
m.unitNum = common.Max(m.unitNum, iUnit+1)
|
||||
}
|
||||
|
||||
func (m *CompactBitList) Get(iUnit int) (v uint64) {
|
||||
bitBoundary := (iUnit + 1) * m.unitBitSize
|
||||
if m.b.Len()*8 < bitBoundary {
|
||||
return 0
|
||||
}
|
||||
|
||||
b := m.b.Bytes()
|
||||
i := iUnit * m.unitBitSize / 8
|
||||
j := iUnit * m.unitBitSize % 8
|
||||
|
||||
var val uint8
|
||||
byteSpace := 8 - j
|
||||
// 11111111
|
||||
// |
|
||||
// j byteSpace = 6, unitBitSize = 2
|
||||
// 11 We only copy those 2 bits, so we left shift 4 and right shift 4+2.
|
||||
if byteSpace > m.unitBitSize {
|
||||
toTrimLeft := byteSpace - m.unitBitSize
|
||||
return uint64((b[i] << toTrimLeft) >> (toTrimLeft + j))
|
||||
} else {
|
||||
// Trim right only.
|
||||
val = b[i] >> j
|
||||
}
|
||||
v |= uint64(val)
|
||||
|
||||
offset := 8 - j
|
||||
i++
|
||||
// Now we have multiple of 8 bits spaces to move.
|
||||
unitToTravel := m.unitBitSize - offset
|
||||
for ; unitToTravel >= 8; unitToTravel, i, offset = unitToTravel-8, i+1, offset+8 {
|
||||
// 11111111
|
||||
// |
|
||||
// p
|
||||
// 11111111 We copy whole 8 bits
|
||||
v |= uint64(b[i]) << offset
|
||||
}
|
||||
if unitToTravel == 0 {
|
||||
return v
|
||||
}
|
||||
|
||||
// 11111111
|
||||
// |
|
||||
// p unitToTravel = 3
|
||||
// 111 We only copy those 3 bits, so we left shift 5 and right shift 5.
|
||||
toTrimLeft := 8 - unitToTravel
|
||||
if offset > toTrimLeft {
|
||||
v |= uint64(b[i]<<toTrimLeft) << (offset - toTrimLeft)
|
||||
} else {
|
||||
v |= uint64(b[i]<<toTrimLeft) >> (toTrimLeft - offset)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (m *CompactBitList) Append(v uint64) {
|
||||
m.Set(m.unitNum, v)
|
||||
}
|
||||
|
||||
func (m *CompactBitList) growByUnitIndex(i int) {
|
||||
if bitBoundary := (i + 1) * m.unitBitSize; m.b.Len()*8 < bitBoundary {
|
||||
needBytes := bitBoundary / 8
|
||||
if bitBoundary%8 != 0 {
|
||||
needBytes++
|
||||
}
|
||||
m.b.Extend(needBytes - m.b.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func (m *CompactBitList) Tighten() {
|
||||
a := pool.B(make([]byte, m.b.Len()))
|
||||
copy(a, m.b.Bytes())
|
||||
m.b.Put()
|
||||
m.b = buffer.NewBufferFrom(a)
|
||||
}
|
||||
|
||||
func (m *CompactBitList) Put() {
|
||||
m.b.Put()
|
||||
}
|
90
common/bitlist/bitlist_test.go
Normal file
90
common/bitlist/bitlist_test.go
Normal file
@ -0,0 +1,90 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: AGPL-3.0-only
|
||||
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
|
||||
*/
|
||||
|
||||
package bitlist
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBitList6(t *testing.T) {
|
||||
bm := NewCompactBitList(6)
|
||||
bm.Set(1, 0b110010)
|
||||
if v := bm.Get(1); v != 0b110010 {
|
||||
t.Fatal(fmt.Errorf("expect 0b%08b, got 0b%08b", 0b110010, v))
|
||||
}
|
||||
bm.Tighten()
|
||||
if v := bm.Get(1); v != 0b110010 {
|
||||
t.Fatal(fmt.Errorf("expect 0b%08b, got 0b%08b", 0b110010, v))
|
||||
}
|
||||
bm.Set(13, 0b110010)
|
||||
if v := bm.Get(13); v != 0b110010 {
|
||||
t.Fatal(fmt.Errorf("expect 0b%08b, got 0b%08b", 0b110010, v))
|
||||
}
|
||||
bm.Tighten()
|
||||
if bm.b.Cap() != 11 {
|
||||
t.Fatal("failed to tighten", bm.b.Cap())
|
||||
}
|
||||
if v := bm.Get(13); v != 0b110010 {
|
||||
t.Fatal(fmt.Errorf("expect 0b%08b, got 0b%08b", 0b110010, v))
|
||||
}
|
||||
bm.Append(0b110010)
|
||||
if v := bm.Get(14); v != 0b110010 {
|
||||
t.Fatal(fmt.Errorf("expect 0b%08b, got 0b%08b", 0b110010, v))
|
||||
}
|
||||
if bm.b.Cap() != 32 {
|
||||
t.Fatal("unexpected grow behavior", bm.b.Cap())
|
||||
}
|
||||
bm.Tighten()
|
||||
if bm.b.Cap() != 12 {
|
||||
t.Fatal("failed to tighten", bm.b.Cap())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitList19(t *testing.T) {
|
||||
bm := NewCompactBitList(19)
|
||||
bm.Set(1, 0b1110010110010110010)
|
||||
if v := bm.Get(1); v != 0b1110010110010110010 {
|
||||
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b1110010110010110010, v))
|
||||
}
|
||||
bm.Tighten()
|
||||
if v := bm.Get(1); v != 0b1110010110010110010 {
|
||||
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b1110010110010110010, v))
|
||||
}
|
||||
bm.Set(13, 0b1110010110010110010)
|
||||
if v := bm.Get(13); v != 0b1110010110010110010 {
|
||||
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b1110010110010110010, v))
|
||||
}
|
||||
bm.Tighten()
|
||||
if bm.b.Cap() != 34 {
|
||||
t.Fatal("failed to tighten", bm.b.Cap())
|
||||
}
|
||||
if v := bm.Get(13); v != 0b1110010110010110010 {
|
||||
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b1110010110010110010, v))
|
||||
}
|
||||
bm.Append(0b1110010110010110010)
|
||||
if v := bm.Get(14); v != 0b1110010110010110010 {
|
||||
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b1110010110010110010, v))
|
||||
}
|
||||
if bm.b.Cap() != 128 {
|
||||
t.Fatal("unexpected grow behavior", bm.b.Cap())
|
||||
}
|
||||
bm.Tighten()
|
||||
if bm.b.Cap() != 36 {
|
||||
t.Fatal("failed to tighten", bm.b.Cap())
|
||||
}
|
||||
bm.Set(1, 0b0000000000000000000)
|
||||
if v := bm.Get(1); v != 0b0000000000000000000 {
|
||||
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b0000000000000000000, v))
|
||||
}
|
||||
bm.Set(2, 0b1111111111111111111)
|
||||
if v := bm.Get(2); v != 0b1111111111111111111 {
|
||||
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b1111111111111111111, v))
|
||||
}
|
||||
if v := bm.Get(1); v != 0b0000000000000000000 {
|
||||
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b0000000000000000000, v))
|
||||
}
|
||||
}
|
@ -173,7 +173,7 @@ func (n *AhocorasickSlimtrie) Build() (err error) {
|
||||
}
|
||||
toBuild = ToSuffixTrieStrings(toBuild)
|
||||
sort.Strings(toBuild)
|
||||
n.trie[i] = trie.NewTrie(toBuild)
|
||||
n.trie[i], err = trie.NewTrie(toBuild)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
127
pkg/trie/trie.go
127
pkg/trie/trie.go
@ -1,7 +1,62 @@
|
||||
// Package succinct is modified from https://github.com/openacid/succinct/blob/loc100/sskv.go.
|
||||
// Package trie is modified from https://github.com/openacid/succinct/blob/loc100/sskv.go.
|
||||
// Slower than about 50% but more memory saving.
|
||||
|
||||
package trie
|
||||
|
||||
import "math/bits"
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/v2rayA/dae/common/bitlist"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
var table = [256]byte{
|
||||
97: 0, // 'a'
|
||||
98: 1,
|
||||
99: 2,
|
||||
100: 3,
|
||||
101: 4,
|
||||
102: 5,
|
||||
103: 6,
|
||||
104: 7,
|
||||
105: 8,
|
||||
106: 9,
|
||||
107: 10,
|
||||
108: 11,
|
||||
109: 12,
|
||||
110: 13,
|
||||
111: 14,
|
||||
112: 15,
|
||||
113: 16,
|
||||
114: 17,
|
||||
115: 18,
|
||||
116: 19,
|
||||
117: 20,
|
||||
118: 21,
|
||||
119: 22,
|
||||
120: 23,
|
||||
121: 24,
|
||||
122: 25,
|
||||
'-': 26,
|
||||
'.': 27,
|
||||
'^': 28,
|
||||
'$': 29,
|
||||
'1': 30,
|
||||
'2': 31,
|
||||
'3': 32,
|
||||
'4': 33,
|
||||
'5': 34,
|
||||
'6': 35,
|
||||
'7': 36,
|
||||
'8': 37,
|
||||
'9': 38,
|
||||
'0': 39,
|
||||
}
|
||||
|
||||
const N = 40
|
||||
|
||||
func IsValidChar(b byte) bool {
|
||||
return table[b] > 0 || b == 'a'
|
||||
}
|
||||
|
||||
// Trie is a succinct, sorted and static string set impl with compacted trie as
|
||||
// storage. The space cost is about half lower than the original data.
|
||||
@ -45,14 +100,25 @@ import "math/bits"
|
||||
// leaves: 0001001111
|
||||
type Trie struct {
|
||||
leaves, labelBitmap []uint64
|
||||
labels []byte
|
||||
ranks, selects []int32
|
||||
labels *bitlist.CompactBitList
|
||||
ranksBL, selectsBL *bitlist.CompactBitList
|
||||
}
|
||||
|
||||
// NewTrie creates a new *Trie struct, from a slice of sorted strings.
|
||||
func NewTrie(keys []string) *Trie {
|
||||
func NewTrie(keys []string) (*Trie, error) {
|
||||
|
||||
// Check chars.
|
||||
for _, key := range keys {
|
||||
for _, c := range []byte(key) {
|
||||
if !IsValidChar(c) {
|
||||
return nil, fmt.Errorf("char out of range: %c", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ss := &Trie{}
|
||||
ss.labels = bitlist.NewCompactBitList(bits.Len8(N))
|
||||
lIdx := 0
|
||||
|
||||
type qElt struct{ s, e, col int }
|
||||
@ -76,7 +142,7 @@ func NewTrie(keys []string) *Trie {
|
||||
}
|
||||
|
||||
queue = append(queue, qElt{frm, j, elt.col + 1})
|
||||
ss.labels = append(ss.labels, keys[frm][elt.col])
|
||||
ss.labels.Append(uint64(table[keys[frm][elt.col]]))
|
||||
setBit(&ss.labelBitmap, lIdx, 0)
|
||||
lIdx++
|
||||
}
|
||||
@ -86,7 +152,32 @@ func NewTrie(keys []string) *Trie {
|
||||
}
|
||||
|
||||
ss.init()
|
||||
return ss
|
||||
|
||||
// Tighten.
|
||||
ss.labels.Tighten()
|
||||
|
||||
leaves := make([]uint64, len(ss.leaves))
|
||||
copy(leaves, ss.leaves)
|
||||
ss.leaves = leaves
|
||||
|
||||
labelBitmap := make([]uint64, len(ss.labelBitmap))
|
||||
copy(labelBitmap, ss.labelBitmap)
|
||||
ss.labelBitmap = labelBitmap
|
||||
|
||||
ss.ranksBL = bitlist.NewCompactBitList(bits.Len64(uint64(ss.ranks[len(ss.ranks)-1])))
|
||||
ss.selectsBL = bitlist.NewCompactBitList(bits.Len64(uint64(ss.selects[len(ss.selects)-1])))
|
||||
for _, v := range ss.ranks {
|
||||
ss.ranksBL.Append(uint64(v))
|
||||
}
|
||||
for _, v := range ss.selects {
|
||||
ss.selectsBL.Append(uint64(v))
|
||||
}
|
||||
ss.ranksBL.Tighten()
|
||||
ss.selectsBL.Tighten()
|
||||
ss.ranks = nil
|
||||
ss.selects = nil
|
||||
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// HasPrefix query for a word and return whether a prefix of the word is in the Trie.
|
||||
@ -95,6 +186,9 @@ func (ss *Trie) HasPrefix(word string) bool {
|
||||
nodeId, bmIdx := 0, 0
|
||||
|
||||
for i := 0; i < len(word); i++ {
|
||||
if getBit(ss.leaves, nodeId) != 0 {
|
||||
return true
|
||||
}
|
||||
c := word[i]
|
||||
for ; ; bmIdx++ {
|
||||
if getBit(ss.labelBitmap, bmIdx) != 0 {
|
||||
@ -102,21 +196,18 @@ func (ss *Trie) HasPrefix(word string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if ss.labels[bmIdx-nodeId] == c {
|
||||
if byte(ss.labels.Get(bmIdx-nodeId)) == table[c] {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// go to next level
|
||||
|
||||
nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
|
||||
if getBit(ss.leaves, nodeId) != 0 {
|
||||
return true
|
||||
}
|
||||
bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1
|
||||
nodeId = countZeros(ss.labelBitmap, ss.ranksBL, bmIdx+1)
|
||||
bmIdx = selectIthOne(ss.labelBitmap, ss.ranksBL, ss.selectsBL, nodeId-1) + 1
|
||||
}
|
||||
|
||||
return false
|
||||
return getBit(ss.leaves, nodeId) != 0
|
||||
}
|
||||
|
||||
func setBit(bm *[]uint64, i int, v int) {
|
||||
@ -155,8 +246,8 @@ func (ss *Trie) init() {
|
||||
//
|
||||
// countZeros("010010", 4) == 3
|
||||
// // 012345
|
||||
func countZeros(bm []uint64, ranks []int32, i int) int {
|
||||
return i - int(ranks[i>>6]) - bits.OnesCount64(bm[i>>6]&(1<<uint(i&63)-1))
|
||||
func countZeros(bm []uint64, ranks *bitlist.CompactBitList, i int) int {
|
||||
return i - int(ranks.Get(i>>6)) - bits.OnesCount64(bm[i>>6]&(1<<uint(i&63)-1))
|
||||
}
|
||||
|
||||
// selectIthOne returns the index of the i-th "1" in a bitmap, on behalf of rank
|
||||
@ -165,9 +256,9 @@ func countZeros(bm []uint64, ranks []int32, i int) int {
|
||||
//
|
||||
// selectIthOne("010010", 1) == 4
|
||||
// // 012345
|
||||
func selectIthOne(bm []uint64, ranks, selects []int32, i int) int {
|
||||
base := int(selects[i>>6] & ^63)
|
||||
findIthOne := i - int(ranks[base>>6])
|
||||
func selectIthOne(bm []uint64, ranks, selects *bitlist.CompactBitList, i int) int {
|
||||
base := int(selects.Get(i>>6)) & ^63
|
||||
findIthOne := i - int(ranks.Get(base>>6))
|
||||
|
||||
for i := base >> 6; i < len(bm); i++ {
|
||||
bitIdx := 0
|
||||
|
@ -8,7 +8,7 @@ package trie
|
||||
import "testing"
|
||||
|
||||
func TestTrie(t *testing.T) {
|
||||
trie := NewTrie([]string{
|
||||
trie, err := NewTrie([]string{
|
||||
"moc.cbatnetnoc.",
|
||||
"moc.cbatnetnoc^",
|
||||
"nc.",
|
||||
@ -95,6 +95,9 @@ func TestTrie(t *testing.T) {
|
||||
"zk.ytamlacbci.",
|
||||
"zk.ytamlacbci^",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !(trie.HasPrefix("nc.tset^") == true) {
|
||||
t.Fatal("^test.cn")
|
||||
}
|
||||
|
Reference in New Issue
Block a user