Skip to content

Commit

Permalink
trie: remove global pool for stacktrie nodes, in favour of local
Browse files Browse the repository at this point in the history
  • Loading branch information
holiman committed Oct 2, 2023
1 parent 171a932 commit 8418248
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 28 deletions.
6 changes: 4 additions & 2 deletions core/types/hashing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,20 @@ func BenchmarkDeriveSha200(b *testing.B) {
var exp common.Hash
var got common.Hash
b.Run("std_trie", func(b *testing.B) {
hasher := trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase(), nil))
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
exp = types.DeriveSha(txs, trie.NewEmpty(trie.NewDatabase(rawdb.NewMemoryDatabase(), nil)))
exp = types.DeriveSha(txs, hasher)
}
})

b.Run("stack_trie", func(b *testing.B) {
hasher := trie.NewStackTrie(nil)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
got = types.DeriveSha(txs, trie.NewStackTrie(nil))
got = types.DeriveSha(txs, hasher)
}
})
if got != exp {
Expand Down
65 changes: 39 additions & 26 deletions trie/stacktrie.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package trie

import (
"errors"
"sync"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
Expand All @@ -27,8 +26,8 @@ import (

var (
ErrCommitDisabled = errors.New("no database for committing")
stPool = sync.Pool{New: func() any { return new(stNode) }}
_ = types.TrieHasher((*StackTrie)(nil))

_ = types.TrieHasher((*StackTrie)(nil)) // interface check
)

// NodeWriteFunc is used to provide all information of a dirty node for committing
Expand All @@ -41,17 +40,38 @@ type NodeWriteFunc = func(owner common.Hash, path []byte, hash common.Hash, blob
type StackTrie struct {
owner common.Hash // the owner of the trie
writeFn NodeWriteFunc // function for committing nodes, can be nil
root *stNode
h *hasher
root *stNode // the root node
h *hasher // hasher used for calculating hashes
pool []*stNode // local pool of nodes
}

const poolMax = 100

func (stack *StackTrie) getNode() *stNode {
if len(stack.pool) > 0 {
el := stack.pool[len(stack.pool)-1]
stack.pool[len(stack.pool)-1] = nil
stack.pool = stack.pool[:len(stack.pool)-1]
return el
}
return new(stNode)
}

func (stack *StackTrie) putNode(node *stNode) {
if len(stack.pool) < poolMax {
stack.pool = append(stack.pool, node.Reset())
}
}

// NewStackTrie allocates and initializes an empty trie.
func NewStackTrie(writeFn NodeWriteFunc) *StackTrie {
return &StackTrie{
st := &StackTrie{
writeFn: writeFn,
root: stPool.Get().(*stNode),
h: newHasher(false),
pool: make([]*stNode, 0, 100),
}
st.root = st.getNode()
return st
}

// NewStackTrieWithOwner allocates and initializes an empty trie, but with
Expand Down Expand Up @@ -83,7 +103,7 @@ func (st *StackTrie) MustUpdate(key, value []byte) {
func (stack *StackTrie) Reset() {
stack.owner = (common.Hash{})
stack.writeFn = nil
stack.root = stPool.Get().(*stNode)
stack.root = stack.getNode()
}

// stNode represents a node within a StackTrie
Expand All @@ -94,16 +114,14 @@ type stNode struct {
children [16]*stNode // list of children (for branch and exts)
}

func newLeaf(key, val []byte) *stNode {
st := stPool.Get().(*stNode)
func (st *stNode) toLeaf(key, val []byte) *stNode {
st.nodeType = leafNode
st.key = append(st.key, key...)
st.val = val
return st
}

func newExt(key []byte, child *stNode) *stNode {
st := stPool.Get().(*stNode)
func (st *stNode) toExt(key []byte, child *stNode) *stNode {
st.nodeType = extNode
st.key = append(st.key, key...)
st.children[0] = child
Expand Down Expand Up @@ -157,10 +175,9 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
break
}
}

// Add new child
if st.children[idx] == nil {
st.children[idx] = newLeaf(key[1:], value)
st.children[idx] = stack.getNode().toLeaf(key[1:], value)
} else {
stack.insert(st.children[idx], key[1:], value, append(prefix, key[0]))
}
Expand Down Expand Up @@ -189,7 +206,7 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
// Break on the non-last byte, insert an intermediate
// extension. The path prefix of the newly-inserted
// extension should also contain the different byte.
n = newExt(st.key[diffidx+1:], st.children[0])
n = stack.getNode().toExt(st.key[diffidx+1:], st.children[0])
stack.hash(n, append(prefix, st.key[:diffidx+1]...))
} else {
// Break on the last byte, no need to insert
Expand All @@ -211,12 +228,12 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
// the common prefix is at least one byte
// long, insert a new intermediate branch
// node.
st.children[0] = stPool.Get().(*stNode)
st.children[0] = stack.getNode()
st.children[0].nodeType = branchNode
p = st.children[0]
}
// Create a leaf for the inserted part
o := newLeaf(key[diffidx+1:], value)
o := stack.getNode().toLeaf(key[diffidx+1:], value)

// Insert both child leaves where they belong:
origIdx := st.key[diffidx]
Expand Down Expand Up @@ -252,7 +269,7 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
// Convert current node into an ext,
// and insert a child branch node.
st.nodeType = extNode
st.children[0] = stPool.Get().(*stNode)
st.children[0] = stack.getNode()
st.children[0].nodeType = branchNode
p = st.children[0]
}
Expand All @@ -261,11 +278,11 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
// value and another containing the new value. The child leaf
// is hashed directly in order to free up some memory.
origIdx := st.key[diffidx]
p.children[origIdx] = newLeaf(st.key[diffidx+1:], st.val)
p.children[origIdx] = stack.getNode().toLeaf(st.key[diffidx+1:], st.val)
stack.hash(p.children[origIdx], append(prefix, st.key[:diffidx+1]...))

newIdx := key[diffidx]
p.children[newIdx] = newLeaf(key[diffidx+1:], value)
p.children[newIdx] = stack.getNode().toLeaf(key[diffidx+1:], value)

// Finally, cut off the key part that has been passed
// over to the children.
Expand Down Expand Up @@ -323,10 +340,9 @@ func (stack *StackTrie) hash(st *stNode, path []byte) {
} else {
nodes.Children[i] = hashNode(child.val)
}

// Release child back to pool.
st.children[i] = nil
stPool.Put(child.Reset())
stack.putNode(child)
}

nodes.encode(stack.h.encbuf)
Expand All @@ -341,13 +357,11 @@ func (stack *StackTrie) hash(st *stNode, path []byte) {
} else {
n.Val = hashNode(st.children[0].val)
}

n.encode(stack.h.encbuf)
encodedNode = stack.h.encodedBytes()

// Release child back to pool.
stPool.Put(st.children[0].Reset())

stack.putNode(st.children[0])
st.children[0] = nil

case leafNode:
Expand All @@ -367,7 +381,6 @@ func (stack *StackTrie) hash(st *stNode, path []byte) {
st.val = common.CopyBytes(encodedNode)
return
}

// Write the hash to the 'val'. We allocate a new val here to not mutate
// input values
st.val = stack.h.hashData(encodedNode)
Expand Down

0 comments on commit 8418248

Please sign in to comment.