optimize: optimize momory of domain matching trie

This commit is contained in:
mzz2017
2023-02-19 19:15:16 +08:00
parent a011c2a74c
commit c75db9397d
5 changed files with 347 additions and 20 deletions

143
common/bitlist/bitlist.go Normal file
View File

@ -0,0 +1,143 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package bitlist
import (
"fmt"
"github.com/mzz2017/softwind/common"
"github.com/mzz2017/softwind/pkg/zeroalloc/buffer"
"github.com/mzz2017/softwind/pool"
"math/bits"
)
// CompactBitList allows your units to be of arbitrary bit size.
type CompactBitList struct {
unitBitSize int
size int
b *buffer.Buffer
unitNum int
}
func NewCompactBitList(unitBitSize int) *CompactBitList {
return &CompactBitList{
unitBitSize: unitBitSize,
size: 0,
b: buffer.NewBuffer(1),
}
}
// Set is not optimized yet.
func (m *CompactBitList) Set(iUnit int, v uint64) {
if bits.Len64(v) > m.unitBitSize {
panic(fmt.Sprintf("value %v exceeds unit bit size", v))
}
m.growByUnitIndex(iUnit)
b := m.b.Bytes()
i := iUnit * m.unitBitSize / 8
j := iUnit * m.unitBitSize % 8
for unitToTravel := m.unitBitSize; unitToTravel > 0; unitToTravel -= 8 {
k := 0
for ; k < unitToTravel && j+k < 8; k++ {
b[i] &= ^(1 << (k + j)) // clear bit.
val := uint8((v & (1 << k)) << j)
b[i] |= val // set bit.
}
// Now unitBitSize is traveled and we should break the loop,
// OR we did not travel the byte and we need to travel the next byte.
if k >= unitToTravel {
break
}
i++
bakJ := j
j = k
for ; k < unitToTravel && k < 8; k++ {
b[i] &= ^(1 << (k - j)) // clear bit.
val := uint8((v & (1 << k)) >> j)
b[i] |= val // set bit.
}
v >>= 8
j = (bakJ + 8) % 8
}
m.unitNum = common.Max(m.unitNum, iUnit+1)
}
func (m *CompactBitList) Get(iUnit int) (v uint64) {
bitBoundary := (iUnit + 1) * m.unitBitSize
if m.b.Len()*8 < bitBoundary {
return 0
}
b := m.b.Bytes()
i := iUnit * m.unitBitSize / 8
j := iUnit * m.unitBitSize % 8
var val uint8
byteSpace := 8 - j
// 11111111
// |
// j byteSpace = 6, unitBitSize = 2
// 11 We only copy those 2 bits, so we left shift 4 and right shift 4+2.
if byteSpace > m.unitBitSize {
toTrimLeft := byteSpace - m.unitBitSize
return uint64((b[i] << toTrimLeft) >> (toTrimLeft + j))
} else {
// Trim right only.
val = b[i] >> j
}
v |= uint64(val)
offset := 8 - j
i++
// Now we have multiple of 8 bits spaces to move.
unitToTravel := m.unitBitSize - offset
for ; unitToTravel >= 8; unitToTravel, i, offset = unitToTravel-8, i+1, offset+8 {
// 11111111
// |
// p
// 11111111 We copy whole 8 bits
v |= uint64(b[i]) << offset
}
if unitToTravel == 0 {
return v
}
// 11111111
// |
// p unitToTravel = 3
// 111 We only copy those 3 bits, so we left shift 5 and right shift 5.
toTrimLeft := 8 - unitToTravel
if offset > toTrimLeft {
v |= uint64(b[i]<<toTrimLeft) << (offset - toTrimLeft)
} else {
v |= uint64(b[i]<<toTrimLeft) >> (toTrimLeft - offset)
}
return v
}
func (m *CompactBitList) Append(v uint64) {
m.Set(m.unitNum, v)
}
func (m *CompactBitList) growByUnitIndex(i int) {
if bitBoundary := (i + 1) * m.unitBitSize; m.b.Len()*8 < bitBoundary {
needBytes := bitBoundary / 8
if bitBoundary%8 != 0 {
needBytes++
}
m.b.Extend(needBytes - m.b.Len())
}
}
func (m *CompactBitList) Tighten() {
a := pool.B(make([]byte, m.b.Len()))
copy(a, m.b.Bytes())
m.b.Put()
m.b = buffer.NewBufferFrom(a)
}
func (m *CompactBitList) Put() {
m.b.Put()
}

View File

