From 71a8a4f3e0562253039002647043298400f24fa5 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 23 Nov 2021 13:05:49 +0100 Subject: [PATCH] fix(trie): memory leak fix in `lib/trie` (#2009) - Buffer and pool usage fixed and improved - Fix buffers not put back in pool - Write to buffer passed as arguments - Decouple pools for encoding, digests and hashers - Improve `sync.Pool` usage generally - Improve parallel encoding of branches and remove dependency on `sync/x` - Do not copy when not needed --- go.mod | 2 +- lib/trie/hash.go | 334 ++++++++++++++++++++++++++++-------------- lib/trie/hash_test.go | 79 ++++------ lib/trie/node.go | 105 ++++++++----- lib/trie/node_test.go | 96 +++++------- lib/trie/print.go | 34 +++-- lib/trie/trie.go | 16 +- 7 files changed, 396 insertions(+), 270 deletions(-) diff --git a/go.mod b/go.mod index 83a0567d4b..ee23087213 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,6 @@ require ( github.com/urfave/cli v1.22.5 github.com/wasmerio/go-ext-wasm v0.3.2-0.20200326095750-0a32be6068ec golang.org/x/crypto v0.0.0-20210813211128-0a44fdfbc16e - golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 google.golang.org/protobuf v1.27.1 ) @@ -174,6 +173,7 @@ require ( go.uber.org/multierr v1.7.0 // indirect go.uber.org/zap v1.19.0 // indirect golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d // indirect + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect golang.org/x/sys v0.0.0-20210816183151-1e6c022a8912 // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect diff --git a/lib/trie/hash.go b/lib/trie/hash.go index 228fc7802a..887ed83731 100644 --- a/lib/trie/hash.go +++ b/lib/trie/hash.go @@ -5,99 +5,123 @@ package trie import ( "bytes" - "context" + "errors" + "fmt" "hash" + "io" "sync" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/scale" "golang.org/x/crypto/blake2b" - "golang.org/x/sync/errgroup" ) -// Hasher is a wrapper around a hash function -type hasher struct { - hash hash.Hash - tmp bytes.Buffer - parallel bool // Whether to use parallel threads when hashing +var encodingBufferPool = &sync.Pool{ + New: func() interface{} { + const initialBufferCapacity = 1900000 // 1.9MB, from checking capacities at runtime + b := make([]byte, 0, initialBufferCapacity) + return bytes.NewBuffer(b) + }, } -// hasherPool creates a pool of Hasher. -var hasherPool = sync.Pool{ +var digestBufferPool = &sync.Pool{ New: func() interface{} { - h, _ := blake2b.New256(nil) - var buf bytes.Buffer - // This allocation will be helpful for encoding keys. This is the min buffer size. - buf.Grow(700) - - return &hasher{ - tmp: buf, - hash: h, - } + const bufferCapacity = 32 + b := make([]byte, 0, bufferCapacity) + return bytes.NewBuffer(b) }, } -// NewHasher create new Hasher instance -func newHasher(parallel bool) *hasher { - h := hasherPool.Get().(*hasher) - h.parallel = parallel - return h +var hasherPool = &sync.Pool{ + New: func() interface{} { + hasher, err := blake2b.New256(nil) + if err != nil { + panic("cannot create Blake2b-256 hasher: " + err.Error()) + } + return hasher + }, } -func (h *hasher) returnToPool() { - h.tmp.Reset() - h.hash.Reset() - hasherPool.Put(h) -} +func hashNode(n node, digestBuffer io.Writer) (err error) { + encodingBuffer := encodingBufferPool.Get().(*bytes.Buffer) + encodingBuffer.Reset() + defer encodingBufferPool.Put(encodingBuffer) + + const parallel = false -// Hash encodes the node and then hashes it if its encoded length is > 32 bytes -func (h *hasher) Hash(n node) (res []byte, err error) { - encNode, err := h.encode(n) + err = encodeNode(n, encodingBuffer, parallel) if err != nil { - return nil, err + return fmt.Errorf("cannot encode node: %w", err) } // if length of encoded leaf is less than 32 bytes, do not hash - if len(encNode) < 32 { - return encNode, nil + if encodingBuffer.Len() < 32 { + _, err = digestBuffer.Write(encodingBuffer.Bytes()) + return err } - h.hash.Reset() // otherwise, hash encoded node - _, err = h.hash.Write(encNode) - if err == nil { - res = h.hash.Sum(nil) + hasher := hasherPool.Get().(hash.Hash) + hasher.Reset() + defer hasherPool.Put(hasher) + + // Note: using the sync.Pool's buffer is useful here. + _, err = hasher.Write(encodingBuffer.Bytes()) + if err != nil { + return fmt.Errorf("cannot hash encoded node: %w", err) } - return res, err + _, err = digestBuffer.Write(hasher.Sum(nil)) + return err } -// encode is the high-level function wrapping the encoding for different node types -// encoding has the following format: +var ErrNodeTypeUnsupported = errors.New("node type is not supported") + +// encodeNode writes the encoding of the node to the buffer given. +// It is the high-level function wrapping the encoding for different +// node types. The encoding has the following format: // NodeHeader | Extra partial key length | Partial Key | Value -func (h *hasher) encode(n node) ([]byte, error) { +func encodeNode(n node, buffer *bytes.Buffer, parallel bool) (err error) { switch n := n.(type) { case *branch: - return h.encodeBranch(n) + err := encodeBranch(n, buffer, parallel) + if err != nil { + return fmt.Errorf("cannot encode branch: %w", err) + } + return nil case *leaf: - return h.encodeLeaf(n) + err := encodeLeaf(n, buffer) + if err != nil { + return fmt.Errorf("cannot encode leaf: %w", err) + } + + n.encodingMu.Lock() + defer n.encodingMu.Unlock() + + // TODO remove this copying since it defeats the purpose of `buffer` + // and the sync.Pool. + n.encoding = make([]byte, buffer.Len()) + copy(n.encoding, buffer.Bytes()) + return nil case nil: - return []byte{0}, nil + buffer.Write([]byte{0}) + return nil + default: + return fmt.Errorf("%w: %T", ErrNodeTypeUnsupported, n) } - - return nil, nil } func encodeAndHash(n node) ([]byte, error) { - h := newHasher(false) - defer h.returnToPool() + buffer := digestBufferPool.Get().(*bytes.Buffer) + buffer.Reset() + defer digestBufferPool.Put(buffer) - encChild, err := h.Hash(n) + err := hashNode(n, buffer) if err != nil { return nil, err } - scEncChild, err := scale.Marshal(encChild) + scEncChild, err := scale.Marshal(buffer.Bytes()) if err != nil { return nil, err } @@ -105,95 +129,187 @@ func encodeAndHash(n node) ([]byte, error) { } // encodeBranch encodes a branch with the encoding specified at the top of this package -func (h *hasher) encodeBranch(b *branch) ([]byte, error) { +// to the buffer given. +func encodeBranch(b *branch, buffer io.Writer, parallel bool) (err error) { if !b.dirty && b.encoding != nil { - return b.encoding, nil + _, err = buffer.Write(b.encoding) + if err != nil { + return fmt.Errorf("cannot write stored encoded branch to buffer: %w", err) + } + return nil } - h.tmp.Reset() encoding, err := b.header() - h.tmp.Write(encoding) if err != nil { - return nil, err + return fmt.Errorf("cannot encode branch header: %w", err) + } + + _, err = buffer.Write(encoding) + if err != nil { + return fmt.Errorf("cannot write encoded branch header to buffer: %w", err) } - h.tmp.Write(nibblesToKeyLE(b.key)) - h.tmp.Write(common.Uint16ToBytes(b.childrenBitmap())) + _, err = buffer.Write(nibblesToKeyLE(b.key)) + if err != nil { + return fmt.Errorf("cannot write encoded branch key to buffer: %w", err) + } + + _, err = buffer.Write(common.Uint16ToBytes(b.childrenBitmap())) + if err != nil { + return fmt.Errorf("cannot write branch children bitmap to buffer: %w", err) + } if b.value != nil { bytes, err := scale.Marshal(b.value) if err != nil { - return nil, err + return fmt.Errorf("cannot scale encode branch value: %w", err) } - h.tmp.Write(bytes) - } - - if h.parallel { - wg, _ := errgroup.WithContext(context.Background()) - resBuff := make([][]byte, 16) - for i := 0; i < 16; i++ { - func(i int) { - wg.Go(func() error { - child := b.children[i] - if child == nil { - return nil - } - - var err error - resBuff[i], err = encodeAndHash(child) - if err != nil { - return err - } - return nil - }) - }(i) - } - if err := wg.Wait(); err != nil { - return nil, err + + _, err = buffer.Write(bytes) + if err != nil { + return fmt.Errorf("cannot write encoded branch value to buffer: %w", err) } + } + + if parallel { + err = encodeChildrenInParallel(b.children, buffer) + } else { + err = encodeChildrenSequentially(b.children, buffer) + } + if err != nil { + return fmt.Errorf("cannot encode children of branch: %w", err) + } - for _, v := range resBuff { - if v != nil { - h.tmp.Write(v) + return nil +} + +func encodeChildrenInParallel(children [16]node, buffer io.Writer) (err error) { + type result struct { + index int + buffer *bytes.Buffer + err error + } + + resultsCh := make(chan result) + + for i, child := range children { + go func(index int, child node) { + buffer := encodingBufferPool.Get().(*bytes.Buffer) + buffer.Reset() + // buffer is put back in the pool after processing its + // data in the select block below. + + err := encodeChild(child, buffer) + + resultsCh <- result{ + index: index, + buffer: buffer, + err: err, } + }(i, child) + } + + currentIndex := 0 + resultBuffers := make([]*bytes.Buffer, len(children)) + for range children { + result := <-resultsCh + if result.err != nil && err == nil { // only set the first error we get + err = result.err } - } else { - for i := 0; i < 16; i++ { - if child := b.children[i]; child != nil { - scEncChild, err := encodeAndHash(child) - if err != nil { - return nil, err - } - h.tmp.Write(scEncChild) + + resultBuffers[result.index] = result.buffer + + // write as many completed buffers to the result buffer. + for currentIndex < len(children) && + resultBuffers[currentIndex] != nil { + // note buffer.Write copies the byte slice given as argument + _, writeErr := buffer.Write(resultBuffers[currentIndex].Bytes()) + if writeErr != nil && err == nil { + err = writeErr } + + encodingBufferPool.Put(resultBuffers[currentIndex]) + resultBuffers[currentIndex] = nil + + currentIndex++ } } - return h.tmp.Bytes(), nil + for _, buffer := range resultBuffers { + if buffer == nil { // already emptied and put back in pool + continue + } + encodingBufferPool.Put(buffer) + } + + return err } -// encodeLeaf encodes a leaf with the encoding specified at the top of this package -func (h *hasher) encodeLeaf(l *leaf) ([]byte, error) { - if !l.dirty && l.encoding != nil { - return l.encoding, nil +func encodeChildrenSequentially(children [16]node, buffer io.Writer) (err error) { + for _, child := range children { + err = encodeChild(child, buffer) + if err != nil { + return err + } + } + return nil +} + +func encodeChild(child node, buffer io.Writer) (err error) { + if child == nil { + return nil + } + + scaleEncodedChild, err := encodeAndHash(child) + if err != nil { + return fmt.Errorf("failed to hash and scale encode child: %w", err) + } + + _, err = buffer.Write(scaleEncodedChild) + if err != nil { + return fmt.Errorf("failed to write child to buffer: %w", err) } - h.tmp.Reset() + return nil +} + +// encodeLeaf encodes a leaf to the buffer given, with the encoding +// specified at the top of this package. +func encodeLeaf(l *leaf, buffer io.Writer) (err error) { + l.encodingMu.RLock() + defer l.encodingMu.RUnlock() + if !l.dirty && l.encoding != nil { + _, err = buffer.Write(l.encoding) + if err != nil { + return fmt.Errorf("cannot write stored encoding to buffer: %w", err) + } + return nil + } encoding, err := l.header() - h.tmp.Write(encoding) if err != nil { - return nil, err + return fmt.Errorf("cannot encode header: %w", err) } - h.tmp.Write(nibblesToKeyLE(l.key)) + _, err = buffer.Write(encoding) + if err != nil { + return fmt.Errorf("cannot write encoded header to buffer: %w", err) + } - bytes, err := scale.Marshal(l.value) + _, err = buffer.Write(nibblesToKeyLE(l.key)) if err != nil { - return nil, err + return fmt.Errorf("cannot write LE key to buffer: %w", err) + } + + bytes, err := scale.Marshal(l.value) // TODO scale encoder to write to buffer + if err != nil { + return err + } + + _, err = buffer.Write(bytes) + if err != nil { + return fmt.Errorf("cannot write scale encoded value to buffer: %w", err) } - h.tmp.Write(bytes) - l.encoding = h.tmp.Bytes() - return h.tmp.Bytes(), nil + return nil } diff --git a/lib/trie/hash_test.go b/lib/trie/hash_test.go index cccb9f5a84..ffe46403f9 100644 --- a/lib/trie/hash_test.go +++ b/lib/trie/hash_test.go @@ -7,86 +7,69 @@ import ( "bytes" "math/rand" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func generateRandBytes(size int) []byte { - r := *rand.New(rand.NewSource(rand.Int63())) - buf := make([]byte, r.Intn(size)+1) - r.Read(buf) + buf := make([]byte, rand.Intn(size)+1) + rand.Read(buf) return buf } func generateRand(size int) [][]byte { rt := make([][]byte, size) - r := *rand.New(rand.NewSource(rand.Int63())) for i := range rt { - buf := make([]byte, r.Intn(379)+1) - r.Read(buf) + buf := make([]byte, rand.Intn(379)+1) + rand.Read(buf) rt[i] = buf } return rt } -func TestNewHasher(t *testing.T) { - hasher := newHasher(false) - defer hasher.returnToPool() - - _, err := hasher.hash.Write([]byte("noot")) - if err != nil { - t.Error(err) - } - - sum := hasher.hash.Sum(nil) - if sum == nil { - t.Error("did not sum hash") - } - - hasher.hash.Reset() -} - func TestHashLeaf(t *testing.T) { - hasher := newHasher(false) - defer hasher.returnToPool() - n := &leaf{key: generateRandBytes(380), value: generateRandBytes(64)} - h, err := hasher.Hash(n) + + buffer := bytes.NewBuffer(nil) + const parallel = false + err := encodeNode(n, buffer, parallel) + if err != nil { t.Errorf("did not hash leaf node: %s", err) - } else if h == nil { + } else if buffer.Len() == 0 { t.Errorf("did not hash leaf node: nil") } } func TestHashBranch(t *testing.T) { - hasher := newHasher(false) - defer hasher.returnToPool() - n := &branch{key: generateRandBytes(380), value: generateRandBytes(380)} n.children[3] = &leaf{key: generateRandBytes(380), value: generateRandBytes(380)} - h, err := hasher.Hash(n) + + buffer := bytes.NewBuffer(nil) + const parallel = false + err := encodeNode(n, buffer, parallel) + if err != nil { t.Errorf("did not hash branch node: %s", err) - } else if h == nil { + } else if buffer.Len() == 0 { t.Errorf("did not hash branch node: nil") } } func TestHashShort(t *testing.T) { - hasher := newHasher(false) - defer hasher.returnToPool() - - n := &leaf{key: generateRandBytes(2), value: generateRandBytes(3)} - expected, err := hasher.encode(n) - if err != nil { - t.Fatal(err) + n := &leaf{ + key: generateRandBytes(2), + value: generateRandBytes(3), } - h, err := hasher.Hash(n) - if err != nil { - t.Errorf("did not hash leaf node: %s", err) - } else if h == nil { - t.Errorf("did not hash leaf node: nil") - } else if !bytes.Equal(h[:], expected) { - t.Errorf("did not return encoded node padded to 32 bytes: got %s", h) - } + encodingBuffer := bytes.NewBuffer(nil) + const parallel = false + err := encodeNode(n, encodingBuffer, parallel) + require.NoError(t, err) + + digestBuffer := bytes.NewBuffer(nil) + err = hashNode(n, digestBuffer) + require.NoError(t, err) + assert.Equal(t, encodingBuffer.Bytes(), digestBuffer.Bytes()) } diff --git a/lib/trie/node.go b/lib/trie/node.go index 080cc6063a..cc55da89c3 100644 --- a/lib/trie/node.go +++ b/lib/trie/node.go @@ -69,6 +69,7 @@ type ( dirty bool hash []byte encoding []byte + encodingMu sync.RWMutex generation uint64 sync.RWMutex } @@ -83,11 +84,12 @@ func (l *leaf) setGeneration(generation uint64) { } func (b *branch) copy() node { - b.Lock() - defer b.Unlock() + b.RLock() + defer b.RUnlock() + cpy := &branch{ key: make([]byte, len(b.key)), - children: [16]node{}, + children: b.children, // copy interface pointers value: nil, dirty: b.dirty, hash: make([]byte, len(b.hash)), @@ -95,7 +97,6 @@ func (b *branch) copy() node { generation: b.generation, } copy(cpy.key, b.key) - copy(cpy.children[:], b.children[:]) // nil and []byte{} are encoded differently, watch out! if b.value != nil { @@ -109,8 +110,12 @@ func (b *branch) copy() node { } func (l *leaf) copy() node { - l.Lock() - defer l.Unlock() + l.RLock() + defer l.RUnlock() + + l.encodingMu.RLock() + defer l.encodingMu.RUnlock() + cpy := &leaf{ key: make([]byte, len(l.key)), value: make([]byte, len(l.value)), @@ -132,7 +137,10 @@ func (b *branch) setEncodingAndHash(enc, hash []byte) { } func (l *leaf) setEncodingAndHash(enc, hash []byte) { + l.encodingMu.Lock() l.encoding = enc + l.encodingMu.Unlock() + l.hash = hash } @@ -211,71 +219,96 @@ func (b *branch) setKey(key []byte) { b.key = key } -func (b *branch) encodeAndHash() ([]byte, []byte, error) { +func (b *branch) encodeAndHash() (encoding, hash []byte, err error) { if !b.dirty && b.encoding != nil && b.hash != nil { return b.encoding, b.hash, nil } - hasher := newHasher(false) - enc, err := hasher.encodeBranch(b) + buffer := encodingBufferPool.Get().(*bytes.Buffer) + buffer.Reset() + defer encodingBufferPool.Put(buffer) + + err = encodeBranch(b, buffer, false) if err != nil { return nil, nil, err } - if len(enc) < 32 { - b.encoding = enc - b.hash = enc - return enc, enc, nil + bufferBytes := buffer.Bytes() + + b.encoding = make([]byte, len(bufferBytes)) + copy(b.encoding, bufferBytes) + encoding = b.encoding // no need to copy + + if buffer.Len() < 32 { + b.hash = make([]byte, len(bufferBytes)) + copy(b.hash, bufferBytes) + hash = b.hash // no need to copy + return encoding, hash, nil } - hash, err := common.Blake2bHash(enc) + // Note: using the sync.Pool's buffer is useful here. + hashArray, err := common.Blake2bHash(buffer.Bytes()) if err != nil { return nil, nil, err } + b.hash = hashArray[:] + hash = b.hash // no need to copy - b.encoding = enc - b.hash = hash[:] - return enc, hash[:], nil + return encoding, hash, nil } -func (l *leaf) encodeAndHash() ([]byte, []byte, error) { +func (l *leaf) encodeAndHash() (encoding, hash []byte, err error) { + l.encodingMu.RLock() if !l.isDirty() && l.encoding != nil && l.hash != nil { + l.encodingMu.RUnlock() return l.encoding, l.hash, nil } - hasher := newHasher(false) - enc, err := hasher.encodeLeaf(l) + l.encodingMu.RUnlock() + + buffer := encodingBufferPool.Get().(*bytes.Buffer) + buffer.Reset() + defer encodingBufferPool.Put(buffer) + err = encodeLeaf(l, buffer) if err != nil { return nil, nil, err } - if len(enc) < 32 { - l.encoding = enc - l.hash = enc - return enc, enc, nil + bufferBytes := buffer.Bytes() + + l.encodingMu.Lock() + // TODO remove this copying since it defeats the purpose of `buffer` + // and the sync.Pool. + l.encoding = make([]byte, len(bufferBytes)) + copy(l.encoding, bufferBytes) + l.encodingMu.Unlock() + encoding = l.encoding // no need to copy + + if len(bufferBytes) < 32 { + l.hash = make([]byte, len(bufferBytes)) + copy(l.hash, bufferBytes) + hash = l.hash // no need to copy + return encoding, hash, nil } - hash, err := common.Blake2bHash(enc) + // Note: using the sync.Pool's buffer is useful here. + hashArray, err := common.Blake2bHash(buffer.Bytes()) if err != nil { return nil, nil, err } - l.encoding = enc - l.hash = hash[:] - return enc, hash[:], nil + l.hash = hashArray[:] + hash = l.hash // no need to copy + + return encoding, hash, nil } func decodeBytes(in []byte) (node, error) { - r := &bytes.Buffer{} - _, err := r.Write(in) - if err != nil { - return nil, err - } - - return decode(r) + buffer := bytes.NewBuffer(in) + return decode(buffer) } -// Decode wraps the decoding of different node types back into a node +// decode wraps the decoding of different node types back into a node func decode(r io.Reader) (node, error) { header, err := readByte(r) if err != nil { diff --git a/lib/trie/node_test.go b/lib/trie/node_test.go index 754f662c5c..4bc452cfa3 100644 --- a/lib/trie/node_test.go +++ b/lib/trie/node_test.go @@ -11,6 +11,7 @@ import ( "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -136,44 +137,38 @@ func TestBranchEncode(t *testing.T) { for i, testKey := range randKeys { b := &branch{key: testKey, children: [16]node{}, value: randVals[i]} - expected := []byte{} + expected := bytes.NewBuffer(nil) header, err := b.header() if err != nil { t.Fatalf("Error when encoding header: %s", err) } - expected = append(expected, header...) - expected = append(expected, nibblesToKeyLE(b.key)...) - expected = append(expected, common.Uint16ToBytes(b.childrenBitmap())...) + expected.Write(header) + expected.Write(nibblesToKeyLE(b.key)) + expected.Write(common.Uint16ToBytes(b.childrenBitmap())) enc, err := scale.Marshal(b.value) if err != nil { t.Fatalf("Fail when encoding value with scale: %s", err) } - expected = append(expected, enc...) + expected.Write(enc) for _, child := range b.children { - if child != nil { - hasher := newHasher(false) - defer hasher.returnToPool() - encChild, er := hasher.Hash(child) - if er != nil { - t.Errorf("Fail when encoding branch child: %s", er) - } - expected = append(expected, encChild[:]...) + if child == nil { + continue } - } - hasher := newHasher(false) - defer hasher.returnToPool() - res, err := hasher.encodeBranch(b) - if !bytes.Equal(res, expected) { - t.Errorf("Fail when encoding node: got %x expected %x", res, expected) - } else if err != nil { - t.Errorf("Fail when encoding node: %s", err) + err := hashNode(child, expected) + require.NoError(t, err) } + + buffer := bytes.NewBuffer(nil) + const parallel = false + err = encodeBranch(b, buffer, parallel) + require.NoError(t, err) + assert.Equal(t, expected.Bytes(), buffer.Bytes()) } } @@ -199,14 +194,10 @@ func TestLeafEncode(t *testing.T) { expected = append(expected, enc...) - hasher := newHasher(false) - defer hasher.returnToPool() - res, err := hasher.encodeLeaf(l) - if !bytes.Equal(res, expected) { - t.Errorf("Fail when encoding node: got %x expected %x", res, expected) - } else if err != nil { - t.Errorf("Fail when encoding node: %s", err) - } + buffer := bytes.NewBuffer(nil) + err = encodeLeaf(l, buffer) + require.NoError(t, err) + assert.Equal(t, expected, buffer.Bytes()) } } @@ -223,12 +214,10 @@ func TestEncodeRoot(t *testing.T) { t.Errorf("Fail to get key %x with value %x: got %x", test.key, test.value, val) } - hasher := newHasher(false) - defer hasher.returnToPool() - _, err := hasher.encode(trie.root) - if err != nil { - t.Errorf("Fail to encode trie root: %s", err) - } + buffer := bytes.NewBuffer(nil) + const parallel = false + err := encodeNode(trie.root, buffer, parallel) + require.NoError(t, err) } } } @@ -250,18 +239,16 @@ func TestBranchDecode(t *testing.T) { {key: byteArray(573), children: [16]node{}, value: []byte{0x01}}, } - hasher := newHasher(false) - defer hasher.returnToPool() + buffer := bytes.NewBuffer(nil) + const parallel = false + for _, test := range tests { - enc, err := hasher.encodeBranch(test) + err := encodeBranch(test, buffer, parallel) require.NoError(t, err) res := new(branch) - r := &bytes.Buffer{} - _, err = r.Write(enc) - require.NoError(t, err) + err = res.decode(buffer, 0) - err = res.decode(r, 0) require.NoError(t, err) require.Equal(t, test.key, res.key) require.Equal(t, test.childrenBitmap(), res.childrenBitmap()) @@ -281,18 +268,14 @@ func TestLeafDecode(t *testing.T) { {key: byteArray(573), value: []byte{0x01}, dirty: true}, } - hasher := newHasher(false) - defer hasher.returnToPool() + buffer := bytes.NewBuffer(nil) + for _, test := range tests { - enc, err := hasher.encodeLeaf(test) + err := encodeLeaf(test, buffer) require.NoError(t, err) res := new(leaf) - r := &bytes.Buffer{} - _, err = r.Write(enc) - require.NoError(t, err) - - err = res.decode(r, 0) + err = res.decode(buffer, 0) require.NoError(t, err) res.hash = nil @@ -320,17 +303,14 @@ func TestDecode(t *testing.T) { &leaf{key: byteArray(573), value: []byte{0x01}}, } - hasher := newHasher(false) - defer hasher.returnToPool() - for _, test := range tests { - enc, err := hasher.encode(test) - require.NoError(t, err) + buffer := bytes.NewBuffer(nil) + const parallel = false - r := &bytes.Buffer{} - _, err = r.Write(enc) + for _, test := range tests { + err := encodeNode(test, buffer, parallel) require.NoError(t, err) - res, err := decode(r) + res, err := decode(buffer) require.NoError(t, err) switch n := test.(type) { diff --git a/lib/trie/print.go b/lib/trie/print.go index bba1f15123..ba72fde4a5 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -4,6 +4,7 @@ package trie import ( + "bytes" "fmt" "github.com/ChainSafe/gossamer/lib/common" @@ -25,15 +26,22 @@ func (t *Trie) String() string { func (t *Trie) string(tree gotree.Tree, curr node, idx int) { switch c := curr.(type) { case *branch: - hasher := newHasher(false) - defer hasher.returnToPool() - c.encoding, _ = hasher.encode(c) + buffer := encodingBufferPool.Get().(*bytes.Buffer) + buffer.Reset() + + const parallel = false + _ = encodeBranch(c, buffer, parallel) + c.encoding = buffer.Bytes() + var bstr string if len(c.encoding) > 1024 { bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", idx, c.String(), common.MustBlake2bHash(c.encoding), c.generation) } else { bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.encoding, c.generation) } + + encodingBufferPool.Put(buffer) + sub := tree.Add(bstr) for i, child := range c.children { if child != nil { @@ -41,22 +49,26 @@ func (t *Trie) string(tree gotree.Tree, curr node, idx int) { } } case *leaf: - hasher := newHasher(false) - defer hasher.returnToPool() - c.encoding, _ = hasher.encode(c) + buffer := encodingBufferPool.Get().(*bytes.Buffer) + buffer.Reset() + + _ = encodeLeaf(c, buffer) + + c.encodingMu.Lock() + defer c.encodingMu.Unlock() + c.encoding = buffer.Bytes() + var bstr string if len(c.encoding) > 1024 { bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", idx, c.String(), common.MustBlake2bHash(c.encoding), c.generation) } else { bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.encoding, c.generation) } + + encodingBufferPool.Put(buffer) + tree.Add(bstr) default: return } } - -// Print prints the trie through pre-order traversal -func (t *Trie) Print() { - fmt.Println(t.String()) -} diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 3a1afd7147..0705b9ba7f 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -106,11 +106,9 @@ func (t *Trie) RootNode() node { //nolint return t.root } -// EncodeRoot returns the encoded root of the trie -func (t *Trie) EncodeRoot() ([]byte, error) { - h := newHasher(t.parallel) - defer h.returnToPool() - return h.encode(t.RootNode()) +// encodeRoot returns the encoded root of the trie +func (t *Trie) encodeRoot(buffer *bytes.Buffer) (err error) { + return encodeNode(t.RootNode(), buffer, t.parallel) } // MustHash returns the hashed root of the trie. It panics if it fails to hash the root node. @@ -125,12 +123,16 @@ func (t *Trie) MustHash() common.Hash { // Hash returns the hashed root of the trie func (t *Trie) Hash() (common.Hash, error) { - encRoot, err := t.EncodeRoot() + buffer := encodingBufferPool.Get().(*bytes.Buffer) + buffer.Reset() + defer encodingBufferPool.Put(buffer) + + err := t.encodeRoot(buffer) if err != nil { return [32]byte{}, err } - return common.Blake2bHash(encRoot) + return common.Blake2bHash(buffer.Bytes()) } // Entries returns all the key-value pairs in the trie as a map of keys to values