From 9e54856003cb9c31ea4bc8aa1a01cf337f4f0cc5 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 7 Nov 2022 13:11:00 +0000 Subject: [PATCH] fix(trie): disallow empty byte slice node values --- internal/trie/node/copy.go | 3 +- internal/trie/node/decode_test.go | 14 ++++++- internal/trie/node/encode.go | 14 ++++--- internal/trie/node/encode_test.go | 45 +++++++++++++++++++++++ internal/trie/node/subvalue.go | 17 --------- internal/trie/node/subvalue_test.go | 57 ----------------------------- lib/trie/trie.go | 9 ++++- lib/trie/trie_endtoend_test.go | 6 ++- 8 files changed, 80 insertions(+), 85 deletions(-) delete mode 100644 internal/trie/node/subvalue.go delete mode 100644 internal/trie/node/subvalue_test.go diff --git a/internal/trie/node/copy.go b/internal/trie/node/copy.go index 6d7faa3feef..65b09d5b01c 100644 --- a/internal/trie/node/copy.go +++ b/internal/trie/node/copy.go @@ -87,7 +87,8 @@ func (n *Node) Copy(settings CopySettings) *Node { copy(cpy.Key, n.Key) } - // nil and []byte{} are encoded differently, watch out! + // nil and []byte{} values for branches result in a different node encoding, + // so we ensure to keep the `nil` value. if settings.CopyValue && n.SubValue != nil { cpy.SubValue = make([]byte, len(n.SubValue)) copy(cpy.SubValue, n.SubValue) diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go index 46c27012fb3..5c37e127bdf 100644 --- a/internal/trie/node/decode_test.go +++ b/internal/trie/node/decode_test.go @@ -327,7 +327,7 @@ func Test_decodeLeaf(t *testing.T) { errWrapped: ErrDecodeValue, errMessage: "cannot decode value: unknown prefix for compact uint: 255", }, - "zero value": { + "missing value data": { reader: bytes.NewBuffer([]byte{ 9, // key data // missing value data @@ -338,6 +338,18 @@ func Test_decodeLeaf(t *testing.T) { Key: []byte{9}, }, }, + "empty value data": { + reader: bytes.NewBuffer(concatByteSlices([][]byte{ + {9}, // key data + scaleEncodeByteSlice(t, nil), + })), + variant: leafVariant.bits, + partialKeyLength: 1, + leaf: &Node{ + Key: []byte{9}, + Dirty: true, + }, + }, "success": { reader: bytes.NewBuffer( concatByteSlices([][]byte{ diff --git a/internal/trie/node/encode.go b/internal/trie/node/encode.go index 59ce5da172b..54283edbfea 100644 --- a/internal/trie/node/encode.go +++ b/internal/trie/node/encode.go @@ -35,7 +35,9 @@ func (n *Node) Encode(buffer Buffer) (err error) { return fmt.Errorf("cannot write LE key to buffer: %w", err) } - if n.Kind() == Branch { + kind := n.Kind() + nodeIsBranch := kind == Branch + if nodeIsBranch { childrenBitmap := common.Uint16ToBytes(n.ChildrenBitmap()) _, err = buffer.Write(childrenBitmap) if err != nil { @@ -43,9 +45,9 @@ func (n *Node) Encode(buffer Buffer) (err error) { } } - // check value is not nil for branch nodes, even though - // leaf nodes always have a non-nil value. - if n.SubValue != nil { + // Only encode node value if the node is a leaf or + // the node is a branch with a non empty value. + if !nodeIsBranch || (nodeIsBranch && n.SubValue != nil) { encodedValue, err := scale.Marshal(n.SubValue) // TODO scale encoder to write to buffer if err != nil { return fmt.Errorf("cannot scale encode value: %w", err) @@ -57,14 +59,14 @@ func (n *Node) Encode(buffer Buffer) (err error) { } } - if n.Kind() == Branch { + if nodeIsBranch { err = encodeChildrenOpportunisticParallel(n.Children, buffer) if err != nil { return fmt.Errorf("cannot encode children of branch: %w", err) } } - if n.Kind() == Leaf { + if kind == Leaf { // TODO cache this for branches too and update test cases. // TODO remove this copying since it defeats the purpose of `buffer` // and the sync.Pool. diff --git a/internal/trie/node/encode_test.go b/internal/trie/node/encode_test.go index 1d0506da4ef..4cf70c5f3c0 100644 --- a/internal/trie/node/encode_test.go +++ b/internal/trie/node/encode_test.go @@ -127,6 +127,25 @@ func Test_Node_Encode(t *testing.T) { bufferBytesCall: true, expectedEncoding: []byte{1, 2, 3}, }, + "leaf with empty value success": { + node: &Node{ + Key: []byte{1, 2, 3}, + }, + writes: []writeCall{ + { + written: []byte{leafVariant.bits | 3}, // partial key length 3 + }, + { + written: []byte{0x01, 0x23}, + }, + { + written: []byte{0}, + }, + }, + bufferLenCall: true, + bufferBytesCall: true, + expectedEncoding: []byte{1, 2, 3}, + }, "clean branch with encoding": { node: &Node{ Children: make([]*Node, ChildrenCapacity), @@ -297,6 +316,32 @@ func Test_Node_Encode(t *testing.T) { }, }, }, + "branch without value and with children success": { + node: &Node{ + Key: []byte{1, 2, 3}, + Children: []*Node{ + nil, nil, nil, {Key: []byte{9}, SubValue: []byte{1}}, + nil, nil, nil, {Key: []byte{11}, SubValue: []byte{1}}, + }, + }, + writes: []writeCall{ + { // header + written: []byte{branchVariant.bits | 3}, // partial key length 3 + }, + { // key LE + written: []byte{0x01, 0x23}, + }, + { // children bitmap + written: []byte{136, 0}, + }, + { // first children + written: []byte{16, 65, 9, 4, 1}, + }, + { // second children + written: []byte{16, 65, 11, 4, 1}, + }, + }, + }, } for name, testCase := range testCases { diff --git a/internal/trie/node/subvalue.go b/internal/trie/node/subvalue.go deleted file mode 100644 index d00aab788f6..00000000000 --- a/internal/trie/node/subvalue.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2022 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package node - -import "bytes" - -// SubValueEqual returns true if the node subvalue is equal to the -// subvalue given as argument. In particular, it returns false -// if one subvalue is nil and the other subvalue is the empty slice. -func (n Node) SubValueEqual(subValue []byte) (equal bool) { - if len(subValue) == 0 && len(n.SubValue) == 0 { - return (subValue == nil && n.SubValue == nil) || - (subValue != nil && n.SubValue != nil) - } - return bytes.Equal(n.SubValue, subValue) -} diff --git a/internal/trie/node/subvalue_test.go b/internal/trie/node/subvalue_test.go deleted file mode 100644 index 3d190acc63c..00000000000 --- a/internal/trie/node/subvalue_test.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2022 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package node - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func Test_Node_SubValueEqual(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - node Node - subValue []byte - equal bool - }{ - "nil node subvalue and nil subvalue": { - equal: true, - }, - "empty node subvalue and empty subvalue": { - node: Node{SubValue: []byte{}}, - subValue: []byte{}, - equal: true, - }, - "nil node subvalue and empty subvalue": { - subValue: []byte{}, - }, - "empty node subvalue and nil subvalue": { - node: Node{SubValue: []byte{}}, - }, - "equal non empty values": { - node: Node{SubValue: []byte{1, 2}}, - subValue: []byte{1, 2}, - equal: true, - }, - "not equal non empty values": { - node: Node{SubValue: []byte{1, 2}}, - subValue: []byte{1, 3}, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - node := testCase.node - - equal := node.SubValueEqual(testCase.subValue) - - assert.Equal(t, testCase.equal, equal) - }) - } -} diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 9aecb152615..c3fd7f320fc 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -342,6 +342,11 @@ func (t *Trie) Put(keyLE, value []byte) { func (t *Trie) insertKeyLE(keyLE, value []byte, deletedMerkleValues map[string]struct{}) { nibblesKey := codec.KeyLEToNibbles(keyLE) + if len(value) == 0 { + // Force value to be inserted to nil since we don't + // differentiate between nil and empty values. + value = nil + } t.root, _, _ = t.insert(t.root, nibblesKey, value, deletedMerkleValues) } @@ -374,7 +379,7 @@ func (t *Trie) insertInLeaf(parentLeaf *Node, key, value []byte, newParent *Node, mutated bool, nodesCreated uint32) { if bytes.Equal(parentLeaf.Key, key) { nodesCreated = 0 - if parentLeaf.SubValueEqual(value) { + if bytes.Equal(parentLeaf.SubValue, value) { mutated = false return parentLeaf, mutated, nodesCreated } @@ -455,7 +460,7 @@ func (t *Trie) insertInBranch(parentBranch *Node, key, value []byte, copySettings := node.DefaultCopySettings if bytes.Equal(key, parentBranch.Key) { - if parentBranch.SubValueEqual(value) { + if bytes.Equal(parentBranch.SubValue, value) { mutated = false return parentBranch, mutated, 0 } diff --git a/lib/trie/trie_endtoend_test.go b/lib/trie/trie_endtoend_test.go index 90261d665a3..5fb670e25f7 100644 --- a/lib/trie/trie_endtoend_test.go +++ b/lib/trie/trie_endtoend_test.go @@ -106,7 +106,11 @@ func Fuzz_Trie_PutAndGet_Single(f *testing.F) { trie := NewEmptyTrie() trie.Put(key, value) retrievedValue := trie.Get(key) - assert.Equal(t, retrievedValue, value) + if retrievedValue == nil { + assert.Empty(t, value) + } else { + assert.Equal(t, value, retrievedValue) + } }) }