@ -0,0 +1,90 @@
/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
*/
package bitlist
import (
"fmt"
"testing"
)
func TestBitList6(t *testing.T) {
bm := NewCompactBitList(6)
bm.Set(1, 0b110010)
if v := bm.Get(1); v != 0b110010 {
t.Fatal(fmt.Errorf("expect 0b%08b, got 0b%08b", 0b110010, v))
}
bm.Tighten()
if v := bm.Get(1); v != 0b110010 {
t.Fatal(fmt.Errorf("expect 0b%08b, got 0b%08b", 0b110010, v))
}
bm.Set(13, 0b110010)
if v := bm.Get(13); v != 0b110010 {
t.Fatal(fmt.Errorf("expect 0b%08b, got 0b%08b", 0b110010, v))
}
bm.Tighten()
if bm.b.Cap() != 11 {
t.Fatal("failed to tighten", bm.b.Cap())
}
if v := bm.Get(13); v != 0b110010 {
t.Fatal(fmt.Errorf("expect 0b%08b, got 0b%08b", 0b110010, v))
}
bm.Append(0b110010)
if v := bm.Get(14); v != 0b110010 {
t.Fatal(fmt.Errorf("expect 0b%08b, got 0b%08b", 0b110010, v))
}
if bm.b.Cap() != 32 {
t.Fatal("unexpected grow behavior", bm.b.Cap())
}
bm.Tighten()
if bm.b.Cap() != 12 {
t.Fatal("failed to tighten", bm.b.Cap())
}
}
func TestBitList19(t *testing.T) {
bm := NewCompactBitList(19)
bm.Set(1, 0b1110010110010110010)
if v := bm.Get(1); v != 0b1110010110010110010 {
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b1110010110010110010, v))
}
bm.Tighten()
if v := bm.Get(1); v != 0b1110010110010110010 {
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b1110010110010110010, v))
}
bm.Set(13, 0b1110010110010110010)
if v := bm.Get(13); v != 0b1110010110010110010 {
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b1110010110010110010, v))
}
bm.Tighten()
if bm.b.Cap() != 34 {
t.Fatal("failed to tighten", bm.b.Cap())
}
if v := bm.Get(13); v != 0b1110010110010110010 {
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b1110010110010110010, v))
}
bm.Append(0b1110010110010110010)
if v := bm.Get(14); v != 0b1110010110010110010 {
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b1110010110010110010, v))
}
if bm.b.Cap() != 128 {
t.Fatal("unexpected grow behavior", bm.b.Cap())
}
bm.Tighten()
if bm.b.Cap() != 36 {
t.Fatal("failed to tighten", bm.b.Cap())
}
bm.Set(1, 0b0000000000000000000)
if v := bm.Get(1); v != 0b0000000000000000000 {
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b0000000000000000000, v))
}
bm.Set(2, 0b1111111111111111111)
if v := bm.Get(2); v != 0b1111111111111111111 {
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b1111111111111111111, v))
}
if v := bm.Get(1); v != 0b0000000000000000000 {
t.Fatal(fmt.Errorf("expect 0b%019b, got 0b%019b", 0b0000000000000000000, v))
}
}

View File

@ -173,7 +173,7 @@ func (n *AhocorasickSlimtrie) Build() (err error) {
}
toBuild = ToSuffixTrieStrings(toBuild)
sort.Strings(toBuild)
n.trie[i] = trie.NewTrie(toBuild)
n.trie[i], err = trie.NewTrie(toBuild)
if err != nil {
return err
}

View File

