Skip to content

Commit

Permalink
fix(trie): disallow empty byte slice node values
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Nov 9, 2022
1 parent 463a9b7 commit 9e54856
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 85 deletions.
3 changes: 2 additions & 1 deletion internal/trie/node/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ func (n *Node) Copy(settings CopySettings) *Node {
copy(cpy.Key, n.Key)
}

// nil and []byte{} are encoded differently, watch out!
// nil and []byte{} values for branches result in a different node encoding,
// so we ensure to keep the `nil` value.
if settings.CopyValue && n.SubValue != nil {
cpy.SubValue = make([]byte, len(n.SubValue))
copy(cpy.SubValue, n.SubValue)
Expand Down
14 changes: 13 additions & 1 deletion internal/trie/node/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ func Test_decodeLeaf(t *testing.T) {
errWrapped: ErrDecodeValue,
errMessage: "cannot decode value: unknown prefix for compact uint: 255",
},
"zero value": {
"missing value data": {
reader: bytes.NewBuffer([]byte{
9, // key data
// missing value data
Expand All @@ -338,6 +338,18 @@ func Test_decodeLeaf(t *testing.T) {
Key: []byte{9},
},
},
"empty value data": {
reader: bytes.NewBuffer(concatByteSlices([][]byte{
{9}, // key data
scaleEncodeByteSlice(t, nil),
})),
variant: leafVariant.bits,
partialKeyLength: 1,
leaf: &Node{
Key: []byte{9},
Dirty: true,
},
},
"success": {
reader: bytes.NewBuffer(
concatByteSlices([][]byte{
Expand Down
14 changes: 8 additions & 6 deletions internal/trie/node/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,19 @@ func (n *Node) Encode(buffer Buffer) (err error) {
return fmt.Errorf("cannot write LE key to buffer: %w", err)
}

if n.Kind() == Branch {
kind := n.Kind()
nodeIsBranch := kind == Branch
if nodeIsBranch {
childrenBitmap := common.Uint16ToBytes(n.ChildrenBitmap())
_, err = buffer.Write(childrenBitmap)
if err != nil {
return fmt.Errorf("cannot write children bitmap to buffer: %w", err)
}
}

// check value is not nil for branch nodes, even though
// leaf nodes always have a non-nil value.
if n.SubValue != nil {
// Only encode node value if the node is a leaf or
// the node is a branch with a non empty value.
if !nodeIsBranch || (nodeIsBranch && n.SubValue != nil) {
encodedValue, err := scale.Marshal(n.SubValue) // TODO scale encoder to write to buffer
if err != nil {
return fmt.Errorf("cannot scale encode value: %w", err)
Expand All @@ -57,14 +59,14 @@ func (n *Node) Encode(buffer Buffer) (err error) {
}
}

if n.Kind() == Branch {
if nodeIsBranch {
err = encodeChildrenOpportunisticParallel(n.Children, buffer)
if err != nil {
return fmt.Errorf("cannot encode children of branch: %w", err)
}
}

if n.Kind() == Leaf {
if kind == Leaf {
// TODO cache this for branches too and update test cases.
// TODO remove this copying since it defeats the purpose of `buffer`
// and the sync.Pool.
Expand Down
45 changes: 45 additions & 0 deletions internal/trie/node/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,25 @@ func Test_Node_Encode(t *testing.T) {
bufferBytesCall: true,
expectedEncoding: []byte{1, 2, 3},
},
"leaf with empty value success": {
node: &Node{
Key: []byte{1, 2, 3},
},
writes: []writeCall{
{
written: []byte{leafVariant.bits | 3}, // partial key length 3
},
{
written: []byte{0x01, 0x23},
},
{
written: []byte{0},
},
},
bufferLenCall: true,
bufferBytesCall: true,
expectedEncoding: []byte{1, 2, 3},
},
"clean branch with encoding": {
node: &Node{
Children: make([]*Node, ChildrenCapacity),
Expand Down Expand Up @@ -297,6 +316,32 @@ func Test_Node_Encode(t *testing.T) {
},
},
},
"branch without value and with children success": {
node: &Node{
Key: []byte{1, 2, 3},
Children: []*Node{
nil, nil, nil, {Key: []byte{9}, SubValue: []byte{1}},
nil, nil, nil, {Key: []byte{11}, SubValue: []byte{1}},
},
},
writes: []writeCall{
{ // header
written: []byte{branchVariant.bits | 3}, // partial key length 3
},
{ // key LE
written: []byte{0x01, 0x23},
},
{ // children bitmap
written: []byte{136, 0},
},
{ // first children
written: []byte{16, 65, 9, 4, 1},
},
{ // second children
written: []byte{16, 65, 11, 4, 1},
},
},
},
}

for name, testCase := range testCases {
Expand Down
17 changes: 0 additions & 17 deletions internal/trie/node/subvalue.go

This file was deleted.

57 changes: 0 additions & 57 deletions internal/trie/node/subvalue_test.go

This file was deleted.

9 changes: 7 additions & 2 deletions lib/trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ func (t *Trie) Put(keyLE, value []byte) {

func (t *Trie) insertKeyLE(keyLE, value []byte, deletedMerkleValues map[string]struct{}) {
nibblesKey := codec.KeyLEToNibbles(keyLE)
if len(value) == 0 {
// Force value to be inserted to nil since we don't
// differentiate between nil and empty values.
value = nil
}
t.root, _, _ = t.insert(t.root, nibblesKey, value, deletedMerkleValues)
}

Expand Down Expand Up @@ -374,7 +379,7 @@ func (t *Trie) insertInLeaf(parentLeaf *Node, key, value []byte,
newParent *Node, mutated bool, nodesCreated uint32) {
if bytes.Equal(parentLeaf.Key, key) {
nodesCreated = 0
if parentLeaf.SubValueEqual(value) {
if bytes.Equal(parentLeaf.SubValue, value) {
mutated = false
return parentLeaf, mutated, nodesCreated
}
Expand Down Expand Up @@ -455,7 +460,7 @@ func (t *Trie) insertInBranch(parentBranch *Node, key, value []byte,
copySettings := node.DefaultCopySettings

if bytes.Equal(key, parentBranch.Key) {
if parentBranch.SubValueEqual(value) {
if bytes.Equal(parentBranch.SubValue, value) {
mutated = false
return parentBranch, mutated, 0
}
Expand Down
6 changes: 5 additions & 1 deletion lib/trie/trie_endtoend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ func Fuzz_Trie_PutAndGet_Single(f *testing.F) {
trie := NewEmptyTrie()
trie.Put(key, value)
retrievedValue := trie.Get(key)
assert.Equal(t, retrievedValue, value)
if retrievedValue == nil {
assert.Empty(t, value)
} else {
assert.Equal(t, value, retrievedValue)
}
})
}

Expand Down

0 comments on commit 9e54856

Please sign in to comment.