Skip to content

Commit

Permalink
membuffer: compare the keys of ART by chunk (#1482)
Browse files Browse the repository at this point in the history
ref pingcap/tidb#55287

Signed-off-by: you06 <[email protected]>
  • Loading branch information
you06 authored Nov 7, 2024
1 parent 0232600 commit c154447
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 15 deletions.
8 changes: 4 additions & 4 deletions internal/unionstore/art/art_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ func (it *baseIter) next() artNode {
if idx >= 0 && idx < int(n4.nodeNum) {
it.idxes[depth] = idx
child = &n4.children[idx]
} else if idx == int(n4.nodeNum) {
// idx == n4.nodeNum means this node is drain, break to pop stack.
} else if idx >= int(n4.nodeNum) {
// idx >= n4.nodeNum means this node is drain, break to pop stack.
break
} else {
panicForInvalidIndex(idx)
Expand All @@ -380,8 +380,8 @@ func (it *baseIter) next() artNode {
if idx >= 0 && idx < int(n16.nodeNum) {
it.idxes[depth] = idx
child = &n16.children[idx]
} else if idx == int(n16.nodeNum) {
// idx == n16.nodeNum means this node is drain, break to pop stack.
} else if idx >= int(n16.nodeNum) {
// idx >= n16.nodeNum means this node is drain, break to pop stack.
break
} else {
panicForInvalidIndex(idx)
Expand Down
56 changes: 45 additions & 11 deletions internal/unionstore/art/art_node.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bytes"
"math"
"math/bits"
"runtime"
"sort"
"testing"
"unsafe"
Expand Down Expand Up @@ -310,14 +311,7 @@ func (n *nodeBase) setPrefix(key artKey, prefixLen uint32) {
// Node if the nodeBase.prefixLen > maxPrefixLen and the returned mismatch index equals to maxPrefixLen,
// key[maxPrefixLen:] will not be checked by this function.
func (n *nodeBase) match(key artKey, depth uint32) uint32 /* mismatch index */ {
idx := uint32(0)
limit := min(min(n.prefixLen, maxPrefixLen), uint32(len(key))-depth)
for ; idx < limit; idx++ {
if n.prefix[idx] != key[idx+depth] {
return idx
}
}
return idx
return longestCommonPrefix(key[depth:], n.prefix[:min(n.prefixLen, maxPrefixLen)], 0)
}

// matchDeep returns the mismatch index of the key and the node's prefix.
Expand Down Expand Up @@ -352,13 +346,53 @@ func (an *artNode) asNode256(a *artAllocator) *node256 {
return a.getNode256(an.addr)
}

// for amd64 and arm64 architectures, we use the chunk comparison to speed up finding the longest common prefix.
const enableChunkComparison = runtime.GOARCH == "amd64" || runtime.GOARCH == "arm64"

// longestCommonPrefix returns the length of the longest common prefix of two keys.
// the LCP is calculated from the given depth, you need to guarantee l1Key[:depth] equals to l2Key[:depth] before calling this function.
func longestCommonPrefix(l1Key, l2Key artKey, depth uint32) uint32 {
if enableChunkComparison {
return longestCommonPrefixByChunk(l1Key, l2Key, depth)
}
// For other architectures, we use the byte-by-byte comparison.
idx, limit := depth, uint32(min(len(l1Key), len(l2Key)))
for ; idx < limit; idx++ {
if l1Key[idx] != l2Key[idx] {
break
}
}
return idx - depth
}

// longestCommonPrefixByChunk compares two keys by 8 bytes at a time, which is significantly faster when the keys are long.
// Note this function only support architecture which is under little-endian and can read memory across unaligned address.
func longestCommonPrefixByChunk(l1Key, l2Key artKey, depth uint32) uint32 {
idx, limit := depth, uint32(min(len(l1Key), len(l2Key)))
// TODO: possible optimization
// Compare the key by loop can be very slow if the final LCP is large.
// Maybe optimize it by comparing the key in chunks if the limit exceeds certain threshold.

if idx == limit {
return 0
}

p1 := unsafe.Pointer(&l1Key[depth])
p2 := unsafe.Pointer(&l2Key[depth])

// Compare 8 bytes at a time
remaining := limit - depth
for remaining >= 8 {
if *(*uint64)(p1) != *(*uint64)(p2) {
// Find first different byte using trailing zeros
xor := *(*uint64)(p1) ^ *(*uint64)(p2)
return limit - remaining + uint32(bits.TrailingZeros64(xor)>>3) - depth
}

p1 = unsafe.Add(p1, 8)
p2 = unsafe.Add(p2, 8)
remaining -= 8
}

// Compare rest bytes
idx = limit - remaining
for ; idx < limit; idx++ {
if l1Key[idx] != l2Key[idx] {
break
Expand Down
11 changes: 11 additions & 0 deletions internal/unionstore/art/art_node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,14 @@ func TestMinimumNode(t *testing.T) {
check(typeNode48)
check(typeNode256)
}

func TestKey2Chunk(t *testing.T) {
key := artKey([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})

for i := 0; i < len(key); i++ {
diffKey := make(artKey, len(key))
copy(diffKey, key)
diffKey[i] = 255
require.Equal(t, uint32(i), longestCommonPrefix(key, diffKey, 0))
}
}
21 changes: 21 additions & 0 deletions internal/unionstore/memdb_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"context"
"encoding/binary"
"math/rand"
"slices"
"testing"
)

Expand Down Expand Up @@ -250,3 +251,23 @@ func BenchmarkMemBufferCache(b *testing.B) {
b.Run("RBT", func(b *testing.B) { fn(b, newRbtDBWithContext()) })
b.Run("ART", func(b *testing.B) { fn(b, newArtDBWithContext()) })
}

func BenchmarkMemBufferSetGetLongKey(b *testing.B) {
fn := func(b *testing.B, buffer MemBuffer) {
keys := make([][]byte, b.N)
for i := 0; i < b.N; i++ {
keys[i] = make([]byte, 1024)
binary.BigEndian.PutUint64(keys[i], uint64(i))
slices.Reverse(keys[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
buffer.Set(keys[i], keys[i])
}
for i := 0; i < b.N; i++ {
buffer.Get(context.Background(), keys[i])
}
}
b.Run("RBT", func(b *testing.B) { fn(b, newRbtDBWithContext()) })
b.Run("ART", func(b *testing.B) { fn(b, newArtDBWithContext()) })
}

0 comments on commit c154447

Please sign in to comment.