@ -1,7 +1,62 @@
// Package succinct is modified from https://github.com/openacid/succinct/blob/loc100/sskv.go.
// Package trie is modified from https://github.com/openacid/succinct/blob/loc100/sskv.go.
// Slower than about 50% but more memory saving.
package trie
import "math/bits"
import (
"fmt"
"github.com/v2rayA/dae/common/bitlist"
"math/bits"
)
var table = [256]byte{
97: 0, // 'a'
98: 1,
99: 2,
100: 3,
101: 4,
102: 5,
103: 6,
104: 7,
105: 8,
106: 9,
107: 10,
108: 11,
109: 12,
110: 13,
111: 14,
112: 15,
113: 16,
114: 17,
115: 18,
116: 19,
117: 20,
118: 21,
119: 22,
120: 23,
121: 24,
122: 25,
'-': 26,
'.': 27,
'^': 28,
'$': 29,
'1': 30,
'2': 31,
'3': 32,
'4': 33,
'5': 34,
'6': 35,
'7': 36,
'8': 37,
'9': 38,
'0': 39,
}
const N = 40
func IsValidChar(b byte) bool {
return table[b] > 0 || b == 'a'
}
// Trie is a succinct, sorted and static string set impl with compacted trie as
// storage. The space cost is about half lower than the original data.
@ -45,14 +100,25 @@ import "math/bits"
// leaves: 0001001111
type Trie struct {
leaves, labelBitmap []uint64
labels []byte
ranks, selects []int32
labels *bitlist.CompactBitList
ranksBL, selectsBL *bitlist.CompactBitList
}
// NewTrie creates a new *Trie struct, from a slice of sorted strings.
func NewTrie(keys []string) *Trie {
func NewTrie(keys []string) (*Trie, error) {
// Check chars.
for _, key := range keys {
for _, c := range []byte(key) {
if !IsValidChar(c) {
return nil, fmt.Errorf("char out of range: %c", c)
}
}
}
ss := &Trie{}
ss.labels = bitlist.NewCompactBitList(bits.Len8(N))
lIdx := 0
type qElt struct{ s, e, col int }
@ -76,7 +142,7 @@ func NewTrie(keys []string) *Trie {
}
queue = append(queue, qElt{frm, j, elt.col + 1})
ss.labels = append(ss.labels, keys[frm][elt.col])
ss.labels.Append(uint64(table[keys[frm][elt.col]]))
setBit(&ss.labelBitmap, lIdx, 0)
lIdx++
}
@ -86,7 +152,32 @@ func NewTrie(keys []string) *Trie {
}
ss.init()
return ss
// Tighten.
ss.labels.Tighten()
leaves := make([]uint64, len(ss.leaves))
copy(leaves, ss.leaves)
ss.leaves = leaves
labelBitmap := make([]uint64, len(ss.labelBitmap))
copy(labelBitmap, ss.labelBitmap)
ss.labelBitmap = labelBitmap
ss.ranksBL = bitlist.NewCompactBitList(bits.Len64(uint64(ss.ranks[len(ss.ranks)-1])))
ss.selectsBL = bitlist.NewCompactBitList(bits.Len64(uint64(ss.selects[len(ss.selects)-1])))
for _, v := range ss.ranks {
ss.ranksBL.Append(uint64(v))
}
for _, v := range ss.selects {
ss.selectsBL.Append(uint64(v))
}
ss.ranksBL.Tighten()
ss.selectsBL.Tighten()
ss.ranks = nil
ss.selects = nil
return ss, nil
}
// HasPrefix query for a word and return whether a prefix of the word is in the Trie.
@ -95,6 +186,9 @@ func (ss *Trie) HasPrefix(word string) bool {
nodeId, bmIdx := 0, 0
for i := 0; i < len(word); i++ {
if getBit(ss.leaves, nodeId) != 0 {
return true
}
c := word[i]
for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 {
@ -102,21 +196,18 @@ func (ss *Trie) HasPrefix(word string) bool {
return false
}
if ss.labels[bmIdx-nodeId] == c {
if byte(ss.labels.Get(bmIdx-nodeId)) == table[c] {
break
}
}
// go to next level
nodeId = countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
if getBit(ss.leaves, nodeId) != 0 {
return true
}
bmIdx = selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nodeId-1) + 1
nodeId = countZeros(ss.labelBitmap, ss.ranksBL, bmIdx+1)
bmIdx = selectIthOne(ss.labelBitmap, ss.ranksBL, ss.selectsBL, nodeId-1) + 1
}
return false
return getBit(ss.leaves, nodeId) != 0
}
func setBit(bm *[]uint64, i int, v int) {
@ -155,8 +246,8 @@ func (ss *Trie) init() {
//
// countZeros("010010", 4) == 3
// // 012345
func countZeros(bm []uint64, ranks []int32, i int) int {
return i - int(ranks[i>>6]) - bits.OnesCount64(bm[i>>6]&(1<<uint(i&63)-1))
func countZeros(bm []uint64, ranks *bitlist.CompactBitList, i int) int {
return i - int(ranks.Get(i>>6)) - bits.OnesCount64(bm[i>>6]&(1<<uint(i&63)-1))
}
// selectIthOne returns the index of the i-th "1" in a bitmap, on behalf of rank
@ -165,9 +256,9 @@ func countZeros(bm []uint64, ranks []int32, i int) int {
//
// selectIthOne("010010", 1) == 4
// // 012345
func selectIthOne(bm []uint64, ranks, selects []int32, i int) int {
base := int(selects[i>>6] & ^63)
findIthOne := i - int(ranks[base>>6])
func selectIthOne(bm []uint64, ranks, selects *bitlist.CompactBitList, i int) int {
base := int(selects.Get(i>>6)) & ^63
findIthOne := i - int(ranks.Get(base>>6))
for i := base >> 6; i < len(bm); i++ {
bitIdx := 0

View File

@ -8,7 +8,7 @@ package trie
import "testing"
func TestTrie(t *testing.T) {
trie := NewTrie([]string{
trie, err := NewTrie([]string{
"moc.cbatnetnoc.",
"moc.cbatnetnoc^",
"nc.",
@ -95,6 +95,9 @@ func TestTrie(t *testing.T) {
"zk.ytamlacbci.",
"zk.ytamlacbci^",
})
if err != nil {
t.Fatal(err)
}
if !(trie.HasPrefix("nc.tset^") == true) {
t.Fatal("^test.cn")
}