From ccf0218e4317bd3d39ba0b26a0488e9e6c9dfd46 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 23 Nov 2021 13:38:34 +0000 Subject: [PATCH 01/50] export all trie node methods --- lib/trie/database.go | 36 +++++++++++------------ lib/trie/hash.go | 2 ++ lib/trie/lookup.go | 2 +- lib/trie/node.go | 66 +++++++++++++++++++++--------------------- lib/trie/node_test.go | 4 +-- lib/trie/print.go | 2 +- lib/trie/proof_test.go | 2 +- lib/trie/trie.go | 52 ++++++++++++++++----------------- lib/trie/trie_test.go | 2 +- 9 files changed, 85 insertions(+), 83 deletions(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index 528664d078..2c63d4b4bc 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -36,7 +36,7 @@ func (t *Trie) store(db chaindb.Batch, curr node) error { return nil } - enc, hash, err := curr.encodeAndHash() + enc, hash, err := curr.EncodeAndHash() if err != nil { return err } @@ -59,8 +59,8 @@ func (t *Trie) store(db chaindb.Batch, curr node) error { } } - if curr.isDirty() { - curr.setDirty(false) + if curr.IsDirty() { + curr.SetDirty(false) } return nil @@ -82,10 +82,10 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { return err } - decNode.setDirty(false) - decNode.setEncodingAndHash(rawNode, nil) + decNode.SetDirty(false) + decNode.SetEncodingAndHash(rawNode, nil) - _, computedRoot, err := decNode.encodeAndHash() + _, computedRoot, err := decNode.EncodeAndHash() if err != nil { return err } @@ -114,7 +114,7 @@ func (t *Trie) loadProof(proof map[string]node, curr node) { continue } - proofNode, ok := proof[common.BytesToHex(child.getHash())] + proofNode, ok := proof[common.BytesToHex(child.GetHash())] if !ok { continue } @@ -142,8 +142,8 @@ func (t *Trie) Load(db chaindb.Database, root common.Hash) error { return err } - t.root.setDirty(false) - t.root.setEncodingAndHash(enc, root[:]) + t.root.SetDirty(false) + t.root.SetEncodingAndHash(enc, root[:]) return t.load(db, t.root) } @@ -155,7 +155,7 @@ func (t *Trie) load(db chaindb.Database, curr node) error { continue } - hash := child.getHash() + hash := child.GetHash() enc, err := db.Get(hash) if err != nil { return fmt.Errorf("failed to find node key=%x index=%d: %w", child.(*leaf).hash, i, err) @@ -166,8 +166,8 @@ func (t *Trie) load(db chaindb.Database, curr node) error { return err } - child.setDirty(false) - child.setEncodingAndHash(enc, hash) + child.SetDirty(false) + child.SetEncodingAndHash(enc, hash) c.children[i] = child err = t.load(db, child) @@ -188,7 +188,7 @@ func (t *Trie) GetNodeHashes(curr node, keys map[common.Hash]struct{}) error { continue } - hash := child.getHash() + hash := child.GetHash() keys[common.BytesToHash(hash)] = struct{}{} err := t.GetNodeHashes(child, keys) @@ -309,11 +309,11 @@ func (t *Trie) WriteDirty(db chaindb.Database) error { } func (t *Trie) writeDirty(db chaindb.Batch, curr node) error { - if curr == nil || !curr.isDirty() { + if curr == nil || !curr.IsDirty() { return nil } - enc, hash, err := curr.encodeAndHash() + enc, hash, err := curr.EncodeAndHash() if err != nil { return err } @@ -346,7 +346,7 @@ func (t *Trie) writeDirty(db chaindb.Batch, curr node) error { } } - curr.setDirty(false) + curr.SetDirty(false) return nil } @@ -358,11 +358,11 @@ func (t *Trie) GetInsertedNodeHashes() ([]common.Hash, error) { func (t *Trie) getInsertedNodeHashes(curr node) ([]common.Hash, error) { var nodeHashes []common.Hash - if curr == nil || !curr.isDirty() { + if curr == nil || !curr.IsDirty() { return nil, nil } - enc, hash, err := curr.encodeAndHash() + enc, hash, err := curr.EncodeAndHash() if err != nil { return nil, err } diff --git a/lib/trie/hash.go b/lib/trie/hash.go index ecb674ea82..87a18d439e 100644 --- a/lib/trie/hash.go +++ b/lib/trie/hash.go @@ -36,6 +36,8 @@ var hasherPool = &sync.Pool{ New: func() interface{} { hasher, err := blake2b.New256(nil) if err != nil { + // Conversation on why we panic here: + // https://github.com/ChainSafe/gossamer/pull/2009#discussion_r753430764 panic("cannot create Blake2b-256 hasher: " + err.Error()) } return hasher diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index c15501a2d3..948f270780 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -13,7 +13,7 @@ func findAndRecord(t *Trie, key []byte, recorder *recorder) error { } func find(parent node, key []byte, recorder *recorder) error { - enc, hash, err := parent.encodeAndHash() + enc, hash, err := parent.EncodeAndHash() if err != nil { return err } diff --git a/lib/trie/node.go b/lib/trie/node.go index 1def45a4c5..381460fa7d 100644 --- a/lib/trie/node.go +++ b/lib/trie/node.go @@ -40,17 +40,17 @@ import ( // node is the interface for trie methods type node interface { - encodeAndHash() ([]byte, []byte, error) - decode(r io.Reader, h byte) error - isDirty() bool - setDirty(dirty bool) - setKey(key []byte) + EncodeAndHash() ([]byte, []byte, error) + Decode(r io.Reader, h byte) error + IsDirty() bool + SetDirty(dirty bool) + SetKey(key []byte) String() string - setEncodingAndHash([]byte, []byte) - getHash() []byte - getGeneration() uint64 - setGeneration(uint64) - copy() node + SetEncodingAndHash([]byte, []byte) + GetHash() []byte + GetGeneration() uint64 + SetGeneration(uint64) + Copy() node } type ( @@ -76,15 +76,15 @@ type ( } ) -func (b *branch) setGeneration(generation uint64) { +func (b *branch) SetGeneration(generation uint64) { b.generation = generation } -func (l *leaf) setGeneration(generation uint64) { +func (l *leaf) SetGeneration(generation uint64) { l.generation = generation } -func (b *branch) copy() node { +func (b *branch) Copy() node { b.RLock() defer b.RUnlock() @@ -110,7 +110,7 @@ func (b *branch) copy() node { return cpy } -func (l *leaf) copy() node { +func (l *leaf) Copy() node { l.RLock() defer l.RUnlock() @@ -132,12 +132,12 @@ func (l *leaf) copy() node { return cpy } -func (b *branch) setEncodingAndHash(enc, hash []byte) { +func (b *branch) SetEncodingAndHash(enc, hash []byte) { b.encoding = enc b.hash = hash } -func (l *leaf) setEncodingAndHash(enc, hash []byte) { +func (l *leaf) SetEncodingAndHash(enc, hash []byte) { l.encodingMu.Lock() l.encoding = enc l.encodingMu.Unlock() @@ -145,19 +145,19 @@ func (l *leaf) setEncodingAndHash(enc, hash []byte) { l.hash = hash } -func (b *branch) getHash() []byte { +func (b *branch) GetHash() []byte { return b.hash } -func (b *branch) getGeneration() uint64 { +func (b *branch) GetGeneration() uint64 { return b.generation } -func (l *leaf) getGeneration() uint64 { +func (l *leaf) GetGeneration() uint64 { return l.generation } -func (l *leaf) getHash() []byte { +func (l *leaf) GetHash() []byte { return l.hash } @@ -198,31 +198,31 @@ func (b *branch) numChildren() int { return count } -func (l *leaf) isDirty() bool { +func (l *leaf) IsDirty() bool { return l.dirty } -func (b *branch) isDirty() bool { +func (b *branch) IsDirty() bool { return b.dirty } -func (l *leaf) setDirty(dirty bool) { +func (l *leaf) SetDirty(dirty bool) { l.dirty = dirty } -func (b *branch) setDirty(dirty bool) { +func (b *branch) SetDirty(dirty bool) { b.dirty = dirty } -func (l *leaf) setKey(key []byte) { +func (l *leaf) SetKey(key []byte) { l.key = key } -func (b *branch) setKey(key []byte) { +func (b *branch) SetKey(key []byte) { b.key = key } -func (b *branch) encodeAndHash() (encoding, hash []byte, err error) { +func (b *branch) EncodeAndHash() (encoding, hash []byte, err error) { if !b.dirty && b.encoding != nil && b.hash != nil { return b.encoding, b.hash, nil } @@ -260,9 +260,9 @@ func (b *branch) encodeAndHash() (encoding, hash []byte, err error) { return encoding, hash, nil } -func (l *leaf) encodeAndHash() (encoding, hash []byte, err error) { +func (l *leaf) EncodeAndHash() (encoding, hash []byte, err error) { l.encodingMu.RLock() - if !l.isDirty() && l.encoding != nil && l.hash != nil { + if !l.IsDirty() && l.encoding != nil && l.hash != nil { l.encodingMu.RUnlock() return l.encoding, l.hash, nil } @@ -321,11 +321,11 @@ func decode(r io.Reader) (node, error) { nodeType := header >> 6 if nodeType == 1 { l := new(leaf) - err := l.decode(r, header) + err := l.Decode(r, header) return l, err } else if nodeType == 2 || nodeType == 3 { b := new(branch) - err := b.decode(r, header) + err := b.Decode(r, header) return b, err } @@ -335,7 +335,7 @@ func decode(r io.Reader) (node, error) { // Decode decodes a byte array with the encoding specified at the top of this package into a branch node // Note that since the encoded branch stores the hash of the children nodes, we aren't able to reconstruct the child // nodes from the encoding. This function instead stubs where the children are known to be with an empty leaf. -func (b *branch) decode(r io.Reader, header byte) (err error) { +func (b *branch) Decode(r io.Reader, header byte) (err error) { if header == 0 { header, err = readByte(r) if err != nil { @@ -392,7 +392,7 @@ func (b *branch) decode(r io.Reader, header byte) (err error) { } // Decode decodes a byte array with the encoding specified at the top of this package into a leaf node -func (l *leaf) decode(r io.Reader, header byte) (err error) { +func (l *leaf) Decode(r io.Reader, header byte) (err error) { if header == 0 { header, err = readByte(r) if err != nil { diff --git a/lib/trie/node_test.go b/lib/trie/node_test.go index 04e69a796d..f91a0cb2e4 100644 --- a/lib/trie/node_test.go +++ b/lib/trie/node_test.go @@ -266,7 +266,7 @@ func TestBranchDecode(t *testing.T) { require.NoError(t, err) res := new(branch) - err = res.decode(buffer, 0) + err = res.Decode(buffer, 0) require.NoError(t, err) require.Equal(t, test.key, res.key) @@ -294,7 +294,7 @@ func TestLeafDecode(t *testing.T) { require.NoError(t, err) res := new(leaf) - err = res.decode(buffer, 0) + err = res.Decode(buffer, 0) require.NoError(t, err) res.hash = nil diff --git a/lib/trie/print.go b/lib/trie/print.go index ba72fde4a5..0604de165b 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -18,7 +18,7 @@ func (t *Trie) String() string { return "empty" } - tree := gotree.New(fmt.Sprintf("Trie root=0x%x", t.root.getHash())) + tree := gotree.New(fmt.Sprintf("Trie root=0x%x", t.root.GetHash())) t.string(tree, t.root, 0) return fmt.Sprintf("\n%s", tree.Print()) } diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index 2e472e2fdc..7c190d1c3c 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -69,7 +69,7 @@ func testGenerateProof(t *testing.T, entries []Pair, keys [][]byte) ([]byte, [][ err = trie.Store(memdb) require.NoError(t, err) - root := trie.root.getHash() + root := trie.root.GetHash() proof, err := GenerateProof(root, keys, memdb) require.NoError(t, err) diff --git a/lib/trie/trie.go b/lib/trie/trie.go index fe422a2e74..578a11e18a 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -69,13 +69,13 @@ func (t *Trie) maybeUpdateGeneration(n node) node { } // Make a copy if the generation is updated. - if n.getGeneration() < t.generation { + if n.GetGeneration() < t.generation { // Insert a new node in the current generation. - newNode := n.copy() - newNode.setGeneration(t.generation) + newNode := n.Copy() + newNode.SetGeneration(t.generation) // Hash of old nodes should already be computed since it belongs to older generation. - oldNodeHash := n.getHash() + oldNodeHash := n.GetHash() if len(oldNodeHash) > 0 { hash := common.BytesToHash(oldNodeHash) t.deletedKeys = append(t.deletedKeys, hash) @@ -261,12 +261,12 @@ func (t *Trie) insert(parent node, key []byte, value node) node { case *branch: n := t.updateBranch(p, key, value) - if p != nil && n != nil && n.isDirty() { - p.setDirty(true) + if p != nil && n != nil && n.IsDirty() { + p.SetDirty(true) } return n case nil: - value.setKey(key) + value.SetKey(key) return value case *leaf: // if a value already exists in the trie at this key, overwrite it with the new value @@ -288,19 +288,19 @@ func (t *Trie) insert(parent node, key []byte, value node) node { // value goes at this branch if len(key) == length { br.value = value.(*leaf).value - br.setDirty(true) + br.SetDirty(true) // if we are not replacing previous leaf, then add it as a child to the new branch if len(parentKey) > len(key) { p.key = p.key[length+1:] br.children[parentKey[length]] = p - p.setDirty(true) + p.SetDirty(true) } return br } - value.setKey(key[length+1:]) + value.SetKey(key[length+1:]) if length == len(p.key) { // if leaf's key is covered by this branch, then make the leaf's @@ -310,7 +310,7 @@ func (t *Trie) insert(parent node, key []byte, value node) node { } else { // otherwise, make the leaf a child of the branch and update its partial key p.key = p.key[length+1:] - p.setDirty(true) + p.SetDirty(true) br.children[parentKey[length]] = p br.children[key[length]] = value } @@ -331,7 +331,7 @@ func (t *Trie) updateBranch(p *branch, key []byte, value node) (n node) { if length == len(p.key) { // if node has same key as this branch, then update the value at this branch if bytes.Equal(key, p.key) { - p.setDirty(true) + p.SetDirty(true) switch v := value.(type) { case *branch: p.value = v.value @@ -345,14 +345,14 @@ func (t *Trie) updateBranch(p *branch, key []byte, value node) (n node) { case *branch, *leaf: n = t.insert(c, key[length+1:], value) p.children[key[length]] = n - n.setDirty(true) - p.setDirty(true) + n.SetDirty(true) + p.SetDirty(true) return p case nil: // otherwise, add node as child of this branch value.(*leaf).key = key[length+1:] p.children[key[length]] = value - p.setDirty(true) + p.SetDirty(true) return p } @@ -372,7 +372,7 @@ func (t *Trie) updateBranch(p *branch, key []byte, value node) (n node) { br.children[key[length]] = t.insert(nil, key[length+1:], value) } - br.setDirty(true) + br.SetDirty(true) return br } @@ -538,7 +538,7 @@ func (t *Trie) clearPrefixLimit(cn node, prefix []byte, limit *uint32) (node, bo i := prefix[len(c.key)] c.children[i], _ = t.deleteNodes(c.children[i], []byte{}, limit) - c.setDirty(true) + c.SetDirty(true) curr = handleDeletion(c, prefix) if c.children[i] == nil { @@ -557,11 +557,11 @@ func (t *Trie) clearPrefixLimit(cn node, prefix []byte, limit *uint32) (node, bo var wasUpdated, allDeleted bool c.children[i], wasUpdated, allDeleted = t.clearPrefixLimit(c.children[i], prefix[len(c.key)+1:], limit) if wasUpdated { - c.setDirty(true) + c.SetDirty(true) curr = handleDeletion(c, prefix) } - return curr, curr.isDirty(), allDeleted + return curr, curr.IsDirty(), allDeleted case *leaf: length := lenCommonPrefix(c.key, prefix) if length == len(prefix) { @@ -603,7 +603,7 @@ func (t *Trie) deleteNodes(cn node, prefix []byte, limit *uint32) (node, bool) { continue } - c.setDirty(true) + c.SetDirty(true) curr = handleDeletion(c, prefix) isAllNil := c.numChildren() == 0 if isAllNil && c.value == nil { @@ -661,7 +661,7 @@ func (t *Trie) clearPrefix(cn node, prefix []byte) (node, bool) { // found prefix at child index, delete child i := prefix[len(c.key)] c.children[i] = nil - c.setDirty(true) + c.SetDirty(true) curr = handleDeletion(c, prefix) return curr, true } @@ -676,11 +676,11 @@ func (t *Trie) clearPrefix(cn node, prefix []byte) (node, bool) { c.children[i], wasUpdated = t.clearPrefix(c.children[i], prefix[len(c.key)+1:]) if wasUpdated { - c.setDirty(true) + c.SetDirty(true) curr = handleDeletion(c, prefix) } - return curr, curr.isDirty() + return curr, curr.IsDirty() case *leaf: length := lenCommonPrefix(c.key, prefix) if length == len(prefix) { @@ -709,7 +709,7 @@ func (t *Trie) delete(parent node, key []byte) (node, bool) { if bytes.Equal(p.key, key) || len(key) == 0 { // found the value at this node p.value = nil - p.setDirty(true) + p.SetDirty(true) return handleDeletion(p, key), true } @@ -720,7 +720,7 @@ func (t *Trie) delete(parent node, key []byte) (node, bool) { } p.children[key[length]] = n - p.setDirty(true) + p.SetDirty(true) n = handleDeletion(p, key) return n, true case *leaf: @@ -779,7 +779,7 @@ func handleDeletion(p *branch, key []byte) node { default: // do nothing } - n.setDirty(true) + n.SetDirty(true) } return n diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index f6ae3ff779..3d3bf170d1 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -515,7 +515,7 @@ func TestTrieDiff(t *testing.T) { } dbTrie := NewEmptyTrie() - err = dbTrie.Load(storageDB, common.BytesToHash(newTrie.root.getHash())) + err = dbTrie.Load(storageDB, common.BytesToHash(newTrie.root.GetHash())) require.NoError(t, err) } From c94b882723ac077fb7d4d17895b0c0569305c4ac Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 23 Nov 2021 13:40:03 +0000 Subject: [PATCH 02/50] Export `node` interface --- lib/trie/database.go | 16 ++++----- lib/trie/hash.go | 14 ++++---- lib/trie/lookup.go | 2 +- lib/trie/node.go | 16 ++++----- lib/trie/node_test.go | 84 +++++++++++++++++++++---------------------- lib/trie/print.go | 2 +- lib/trie/trie.go | 34 +++++++++--------- 7 files changed, 84 insertions(+), 84 deletions(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index 2c63d4b4bc..d35ed003d9 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -31,7 +31,7 @@ func (t *Trie) Store(db chaindb.Database) error { return batch.Flush() } -func (t *Trie) store(db chaindb.Batch, curr node) error { +func (t *Trie) store(db chaindb.Batch, curr Node) error { if curr == nil { return nil } @@ -72,7 +72,7 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { return ErrEmptyProof } - mappedNodes := make(map[string]node, len(proof)) + mappedNodes := make(map[string]Node, len(proof)) // map all the proofs hash -> decoded node // and takes the loop to indentify the root node @@ -103,7 +103,7 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { // loadProof is a recursive function that will create all the trie paths based // on the mapped proofs slice starting by the root -func (t *Trie) loadProof(proof map[string]node, curr node) { +func (t *Trie) loadProof(proof map[string]Node, curr Node) { c, ok := curr.(*branch) if !ok { return @@ -148,7 +148,7 @@ func (t *Trie) Load(db chaindb.Database, root common.Hash) error { return t.load(db, t.root) } -func (t *Trie) load(db chaindb.Database, curr node) error { +func (t *Trie) load(db chaindb.Database, curr Node) error { if c, ok := curr.(*branch); ok { for i, child := range c.children { if child == nil { @@ -181,7 +181,7 @@ func (t *Trie) load(db chaindb.Database, curr node) error { } // GetNodeHashes return hash of each key of the trie. -func (t *Trie) GetNodeHashes(curr node, keys map[common.Hash]struct{}) error { +func (t *Trie) GetNodeHashes(curr Node, keys map[common.Hash]struct{}) error { if c, ok := curr.(*branch); ok { for _, child := range c.children { if child == nil { @@ -249,7 +249,7 @@ func GetFromDB(db chaindb.Database, root common.Hash, key []byte) ([]byte, error return getFromDB(db, rootNode, k) } -func getFromDB(db chaindb.Database, parent node, key []byte) ([]byte, error) { +func getFromDB(db chaindb.Database, parent Node, key []byte) ([]byte, error) { var value []byte switch p := parent.(type) { @@ -308,7 +308,7 @@ func (t *Trie) WriteDirty(db chaindb.Database) error { return batch.Flush() } -func (t *Trie) writeDirty(db chaindb.Batch, curr node) error { +func (t *Trie) writeDirty(db chaindb.Batch, curr Node) error { if curr == nil || !curr.IsDirty() { return nil } @@ -356,7 +356,7 @@ func (t *Trie) GetInsertedNodeHashes() ([]common.Hash, error) { return t.getInsertedNodeHashes(t.root) } -func (t *Trie) getInsertedNodeHashes(curr node) ([]common.Hash, error) { +func (t *Trie) getInsertedNodeHashes(curr Node) ([]common.Hash, error) { var nodeHashes []common.Hash if curr == nil || !curr.IsDirty() { return nil, nil diff --git a/lib/trie/hash.go b/lib/trie/hash.go index 87a18d439e..eaf7466c87 100644 --- a/lib/trie/hash.go +++ b/lib/trie/hash.go @@ -44,7 +44,7 @@ var hasherPool = &sync.Pool{ }, } -func hashNode(n node, digestBuffer io.Writer) (err error) { +func hashNode(n Node, digestBuffer io.Writer) (err error) { encodingBuffer := encodingBufferPool.Get().(*bytes.Buffer) encodingBuffer.Reset() defer encodingBufferPool.Put(encodingBuffer) @@ -96,7 +96,7 @@ type bytesBuffer interface { // It is the high-level function wrapping the encoding for different // node types. The encoding has the following format: // NodeHeader | Extra partial key length | Partial Key | Value -func encodeNode(n node, buffer bytesBuffer, parallel bool) (err error) { +func encodeNode(n Node, buffer bytesBuffer, parallel bool) (err error) { switch n := n.(type) { case *branch: err := encodeBranch(n, buffer, parallel) @@ -186,7 +186,7 @@ func encodeBranch(b *branch, buffer io.Writer, parallel bool) (err error) { return nil } -func encodeChildrenInParallel(children [16]node, buffer io.Writer) (err error) { +func encodeChildrenInParallel(children [16]Node, buffer io.Writer) (err error) { type result struct { index int buffer *bytes.Buffer @@ -196,7 +196,7 @@ func encodeChildrenInParallel(children [16]node, buffer io.Writer) (err error) { resultsCh := make(chan result) for i, child := range children { - go func(index int, child node) { + go func(index int, child Node) { buffer := encodingBufferPool.Get().(*bytes.Buffer) buffer.Reset() // buffer is put back in the pool after processing its @@ -253,7 +253,7 @@ func encodeChildrenInParallel(children [16]node, buffer io.Writer) (err error) { return err } -func encodeChildrenSequentially(children [16]node, buffer io.Writer) (err error) { +func encodeChildrenSequentially(children [16]Node, buffer io.Writer) (err error) { for i, child := range children { err = encodeChild(child, buffer) if err != nil { @@ -263,7 +263,7 @@ func encodeChildrenSequentially(children [16]node, buffer io.Writer) (err error) return nil } -func encodeChild(child node, buffer io.Writer) (err error) { +func encodeChild(child Node, buffer io.Writer) (err error) { var isNil bool switch impl := child.(type) { case *branch: @@ -290,7 +290,7 @@ func encodeChild(child node, buffer io.Writer) (err error) { return nil } -func encodeAndHash(n node) (b []byte, err error) { +func encodeAndHash(n Node) (b []byte, err error) { buffer := digestBufferPool.Get().(*bytes.Buffer) buffer.Reset() defer digestBufferPool.Put(buffer) diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index 948f270780..395604c04d 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -12,7 +12,7 @@ func findAndRecord(t *Trie, key []byte, recorder *recorder) error { return find(t.root, key, recorder) } -func find(parent node, key []byte, recorder *recorder) error { +func find(parent Node, key []byte, recorder *recorder) error { enc, hash, err := parent.EncodeAndHash() if err != nil { return err diff --git a/lib/trie/node.go b/lib/trie/node.go index 381460fa7d..90578473f8 100644 --- a/lib/trie/node.go +++ b/lib/trie/node.go @@ -38,8 +38,8 @@ import ( "github.com/ChainSafe/gossamer/pkg/scale" ) -// node is the interface for trie methods -type node interface { +// Node is the interface for trie methods +type Node interface { EncodeAndHash() ([]byte, []byte, error) Decode(r io.Reader, h byte) error IsDirty() bool @@ -50,13 +50,13 @@ type node interface { GetHash() []byte GetGeneration() uint64 SetGeneration(uint64) - Copy() node + Copy() Node } type ( branch struct { key []byte // partial key - children [16]node + children [16]Node value []byte dirty bool hash []byte @@ -84,7 +84,7 @@ func (l *leaf) SetGeneration(generation uint64) { l.generation = generation } -func (b *branch) Copy() node { +func (b *branch) Copy() Node { b.RLock() defer b.RUnlock() @@ -110,7 +110,7 @@ func (b *branch) Copy() node { return cpy } -func (l *leaf) Copy() node { +func (l *leaf) Copy() Node { l.RLock() defer l.RUnlock() @@ -306,13 +306,13 @@ func (l *leaf) EncodeAndHash() (encoding, hash []byte, err error) { return encoding, hash, nil } -func decodeBytes(in []byte) (node, error) { +func decodeBytes(in []byte) (Node, error) { buffer := bytes.NewBuffer(in) return decode(buffer) } // decode wraps the decoding of different node types back into a node -func decode(r io.Reader) (node, error) { +func decode(r io.Reader) (Node, error) { header, err := readByte(r) if err != nil { return nil, err diff --git a/lib/trie/node_test.go b/lib/trie/node_test.go index f91a0cb2e4..01199f5563 100644 --- a/lib/trie/node_test.go +++ b/lib/trie/node_test.go @@ -36,7 +36,7 @@ func generateRand(size int) [][]byte { } func TestChildrenBitmap(t *testing.T) { - b := &branch{children: [16]node{}} + b := &branch{children: [16]Node{}} res := b.childrenBitmap() if res != 0 { t.Errorf("Fail to get children bitmap: got %x expected %x", res, 1) @@ -66,24 +66,24 @@ func TestBranchHeader(t *testing.T) { br *branch header []byte }{ - {&branch{key: nil, children: [16]node{}, value: nil}, []byte{0x80}}, - {&branch{key: []byte{0x00}, children: [16]node{}, value: nil}, []byte{0x81}}, - {&branch{key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]node{}, value: nil}, []byte{0x84}}, - - {&branch{key: nil, children: [16]node{}, value: []byte{0x01}}, []byte{0xc0}}, - {&branch{key: []byte{0x00}, children: [16]node{}, value: []byte{0x01}}, []byte{0xc1}}, - {&branch{key: []byte{0x00, 0x00}, children: [16]node{}, value: []byte{0x01}}, []byte{0xc2}}, - {&branch{key: []byte{0x00, 0x00, 0xf}, children: [16]node{}, value: []byte{0x01}}, []byte{0xc3}}, - - {&branch{key: byteArray(62), children: [16]node{}, value: nil}, []byte{0xbe}}, - {&branch{key: byteArray(62), children: [16]node{}, value: []byte{0x00}}, []byte{0xfe}}, - {&branch{key: byteArray(63), children: [16]node{}, value: nil}, []byte{0xbf, 0}}, - {&branch{key: byteArray(64), children: [16]node{}, value: nil}, []byte{0xbf, 1}}, - {&branch{key: byteArray(64), children: [16]node{}, value: []byte{0x01}}, []byte{0xff, 1}}, - - {&branch{key: byteArray(317), children: [16]node{}, value: []byte{0x01}}, []byte{255, 254}}, - {&branch{key: byteArray(318), children: [16]node{}, value: []byte{0x01}}, []byte{255, 255, 0}}, - {&branch{key: byteArray(573), children: [16]node{}, value: []byte{0x01}}, []byte{255, 255, 255, 0}}, + {&branch{key: nil, children: [16]Node{}, value: nil}, []byte{0x80}}, + {&branch{key: []byte{0x00}, children: [16]Node{}, value: nil}, []byte{0x81}}, + {&branch{key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]Node{}, value: nil}, []byte{0x84}}, + + {&branch{key: nil, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc0}}, + {&branch{key: []byte{0x00}, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc1}}, + {&branch{key: []byte{0x00, 0x00}, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc2}}, + {&branch{key: []byte{0x00, 0x00, 0xf}, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc3}}, + + {&branch{key: byteArray(62), children: [16]Node{}, value: nil}, []byte{0xbe}}, + {&branch{key: byteArray(62), children: [16]Node{}, value: []byte{0x00}}, []byte{0xfe}}, + {&branch{key: byteArray(63), children: [16]Node{}, value: nil}, []byte{0xbf, 0}}, + {&branch{key: byteArray(64), children: [16]Node{}, value: nil}, []byte{0xbf, 1}}, + {&branch{key: byteArray(64), children: [16]Node{}, value: []byte{0x01}}, []byte{0xff, 1}}, + + {&branch{key: byteArray(317), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 254}}, + {&branch{key: byteArray(318), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 255, 0}}, + {&branch{key: byteArray(573), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 255, 255, 0}}, } for _, test := range tests { @@ -102,7 +102,7 @@ func TestFailingPk(t *testing.T) { br *branch header []byte }{ - {&branch{key: byteArray(2 << 16), children: [16]node{}, value: []byte{0x01}}, []byte{255, 254}}, + {&branch{key: byteArray(2 << 16), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 254}}, } for _, test := range tests { @@ -147,7 +147,7 @@ func TestBranchEncode(t *testing.T) { randVals := generateRand(101) for i, testKey := range randKeys { - b := &branch{key: testKey, children: [16]node{}, value: randVals[i]} + b := &branch{key: testKey, children: [16]Node{}, value: randVals[i]} expected := bytes.NewBuffer(nil) header, err := b.header() @@ -235,27 +235,27 @@ func TestEncodeRoot(t *testing.T) { func TestBranchDecode(t *testing.T) { tests := []*branch{ - {key: []byte{}, children: [16]node{}, value: nil}, - {key: []byte{0x00}, children: [16]node{}, value: nil}, - {key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]node{}, value: nil}, - {key: []byte{}, children: [16]node{}, value: []byte{0x01}}, - {key: []byte{}, children: [16]node{&leaf{}}, value: []byte{0x01}}, - {key: []byte{}, children: [16]node{&leaf{}, nil, &leaf{}}, value: []byte{0x01}}, + {key: []byte{}, children: [16]Node{}, value: nil}, + {key: []byte{0x00}, children: [16]Node{}, value: nil}, + {key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]Node{}, value: nil}, + {key: []byte{}, children: [16]Node{}, value: []byte{0x01}}, + {key: []byte{}, children: [16]Node{&leaf{}}, value: []byte{0x01}}, + {key: []byte{}, children: [16]Node{&leaf{}, nil, &leaf{}}, value: []byte{0x01}}, { key: []byte{}, - children: [16]node{ + children: [16]Node{ &leaf{}, nil, &leaf{}, nil, nil, nil, nil, nil, nil, &leaf{}, nil, &leaf{}, }, value: []byte{0x01}, }, - {key: byteArray(62), children: [16]node{}, value: nil}, - {key: byteArray(63), children: [16]node{}, value: nil}, - {key: byteArray(64), children: [16]node{}, value: nil}, - {key: byteArray(317), children: [16]node{}, value: []byte{0x01}}, - {key: byteArray(318), children: [16]node{}, value: []byte{0x01}}, - {key: byteArray(573), children: [16]node{}, value: []byte{0x01}}, + {key: byteArray(62), children: [16]Node{}, value: nil}, + {key: byteArray(63), children: [16]Node{}, value: nil}, + {key: byteArray(64), children: [16]Node{}, value: nil}, + {key: byteArray(317), children: [16]Node{}, value: []byte{0x01}}, + {key: byteArray(318), children: [16]Node{}, value: []byte{0x01}}, + {key: byteArray(573), children: [16]Node{}, value: []byte{0x01}}, } buffer := bytes.NewBuffer(nil) @@ -304,16 +304,16 @@ func TestLeafDecode(t *testing.T) { } func TestDecode(t *testing.T) { - tests := []node{ - &branch{key: []byte{}, children: [16]node{}, value: nil}, - &branch{key: []byte{0x00}, children: [16]node{}, value: nil}, - &branch{key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]node{}, value: nil}, - &branch{key: []byte{}, children: [16]node{}, value: []byte{0x01}}, - &branch{key: []byte{}, children: [16]node{&leaf{}}, value: []byte{0x01}}, - &branch{key: []byte{}, children: [16]node{&leaf{}, nil, &leaf{}}, value: []byte{0x01}}, + tests := []Node{ + &branch{key: []byte{}, children: [16]Node{}, value: nil}, + &branch{key: []byte{0x00}, children: [16]Node{}, value: nil}, + &branch{key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]Node{}, value: nil}, + &branch{key: []byte{}, children: [16]Node{}, value: []byte{0x01}}, + &branch{key: []byte{}, children: [16]Node{&leaf{}}, value: []byte{0x01}}, + &branch{key: []byte{}, children: [16]Node{&leaf{}, nil, &leaf{}}, value: []byte{0x01}}, &branch{ key: []byte{}, - children: [16]node{ + children: [16]Node{ &leaf{}, nil, &leaf{}, nil, nil, nil, nil, nil, nil, &leaf{}, nil, &leaf{}}, diff --git a/lib/trie/print.go b/lib/trie/print.go index 0604de165b..a0c8a557e3 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -23,7 +23,7 @@ func (t *Trie) String() string { return fmt.Sprintf("\n%s", tree.Print()) } -func (t *Trie) string(tree gotree.Tree, curr node, idx int) { +func (t *Trie) string(tree gotree.Tree, curr Node, idx int) { switch c := curr.(type) { case *branch: buffer := encodingBufferPool.Get().(*bytes.Buffer) diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 578a11e18a..66080fe784 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -18,7 +18,7 @@ var EmptyHash, _ = NewEmptyTrie().Hash() // Use NewTrie to create a trie that sits on top of a database. type Trie struct { generation uint64 - root node + root Node childTries map[common.Hash]*Trie // Used to store the child tries. deletedKeys []common.Hash parallel bool @@ -30,7 +30,7 @@ func NewEmptyTrie() *Trie { } // NewTrie creates a trie with an existing root node -func NewTrie(root node) *Trie { +func NewTrie(root Node) *Trie { return &Trie{ root: root, childTries: make(map[common.Hash]*Trie), @@ -63,7 +63,7 @@ func (t *Trie) Snapshot() *Trie { return newTrie } -func (t *Trie) maybeUpdateGeneration(n node) node { +func (t *Trie) maybeUpdateGeneration(n Node) Node { if n == nil { return nil } @@ -102,7 +102,7 @@ func (t *Trie) DeepCopy() (*Trie, error) { } // RootNode returns the root of the trie -func (t *Trie) RootNode() node { +func (t *Trie) RootNode() Node { return t.root } @@ -140,7 +140,7 @@ func (t *Trie) Entries() map[string][]byte { return t.entries(t.root, nil, make(map[string][]byte)) } -func (t *Trie) entries(current node, prefix []byte, kv map[string][]byte) map[string][]byte { +func (t *Trie) entries(current Node, prefix []byte, kv map[string][]byte) map[string][]byte { switch c := current.(type) { case *branch: if c.value != nil { @@ -169,7 +169,7 @@ func (t *Trie) NextKey(key []byte) []byte { return nibblesToKeyLE(next) } -func (t *Trie) nextKey(curr node, prefix, key []byte) []byte { +func (t *Trie) nextKey(curr Node, prefix, key []byte) []byte { switch c := curr.(type) { case *branch: fullKey := append(prefix, c.key...) @@ -256,7 +256,7 @@ func (t *Trie) tryPut(key, value []byte) { } // insert attempts to insert a key with value into the trie -func (t *Trie) insert(parent node, key []byte, value node) node { +func (t *Trie) insert(parent Node, key []byte, value Node) Node { switch p := t.maybeUpdateGeneration(parent).(type) { case *branch: n := t.updateBranch(p, key, value) @@ -324,7 +324,7 @@ func (t *Trie) insert(parent node, key []byte, value node) node { // updateBranch attempts to add the value node to a branch // inserts the value node as the branch's child at the index that's // the first nibble of the key -func (t *Trie) updateBranch(p *branch, key []byte, value node) (n node) { +func (t *Trie) updateBranch(p *branch, key []byte, value Node) (n Node) { length := lenCommonPrefix(key, p.key) // whole parent key matches @@ -406,7 +406,7 @@ func (t *Trie) GetKeysWithPrefix(prefix []byte) [][]byte { return t.getKeysWithPrefix(t.root, []byte{}, p, [][]byte{}) } -func (t *Trie) getKeysWithPrefix(parent node, prefix, key []byte, keys [][]byte) [][]byte { +func (t *Trie) getKeysWithPrefix(parent Node, prefix, key []byte, keys [][]byte) [][]byte { switch p := parent.(type) { case *branch: length := lenCommonPrefix(p.key, key) @@ -437,7 +437,7 @@ func (t *Trie) getKeysWithPrefix(parent node, prefix, key []byte, keys [][]byte) // addAllKeys appends all keys that are descendants of the parent node to a slice of keys // it uses the prefix to determine the entire key -func (t *Trie) addAllKeys(parent node, prefix []byte, keys [][]byte) [][]byte { +func (t *Trie) addAllKeys(parent Node, prefix []byte, keys [][]byte) [][]byte { switch p := parent.(type) { case *branch: if p.value != nil { @@ -471,7 +471,7 @@ func (t *Trie) tryGet(key []byte) *leaf { return t.retrieve(t.root, k) } -func (t *Trie) retrieve(parent node, key []byte) *leaf { +func (t *Trie) retrieve(parent Node, key []byte) *leaf { var ( value *leaf ) @@ -520,7 +520,7 @@ func (t *Trie) ClearPrefixLimit(prefix []byte, limit uint32) (uint32, bool) { // clearPrefixLimit deletes the keys having the prefix till limit reached and returns updated trie root node, // true if any node in the trie got updated, and next bool returns true if there is no keys left with prefix. -func (t *Trie) clearPrefixLimit(cn node, prefix []byte, limit *uint32) (node, bool, bool) { +func (t *Trie) clearPrefixLimit(cn Node, prefix []byte, limit *uint32) (Node, bool, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { @@ -578,7 +578,7 @@ func (t *Trie) clearPrefixLimit(cn node, prefix []byte, limit *uint32) (node, bo return nil, false, true } -func (t *Trie) deleteNodes(cn node, prefix []byte, limit *uint32) (node, bool) { +func (t *Trie) deleteNodes(cn Node, prefix []byte, limit *uint32) (Node, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { @@ -644,7 +644,7 @@ func (t *Trie) ClearPrefix(prefix []byte) { t.root, _ = t.clearPrefix(t.root, p) } -func (t *Trie) clearPrefix(cn node, prefix []byte) (node, bool) { +func (t *Trie) clearPrefix(cn Node, prefix []byte) (Node, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { case *branch: @@ -700,7 +700,7 @@ func (t *Trie) Delete(key []byte) { t.root, _ = t.delete(t.root, k) } -func (t *Trie) delete(parent node, key []byte) (node, bool) { +func (t *Trie) delete(parent Node, key []byte) (Node, bool) { // Store the current node and return it, if the trie is not updated. switch p := t.maybeUpdateGeneration(parent).(type) { case *branch: @@ -740,8 +740,8 @@ func (t *Trie) delete(parent node, key []byte) (node, bool) { // handleDeletion is called when a value is deleted from a branch // if the updated branch only has 1 child, it should be combined with that child // if the updated branch only has a value, it should be turned into a leaf -func handleDeletion(p *branch, key []byte) node { - var n node = p +func handleDeletion(p *branch, key []byte) Node { + var n Node = p length := lenCommonPrefix(p.key, key) bitmap := p.childrenBitmap() From a10fd38ec9e60145c94b5b65526be32a4a8c2bf1 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 23 Nov 2021 13:40:56 +0000 Subject: [PATCH 03/50] export `branch` --- lib/trie/database.go | 14 +-- lib/trie/hash.go | 6 +- lib/trie/hash_test.go | 84 +++++++++--------- lib/trie/lookup.go | 2 +- lib/trie/node.go | 34 ++++---- lib/trie/node_mock_test.go | 172 ++++++++++++++++++------------------- lib/trie/node_test.go | 72 ++++++++-------- lib/trie/print.go | 2 +- lib/trie/trie.go | 36 ++++---- 9 files changed, 211 insertions(+), 211 deletions(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index d35ed003d9..e7ee40b3c4 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -46,7 +46,7 @@ func (t *Trie) store(db chaindb.Batch, curr Node) error { return err } - if c, ok := curr.(*branch); ok { + if c, ok := curr.(*Branch); ok { for _, child := range c.children { if child == nil { continue @@ -104,7 +104,7 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { // loadProof is a recursive function that will create all the trie paths based // on the mapped proofs slice starting by the root func (t *Trie) loadProof(proof map[string]Node, curr Node) { - c, ok := curr.(*branch) + c, ok := curr.(*Branch) if !ok { return } @@ -149,7 +149,7 @@ func (t *Trie) Load(db chaindb.Database, root common.Hash) error { } func (t *Trie) load(db chaindb.Database, curr Node) error { - if c, ok := curr.(*branch); ok { + if c, ok := curr.(*Branch); ok { for i, child := range c.children { if child == nil { continue @@ -182,7 +182,7 @@ func (t *Trie) load(db chaindb.Database, curr Node) error { // GetNodeHashes return hash of each key of the trie. func (t *Trie) GetNodeHashes(curr Node, keys map[common.Hash]struct{}) error { - if c, ok := curr.(*branch); ok { + if c, ok := curr.(*Branch); ok { for _, child := range c.children { if child == nil { continue @@ -253,7 +253,7 @@ func getFromDB(db chaindb.Database, parent Node, key []byte) ([]byte, error) { var value []byte switch p := parent.(type) { - case *branch: + case *Branch: length := lenCommonPrefix(p.key, key) // found the value at this node @@ -333,7 +333,7 @@ func (t *Trie) writeDirty(db chaindb.Batch, curr Node) error { return err } - if c, ok := curr.(*branch); ok { + if c, ok := curr.(*Branch); ok { for _, child := range c.children { if child == nil { continue @@ -379,7 +379,7 @@ func (t *Trie) getInsertedNodeHashes(curr Node) ([]common.Hash, error) { nodeHash := common.BytesToHash(hash) nodeHashes = append(nodeHashes, nodeHash) - if c, ok := curr.(*branch); ok { + if c, ok := curr.(*Branch); ok { for _, child := range c.children { if child == nil { continue diff --git a/lib/trie/hash.go b/lib/trie/hash.go index eaf7466c87..f668e18fc6 100644 --- a/lib/trie/hash.go +++ b/lib/trie/hash.go @@ -98,7 +98,7 @@ type bytesBuffer interface { // NodeHeader | Extra partial key length | Partial Key | Value func encodeNode(n Node, buffer bytesBuffer, parallel bool) (err error) { switch n := n.(type) { - case *branch: + case *Branch: err := encodeBranch(n, buffer, parallel) if err != nil { return fmt.Errorf("cannot encode branch: %w", err) @@ -131,7 +131,7 @@ func encodeNode(n Node, buffer bytesBuffer, parallel bool) (err error) { // encodeBranch encodes a branch with the encoding specified at the top of this package // to the buffer given. -func encodeBranch(b *branch, buffer io.Writer, parallel bool) (err error) { +func encodeBranch(b *Branch, buffer io.Writer, parallel bool) (err error) { if !b.dirty && b.encoding != nil { _, err = buffer.Write(b.encoding) if err != nil { @@ -266,7 +266,7 @@ func encodeChildrenSequentially(children [16]Node, buffer io.Writer) (err error) func encodeChild(child Node, buffer io.Writer) (err error) { var isNil bool switch impl := child.(type) { - case *branch: + case *Branch: isNil = impl == nil case *leaf: isNil = impl == nil diff --git a/lib/trie/hash_test.go b/lib/trie/hash_test.go index 23265fdf57..8d6949fc00 100644 --- a/lib/trie/hash_test.go +++ b/lib/trie/hash_test.go @@ -27,14 +27,14 @@ func Test_hashNode(t *testing.T) { t.Parallel() testCases := map[string]struct { - n node + n Node writeCall bool write writeCall wrappedErr error errMessage string }{ "node encoding error": { - n: NewMocknode(nil), + n: NewMockNode(nil), wrappedErr: ErrNodeTypeUnsupported, errMessage: "cannot encode node: " + "node type is not supported: " + @@ -137,7 +137,7 @@ func Test_encodeNode(t *testing.T) { t.Parallel() testCases := map[string]struct { - n node + n Node writes []writeCall leafEncodingCopy bool leafBufferLen int @@ -147,7 +147,7 @@ func Test_encodeNode(t *testing.T) { errMessage string }{ "branch error": { - n: &branch{ + n: &Branch{ encoding: []byte{1, 2, 3}, }, writes: []writeCall{ @@ -159,7 +159,7 @@ func Test_encodeNode(t *testing.T) { "test error", }, "branch success": { - n: &branch{ + n: &Branch{ encoding: []byte{1, 2, 3}, }, writes: []writeCall{ @@ -202,7 +202,7 @@ func Test_encodeNode(t *testing.T) { }, }, "unsupported node type": { - n: NewMocknode(nil), + n: NewMockNode(nil), wrappedErr: ErrNodeTypeUnsupported, errMessage: "node type is not supported: *trie.Mocknode", }, @@ -252,14 +252,14 @@ func Test_encodeBranch(t *testing.T) { t.Parallel() testCases := map[string]struct { - branch *branch + branch *Branch writes []writeCall parallel bool wrappedErr error errMessage string }{ "clean branch with encoding": { - branch: &branch{ + branch: &Branch{ encoding: []byte{1, 2, 3}, }, writes: []writeCall{ @@ -269,7 +269,7 @@ func Test_encodeBranch(t *testing.T) { }, }, "write error for clean branch with encoding": { - branch: &branch{ + branch: &Branch{ encoding: []byte{1, 2, 3}, }, writes: []writeCall{ @@ -282,14 +282,14 @@ func Test_encodeBranch(t *testing.T) { errMessage: "cannot write stored encoding to buffer: test error", }, "header encoding error": { - branch: &branch{ + branch: &Branch{ key: make([]byte, 63+(1<<16)), }, wrappedErr: ErrPartialKeyTooBig, errMessage: "cannot encode header: partial key length greater than or equal to 2^16", }, "buffer write error for encoded header": { - branch: &branch{ + branch: &Branch{ key: []byte{1, 2, 3}, value: []byte{100}, }, @@ -303,7 +303,7 @@ func Test_encodeBranch(t *testing.T) { errMessage: "cannot write encoded header to buffer: test error", }, "buffer write error for encoded key": { - branch: &branch{ + branch: &Branch{ key: []byte{1, 2, 3}, value: []byte{100}, }, @@ -320,10 +320,10 @@ func Test_encodeBranch(t *testing.T) { errMessage: "cannot write encoded key to buffer: test error", }, "buffer write error for children bitmap": { - branch: &branch{ + branch: &Branch{ key: []byte{1, 2, 3}, value: []byte{100}, - children: [16]node{ + children: [16]Node{ nil, nil, nil, &leaf{key: []byte{9}}, nil, nil, nil, &leaf{key: []byte{11}}, }, @@ -344,10 +344,10 @@ func Test_encodeBranch(t *testing.T) { errMessage: "cannot write children bitmap to buffer: test error", }, "buffer write error for value": { - branch: &branch{ + branch: &Branch{ key: []byte{1, 2, 3}, value: []byte{100}, - children: [16]node{ + children: [16]Node{ nil, nil, nil, &leaf{key: []byte{9}}, nil, nil, nil, &leaf{key: []byte{11}}, }, @@ -371,10 +371,10 @@ func Test_encodeBranch(t *testing.T) { errMessage: "cannot write encoded value to buffer: test error", }, "buffer write error for children encoded sequentially": { - branch: &branch{ + branch: &Branch{ key: []byte{1, 2, 3}, value: []byte{100}, - children: [16]node{ + children: [16]Node{ nil, nil, nil, &leaf{key: []byte{9}}, nil, nil, nil, &leaf{key: []byte{11}}, }, @@ -403,10 +403,10 @@ func Test_encodeBranch(t *testing.T) { "failed to write child to buffer: test error", }, "buffer write error for children encoded in parallel": { - branch: &branch{ + branch: &Branch{ key: []byte{1, 2, 3}, value: []byte{100}, - children: [16]node{ + children: [16]Node{ nil, nil, nil, &leaf{key: []byte{9}}, nil, nil, nil, &leaf{key: []byte{11}}, }, @@ -439,10 +439,10 @@ func Test_encodeBranch(t *testing.T) { "test error", }, "success with parallel children encoding": { - branch: &branch{ + branch: &Branch{ key: []byte{1, 2, 3}, value: []byte{100}, - children: [16]node{ + children: [16]Node{ nil, nil, nil, &leaf{key: []byte{9}}, nil, nil, nil, &leaf{key: []byte{11}}, }, @@ -470,10 +470,10 @@ func Test_encodeBranch(t *testing.T) { parallel: true, }, "success with sequential children encoding": { - branch: &branch{ + branch: &Branch{ key: []byte{1, 2, 3}, value: []byte{100}, - children: [16]node{ + children: [16]Node{ nil, nil, nil, &leaf{key: []byte{9}}, nil, nil, nil, &leaf{key: []byte{11}}, }, @@ -538,14 +538,14 @@ func Test_encodeChildrenInParallel(t *testing.T) { t.Parallel() testCases := map[string]struct { - children [16]node + children [16]Node writes []writeCall wrappedErr error errMessage string }{ "no children": {}, "first child not nil": { - children: [16]node{ + children: [16]Node{ &leaf{key: []byte{1}}, }, writes: []writeCall{ @@ -555,7 +555,7 @@ func Test_encodeChildrenInParallel(t *testing.T) { }, }, "last child not nil": { - children: [16]node{ + children: [16]Node{ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, @@ -568,7 +568,7 @@ func Test_encodeChildrenInParallel(t *testing.T) { }, }, "first two children not nil": { - children: [16]node{ + children: [16]Node{ &leaf{key: []byte{1}}, &leaf{key: []byte{2}}, }, @@ -582,7 +582,7 @@ func Test_encodeChildrenInParallel(t *testing.T) { }, }, "encoding error": { - children: [16]node{ + children: [16]Node{ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, @@ -638,14 +638,14 @@ func Test_encodeChildrenSequentially(t *testing.T) { t.Parallel() testCases := map[string]struct { - children [16]node + children [16]Node writes []writeCall wrappedErr error errMessage string }{ "no children": {}, "first child not nil": { - children: [16]node{ + children: [16]Node{ &leaf{key: []byte{1}}, }, writes: []writeCall{ @@ -655,7 +655,7 @@ func Test_encodeChildrenSequentially(t *testing.T) { }, }, "last child not nil": { - children: [16]node{ + children: [16]Node{ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, @@ -668,7 +668,7 @@ func Test_encodeChildrenSequentially(t *testing.T) { }, }, "first two children not nil": { - children: [16]node{ + children: [16]Node{ &leaf{key: []byte{1}}, &leaf{key: []byte{2}}, }, @@ -682,7 +682,7 @@ func Test_encodeChildrenSequentially(t *testing.T) { }, }, "encoding error": { - children: [16]node{ + children: [16]Node{ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, @@ -740,7 +740,7 @@ func Test_encodeChild(t *testing.T) { t.Parallel() testCases := map[string]struct { - child node + child Node writeCall bool write writeCall wrappedErr error @@ -751,7 +751,7 @@ func Test_encodeChild(t *testing.T) { child: (*leaf)(nil), }, "nil branch": { - child: (*branch)(nil), + child: (*Branch)(nil), }, "empty leaf child": { child: &leaf{}, @@ -761,14 +761,14 @@ func Test_encodeChild(t *testing.T) { }, }, "empty branch child": { - child: &branch{}, + child: &Branch{}, writeCall: true, write: writeCall{ written: []byte{12, 128, 0, 0}, }, }, "buffer write error": { - child: &branch{}, + child: &Branch{}, writeCall: true, write: writeCall{ written: []byte{12, 128, 0, 0}, @@ -788,10 +788,10 @@ func Test_encodeChild(t *testing.T) { }, }, "branch child": { - child: &branch{ + child: &Branch{ key: []byte{1}, value: []byte{2}, - children: [16]node{ + children: [16]Node{ nil, nil, &leaf{ key: []byte{5}, value: []byte{6}, @@ -835,13 +835,13 @@ func Test_encodeAndHash(t *testing.T) { t.Parallel() testCases := map[string]struct { - n node + n Node b []byte wrappedErr error errMessage string }{ "node encoding error": { - n: NewMocknode(nil), + n: NewMockNode(nil), wrappedErr: ErrNodeTypeUnsupported, errMessage: "cannot hash node: " + "cannot encode node: " + diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index 395604c04d..0e38e19ea8 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -20,7 +20,7 @@ func find(parent Node, key []byte, recorder *recorder) error { recorder.record(hash, enc) - b, ok := parent.(*branch) + b, ok := parent.(*Branch) if !ok { return nil } diff --git a/lib/trie/node.go b/lib/trie/node.go index 90578473f8..17ff14f727 100644 --- a/lib/trie/node.go +++ b/lib/trie/node.go @@ -54,7 +54,7 @@ type Node interface { } type ( - branch struct { + Branch struct { key []byte // partial key children [16]Node value []byte @@ -76,7 +76,7 @@ type ( } ) -func (b *branch) SetGeneration(generation uint64) { +func (b *Branch) SetGeneration(generation uint64) { b.generation = generation } @@ -84,11 +84,11 @@ func (l *leaf) SetGeneration(generation uint64) { l.generation = generation } -func (b *branch) Copy() Node { +func (b *Branch) Copy() Node { b.RLock() defer b.RUnlock() - cpy := &branch{ + cpy := &Branch{ key: make([]byte, len(b.key)), children: b.children, // copy interface pointers value: nil, @@ -132,7 +132,7 @@ func (l *leaf) Copy() Node { return cpy } -func (b *branch) SetEncodingAndHash(enc, hash []byte) { +func (b *Branch) SetEncodingAndHash(enc, hash []byte) { b.encoding = enc b.hash = hash } @@ -145,11 +145,11 @@ func (l *leaf) SetEncodingAndHash(enc, hash []byte) { l.hash = hash } -func (b *branch) GetHash() []byte { +func (b *Branch) GetHash() []byte { return b.hash } -func (b *branch) GetGeneration() uint64 { +func (b *Branch) GetGeneration() uint64 { return b.generation } @@ -161,7 +161,7 @@ func (l *leaf) GetHash() []byte { return l.hash } -func (b *branch) String() string { +func (b *Branch) String() string { if len(b.value) > 1024 { return fmt.Sprintf( "branch key=%x childrenBitmap=%16b value (hashed)=%x dirty=%v", @@ -177,7 +177,7 @@ func (l *leaf) String() string { return fmt.Sprintf("leaf key=%x value=%v dirty=%v", l.key, l.value, l.dirty) } -func (b *branch) childrenBitmap() uint16 { +func (b *Branch) childrenBitmap() uint16 { var bitmap uint16 var i uint for i = 0; i < 16; i++ { @@ -188,7 +188,7 @@ func (b *branch) childrenBitmap() uint16 { return bitmap } -func (b *branch) numChildren() int { +func (b *Branch) numChildren() int { var i, count int for i = 0; i < 16; i++ { if b.children[i] != nil { @@ -202,7 +202,7 @@ func (l *leaf) IsDirty() bool { return l.dirty } -func (b *branch) IsDirty() bool { +func (b *Branch) IsDirty() bool { return b.dirty } @@ -210,7 +210,7 @@ func (l *leaf) SetDirty(dirty bool) { l.dirty = dirty } -func (b *branch) SetDirty(dirty bool) { +func (b *Branch) SetDirty(dirty bool) { b.dirty = dirty } @@ -218,11 +218,11 @@ func (l *leaf) SetKey(key []byte) { l.key = key } -func (b *branch) SetKey(key []byte) { +func (b *Branch) SetKey(key []byte) { b.key = key } -func (b *branch) EncodeAndHash() (encoding, hash []byte, err error) { +func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) { if !b.dirty && b.encoding != nil && b.hash != nil { return b.encoding, b.hash, nil } @@ -324,7 +324,7 @@ func decode(r io.Reader) (Node, error) { err := l.Decode(r, header) return l, err } else if nodeType == 2 || nodeType == 3 { - b := new(branch) + b := new(Branch) err := b.Decode(r, header) return b, err } @@ -335,7 +335,7 @@ func decode(r io.Reader) (Node, error) { // Decode decodes a byte array with the encoding specified at the top of this package into a branch node // Note that since the encoded branch stores the hash of the children nodes, we aren't able to reconstruct the child // nodes from the encoding. This function instead stubs where the children are known to be with an empty leaf. -func (b *branch) Decode(r io.Reader, header byte) (err error) { +func (b *Branch) Decode(r io.Reader, header byte) (err error) { if header == 0 { header, err = readByte(r) if err != nil { @@ -427,7 +427,7 @@ func (l *leaf) Decode(r io.Reader, header byte) (err error) { return nil } -func (b *branch) header() ([]byte, error) { +func (b *Branch) header() ([]byte, error) { var header byte if b.value == nil { header = 2 << 6 diff --git a/lib/trie/node_mock_test.go b/lib/trie/node_mock_test.go index 0235377728..d381d1a157 100644 --- a/lib/trie/node_mock_test.go +++ b/lib/trie/node_mock_test.go @@ -11,173 +11,173 @@ import ( gomock "github.com/golang/mock/gomock" ) -// Mocknode is a mock of node interface. -type Mocknode struct { +// MockNode is a mock of Node interface. +type MockNode struct { ctrl *gomock.Controller - recorder *MocknodeMockRecorder + recorder *MockNodeMockRecorder } -// MocknodeMockRecorder is the mock recorder for Mocknode. -type MocknodeMockRecorder struct { - mock *Mocknode +// MockNodeMockRecorder is the mock recorder for MockNode. +type MockNodeMockRecorder struct { + mock *MockNode } -// NewMocknode creates a new mock instance. -func NewMocknode(ctrl *gomock.Controller) *Mocknode { - mock := &Mocknode{ctrl: ctrl} - mock.recorder = &MocknodeMockRecorder{mock} +// NewMockNode creates a new mock instance. +func NewMockNode(ctrl *gomock.Controller) *MockNode { + mock := &MockNode{ctrl: ctrl} + mock.recorder = &MockNodeMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *Mocknode) EXPECT() *MocknodeMockRecorder { +func (m *MockNode) EXPECT() *MockNodeMockRecorder { return m.recorder } -// String mocks base method. -func (m *Mocknode) String() string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "String") - ret0, _ := ret[0].(string) - return ret0 -} - -// String indicates an expected call of String. -func (mr *MocknodeMockRecorder) String() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*Mocknode)(nil).String)) -} - -// copy mocks base method. -func (m *Mocknode) copy() node { +// Copy mocks base method. +func (m *MockNode) Copy() Node { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "copy") - ret0, _ := ret[0].(node) + ret := m.ctrl.Call(m, "Copy") + ret0, _ := ret[0].(Node) return ret0 } -// copy indicates an expected call of copy. -func (mr *MocknodeMockRecorder) copy() *gomock.Call { +// Copy indicates an expected call of Copy. +func (mr *MockNodeMockRecorder) Copy() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "copy", reflect.TypeOf((*Mocknode)(nil).copy)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Copy", reflect.TypeOf((*MockNode)(nil).Copy)) } -// decode mocks base method. -func (m *Mocknode) decode(r io.Reader, h byte) error { +// Decode mocks base method. +func (m *MockNode) Decode(r io.Reader, h byte) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "decode", r, h) + ret := m.ctrl.Call(m, "Decode", r, h) ret0, _ := ret[0].(error) return ret0 } -// decode indicates an expected call of decode. -func (mr *MocknodeMockRecorder) decode(r, h interface{}) *gomock.Call { +// Decode indicates an expected call of Decode. +func (mr *MockNodeMockRecorder) Decode(r, h interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "decode", reflect.TypeOf((*Mocknode)(nil).decode), r, h) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decode", reflect.TypeOf((*MockNode)(nil).Decode), r, h) } -// encodeAndHash mocks base method. -func (m *Mocknode) encodeAndHash() ([]byte, []byte, error) { +// EncodeAndHash mocks base method. +func (m *MockNode) EncodeAndHash() ([]byte, []byte, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "encodeAndHash") + ret := m.ctrl.Call(m, "EncodeAndHash") ret0, _ := ret[0].([]byte) ret1, _ := ret[1].([]byte) ret2, _ := ret[2].(error) return ret0, ret1, ret2 } -// encodeAndHash indicates an expected call of encodeAndHash. -func (mr *MocknodeMockRecorder) encodeAndHash() *gomock.Call { +// EncodeAndHash indicates an expected call of EncodeAndHash. +func (mr *MockNodeMockRecorder) EncodeAndHash() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "encodeAndHash", reflect.TypeOf((*Mocknode)(nil).encodeAndHash)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncodeAndHash", reflect.TypeOf((*MockNode)(nil).EncodeAndHash)) } -// getGeneration mocks base method. -func (m *Mocknode) getGeneration() uint64 { +// GetGeneration mocks base method. +func (m *MockNode) GetGeneration() uint64 { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getGeneration") + ret := m.ctrl.Call(m, "GetGeneration") ret0, _ := ret[0].(uint64) return ret0 } -// getGeneration indicates an expected call of getGeneration. -func (mr *MocknodeMockRecorder) getGeneration() *gomock.Call { +// GetGeneration indicates an expected call of GetGeneration. +func (mr *MockNodeMockRecorder) GetGeneration() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getGeneration", reflect.TypeOf((*Mocknode)(nil).getGeneration)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGeneration", reflect.TypeOf((*MockNode)(nil).GetGeneration)) } -// getHash mocks base method. -func (m *Mocknode) getHash() []byte { +// GetHash mocks base method. +func (m *MockNode) GetHash() []byte { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "getHash") + ret := m.ctrl.Call(m, "GetHash") ret0, _ := ret[0].([]byte) return ret0 } -// getHash indicates an expected call of getHash. -func (mr *MocknodeMockRecorder) getHash() *gomock.Call { +// GetHash indicates an expected call of GetHash. +func (mr *MockNodeMockRecorder) GetHash() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getHash", reflect.TypeOf((*Mocknode)(nil).getHash)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHash", reflect.TypeOf((*MockNode)(nil).GetHash)) } -// isDirty mocks base method. -func (m *Mocknode) isDirty() bool { +// IsDirty mocks base method. +func (m *MockNode) IsDirty() bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "isDirty") + ret := m.ctrl.Call(m, "IsDirty") ret0, _ := ret[0].(bool) return ret0 } -// isDirty indicates an expected call of isDirty. -func (mr *MocknodeMockRecorder) isDirty() *gomock.Call { +// IsDirty indicates an expected call of IsDirty. +func (mr *MockNodeMockRecorder) IsDirty() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "isDirty", reflect.TypeOf((*Mocknode)(nil).isDirty)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsDirty", reflect.TypeOf((*MockNode)(nil).IsDirty)) } -// setDirty mocks base method. -func (m *Mocknode) setDirty(dirty bool) { +// SetDirty mocks base method. +func (m *MockNode) SetDirty(dirty bool) { m.ctrl.T.Helper() - m.ctrl.Call(m, "setDirty", dirty) + m.ctrl.Call(m, "SetDirty", dirty) } -// setDirty indicates an expected call of setDirty. -func (mr *MocknodeMockRecorder) setDirty(dirty interface{}) *gomock.Call { +// SetDirty indicates an expected call of SetDirty. +func (mr *MockNodeMockRecorder) SetDirty(dirty interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setDirty", reflect.TypeOf((*Mocknode)(nil).setDirty), dirty) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDirty", reflect.TypeOf((*MockNode)(nil).SetDirty), dirty) } -// setEncodingAndHash mocks base method. -func (m *Mocknode) setEncodingAndHash(arg0, arg1 []byte) { +// SetEncodingAndHash mocks base method. +func (m *MockNode) SetEncodingAndHash(arg0, arg1 []byte) { m.ctrl.T.Helper() - m.ctrl.Call(m, "setEncodingAndHash", arg0, arg1) + m.ctrl.Call(m, "SetEncodingAndHash", arg0, arg1) } -// setEncodingAndHash indicates an expected call of setEncodingAndHash. -func (mr *MocknodeMockRecorder) setEncodingAndHash(arg0, arg1 interface{}) *gomock.Call { +// SetEncodingAndHash indicates an expected call of SetEncodingAndHash. +func (mr *MockNodeMockRecorder) SetEncodingAndHash(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setEncodingAndHash", reflect.TypeOf((*Mocknode)(nil).setEncodingAndHash), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetEncodingAndHash", reflect.TypeOf((*MockNode)(nil).SetEncodingAndHash), arg0, arg1) } -// setGeneration mocks base method. -func (m *Mocknode) setGeneration(arg0 uint64) { +// SetGeneration mocks base method. +func (m *MockNode) SetGeneration(arg0 uint64) { m.ctrl.T.Helper() - m.ctrl.Call(m, "setGeneration", arg0) + m.ctrl.Call(m, "SetGeneration", arg0) } -// setGeneration indicates an expected call of setGeneration. -func (mr *MocknodeMockRecorder) setGeneration(arg0 interface{}) *gomock.Call { +// SetGeneration indicates an expected call of SetGeneration. +func (mr *MockNodeMockRecorder) SetGeneration(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setGeneration", reflect.TypeOf((*Mocknode)(nil).setGeneration), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetGeneration", reflect.TypeOf((*MockNode)(nil).SetGeneration), arg0) } -// setKey mocks base method. -func (m *Mocknode) setKey(key []byte) { +// SetKey mocks base method. +func (m *MockNode) SetKey(key []byte) { m.ctrl.T.Helper() - m.ctrl.Call(m, "setKey", key) + m.ctrl.Call(m, "SetKey", key) } -// setKey indicates an expected call of setKey. -func (mr *MocknodeMockRecorder) setKey(key interface{}) *gomock.Call { +// SetKey indicates an expected call of SetKey. +func (mr *MockNodeMockRecorder) SetKey(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetKey", reflect.TypeOf((*MockNode)(nil).SetKey), key) +} + +// String mocks base method. +func (m *MockNode) String() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "String") + ret0, _ := ret[0].(string) + return ret0 +} + +// String indicates an expected call of String. +func (mr *MockNodeMockRecorder) String() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "setKey", reflect.TypeOf((*Mocknode)(nil).setKey), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockNode)(nil).String)) } diff --git a/lib/trie/node_test.go b/lib/trie/node_test.go index 01199f5563..be380e10be 100644 --- a/lib/trie/node_test.go +++ b/lib/trie/node_test.go @@ -36,7 +36,7 @@ func generateRand(size int) [][]byte { } func TestChildrenBitmap(t *testing.T) { - b := &branch{children: [16]Node{}} + b := &Branch{children: [16]Node{}} res := b.childrenBitmap() if res != 0 { t.Errorf("Fail to get children bitmap: got %x expected %x", res, 1) @@ -63,27 +63,27 @@ func TestChildrenBitmap(t *testing.T) { func TestBranchHeader(t *testing.T) { tests := []struct { - br *branch + br *Branch header []byte }{ - {&branch{key: nil, children: [16]Node{}, value: nil}, []byte{0x80}}, - {&branch{key: []byte{0x00}, children: [16]Node{}, value: nil}, []byte{0x81}}, - {&branch{key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]Node{}, value: nil}, []byte{0x84}}, - - {&branch{key: nil, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc0}}, - {&branch{key: []byte{0x00}, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc1}}, - {&branch{key: []byte{0x00, 0x00}, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc2}}, - {&branch{key: []byte{0x00, 0x00, 0xf}, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc3}}, - - {&branch{key: byteArray(62), children: [16]Node{}, value: nil}, []byte{0xbe}}, - {&branch{key: byteArray(62), children: [16]Node{}, value: []byte{0x00}}, []byte{0xfe}}, - {&branch{key: byteArray(63), children: [16]Node{}, value: nil}, []byte{0xbf, 0}}, - {&branch{key: byteArray(64), children: [16]Node{}, value: nil}, []byte{0xbf, 1}}, - {&branch{key: byteArray(64), children: [16]Node{}, value: []byte{0x01}}, []byte{0xff, 1}}, - - {&branch{key: byteArray(317), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 254}}, - {&branch{key: byteArray(318), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 255, 0}}, - {&branch{key: byteArray(573), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 255, 255, 0}}, + {&Branch{key: nil, children: [16]Node{}, value: nil}, []byte{0x80}}, + {&Branch{key: []byte{0x00}, children: [16]Node{}, value: nil}, []byte{0x81}}, + {&Branch{key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]Node{}, value: nil}, []byte{0x84}}, + + {&Branch{key: nil, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc0}}, + {&Branch{key: []byte{0x00}, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc1}}, + {&Branch{key: []byte{0x00, 0x00}, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc2}}, + {&Branch{key: []byte{0x00, 0x00, 0xf}, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc3}}, + + {&Branch{key: byteArray(62), children: [16]Node{}, value: nil}, []byte{0xbe}}, + {&Branch{key: byteArray(62), children: [16]Node{}, value: []byte{0x00}}, []byte{0xfe}}, + {&Branch{key: byteArray(63), children: [16]Node{}, value: nil}, []byte{0xbf, 0}}, + {&Branch{key: byteArray(64), children: [16]Node{}, value: nil}, []byte{0xbf, 1}}, + {&Branch{key: byteArray(64), children: [16]Node{}, value: []byte{0x01}}, []byte{0xff, 1}}, + + {&Branch{key: byteArray(317), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 254}}, + {&Branch{key: byteArray(318), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 255, 0}}, + {&Branch{key: byteArray(573), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 255, 255, 0}}, } for _, test := range tests { @@ -99,10 +99,10 @@ func TestBranchHeader(t *testing.T) { func TestFailingPk(t *testing.T) { tests := []struct { - br *branch + br *Branch header []byte }{ - {&branch{key: byteArray(2 << 16), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 254}}, + {&Branch{key: byteArray(2 << 16), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 254}}, } for _, test := range tests { @@ -147,7 +147,7 @@ func TestBranchEncode(t *testing.T) { randVals := generateRand(101) for i, testKey := range randKeys { - b := &branch{key: testKey, children: [16]Node{}, value: randVals[i]} + b := &Branch{key: testKey, children: [16]Node{}, value: randVals[i]} expected := bytes.NewBuffer(nil) header, err := b.header() @@ -234,7 +234,7 @@ func TestEncodeRoot(t *testing.T) { } func TestBranchDecode(t *testing.T) { - tests := []*branch{ + tests := []*Branch{ {key: []byte{}, children: [16]Node{}, value: nil}, {key: []byte{0x00}, children: [16]Node{}, value: nil}, {key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]Node{}, value: nil}, @@ -265,7 +265,7 @@ func TestBranchDecode(t *testing.T) { err := encodeBranch(test, buffer, parallel) require.NoError(t, err) - res := new(branch) + res := new(Branch) err = res.Decode(buffer, 0) require.NoError(t, err) @@ -305,13 +305,13 @@ func TestLeafDecode(t *testing.T) { func TestDecode(t *testing.T) { tests := []Node{ - &branch{key: []byte{}, children: [16]Node{}, value: nil}, - &branch{key: []byte{0x00}, children: [16]Node{}, value: nil}, - &branch{key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]Node{}, value: nil}, - &branch{key: []byte{}, children: [16]Node{}, value: []byte{0x01}}, - &branch{key: []byte{}, children: [16]Node{&leaf{}}, value: []byte{0x01}}, - &branch{key: []byte{}, children: [16]Node{&leaf{}, nil, &leaf{}}, value: []byte{0x01}}, - &branch{ + &Branch{key: []byte{}, children: [16]Node{}, value: nil}, + &Branch{key: []byte{0x00}, children: [16]Node{}, value: nil}, + &Branch{key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]Node{}, value: nil}, + &Branch{key: []byte{}, children: [16]Node{}, value: []byte{0x01}}, + &Branch{key: []byte{}, children: [16]Node{&leaf{}}, value: []byte{0x01}}, + &Branch{key: []byte{}, children: [16]Node{&leaf{}, nil, &leaf{}}, value: []byte{0x01}}, + &Branch{ key: []byte{}, children: [16]Node{ &leaf{}, nil, &leaf{}, nil, @@ -340,10 +340,10 @@ func TestDecode(t *testing.T) { require.NoError(t, err) switch n := test.(type) { - case *branch: - require.Equal(t, n.key, res.(*branch).key) - require.Equal(t, n.childrenBitmap(), res.(*branch).childrenBitmap()) - require.Equal(t, n.value, res.(*branch).value) + case *Branch: + require.Equal(t, n.key, res.(*Branch).key) + require.Equal(t, n.childrenBitmap(), res.(*Branch).childrenBitmap()) + require.Equal(t, n.value, res.(*Branch).value) case *leaf: require.Equal(t, n.key, res.(*leaf).key) require.Equal(t, n.value, res.(*leaf).value) diff --git a/lib/trie/print.go b/lib/trie/print.go index a0c8a557e3..c8824a5ade 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -25,7 +25,7 @@ func (t *Trie) String() string { func (t *Trie) string(tree gotree.Tree, curr Node, idx int) { switch c := curr.(type) { - case *branch: + case *Branch: buffer := encodingBufferPool.Get().(*bytes.Buffer) buffer.Reset() diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 66080fe784..2881f3a781 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -142,7 +142,7 @@ func (t *Trie) Entries() map[string][]byte { func (t *Trie) entries(current Node, prefix []byte, kv map[string][]byte) map[string][]byte { switch c := current.(type) { - case *branch: + case *Branch: if c.value != nil { kv[string(nibblesToKeyLE(append(prefix, c.key...)))] = c.value } @@ -171,7 +171,7 @@ func (t *Trie) NextKey(key []byte) []byte { func (t *Trie) nextKey(curr Node, prefix, key []byte) []byte { switch c := curr.(type) { - case *branch: + case *Branch: fullKey := append(prefix, c.key...) var cmp int if len(key) < len(fullKey) { @@ -258,7 +258,7 @@ func (t *Trie) tryPut(key, value []byte) { // insert attempts to insert a key with value into the trie func (t *Trie) insert(parent Node, key []byte, value Node) Node { switch p := t.maybeUpdateGeneration(parent).(type) { - case *branch: + case *Branch: n := t.updateBranch(p, key, value) if p != nil && n != nil && n.IsDirty() { @@ -282,7 +282,7 @@ func (t *Trie) insert(parent Node, key []byte, value Node) Node { length := lenCommonPrefix(key, p.key) // need to convert this leaf into a branch - br := &branch{key: key[:length], dirty: true, generation: t.generation} + br := &Branch{key: key[:length], dirty: true, generation: t.generation} parentKey := p.key // value goes at this branch @@ -324,7 +324,7 @@ func (t *Trie) insert(parent Node, key []byte, value Node) Node { // updateBranch attempts to add the value node to a branch // inserts the value node as the branch's child at the index that's // the first nibble of the key -func (t *Trie) updateBranch(p *branch, key []byte, value Node) (n Node) { +func (t *Trie) updateBranch(p *Branch, key []byte, value Node) (n Node) { length := lenCommonPrefix(key, p.key) // whole parent key matches @@ -333,7 +333,7 @@ func (t *Trie) updateBranch(p *branch, key []byte, value Node) (n Node) { if bytes.Equal(key, p.key) { p.SetDirty(true) switch v := value.(type) { - case *branch: + case *Branch: p.value = v.value case *leaf: p.value = v.value @@ -342,7 +342,7 @@ func (t *Trie) updateBranch(p *branch, key []byte, value Node) (n Node) { } switch c := p.children[key[length]].(type) { - case *branch, *leaf: + case *Branch, *leaf: n = t.insert(c, key[length+1:], value) p.children[key[length]] = n n.SetDirty(true) @@ -361,7 +361,7 @@ func (t *Trie) updateBranch(p *branch, key []byte, value Node) (n Node) { // we need to branch out at the point where the keys diverge // update partial keys, new branch has key up to matching length - br := &branch{key: key[:length], dirty: true, generation: t.generation} + br := &Branch{key: key[:length], dirty: true, generation: t.generation} parentIndex := p.key[length] br.children[parentIndex] = t.insert(nil, p.key[length+1:], p) @@ -408,7 +408,7 @@ func (t *Trie) GetKeysWithPrefix(prefix []byte) [][]byte { func (t *Trie) getKeysWithPrefix(parent Node, prefix, key []byte, keys [][]byte) [][]byte { switch p := parent.(type) { - case *branch: + case *Branch: length := lenCommonPrefix(p.key, key) if bytes.Equal(p.key[:length], key) || len(key) == 0 { @@ -439,7 +439,7 @@ func (t *Trie) getKeysWithPrefix(parent Node, prefix, key []byte, keys [][]byte) // it uses the prefix to determine the entire key func (t *Trie) addAllKeys(parent Node, prefix []byte, keys [][]byte) [][]byte { switch p := parent.(type) { - case *branch: + case *Branch: if p.value != nil { keys = append(keys, nibblesToKeyLE(append(prefix, p.key...))) } @@ -477,7 +477,7 @@ func (t *Trie) retrieve(parent Node, key []byte) *leaf { ) switch p := parent.(type) { - case *branch: + case *Branch: length := lenCommonPrefix(p.key, key) // found the value at this node @@ -524,7 +524,7 @@ func (t *Trie) clearPrefixLimit(cn Node, prefix []byte, limit *uint32) (Node, bo curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { - case *branch: + case *Branch: length := lenCommonPrefix(c.key, prefix) if length == len(prefix) { n, _ := t.deleteNodes(c, []byte{}, limit) @@ -588,7 +588,7 @@ func (t *Trie) deleteNodes(cn Node, prefix []byte, limit *uint32) (Node, bool) { } *limit-- return nil, true - case *branch: + case *Branch: if len(c.key) != 0 { prefix = append(prefix, c.key...) } @@ -647,7 +647,7 @@ func (t *Trie) ClearPrefix(prefix []byte) { func (t *Trie) clearPrefix(cn Node, prefix []byte) (Node, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { - case *branch: + case *Branch: length := lenCommonPrefix(c.key, prefix) if length == len(prefix) { @@ -703,7 +703,7 @@ func (t *Trie) Delete(key []byte) { func (t *Trie) delete(parent Node, key []byte) (Node, bool) { // Store the current node and return it, if the trie is not updated. switch p := t.maybeUpdateGeneration(parent).(type) { - case *branch: + case *Branch: length := lenCommonPrefix(p.key, key) if bytes.Equal(p.key, key) || len(key) == 0 { @@ -740,7 +740,7 @@ func (t *Trie) delete(parent Node, key []byte) (Node, bool) { // handleDeletion is called when a value is deleted from a branch // if the updated branch only has 1 child, it should be combined with that child // if the updated branch only has a value, it should be turned into a leaf -func handleDeletion(p *branch, key []byte) Node { +func handleDeletion(p *Branch, key []byte) Node { var n Node = p length := lenCommonPrefix(p.key, key) bitmap := p.childrenBitmap() @@ -763,8 +763,8 @@ func handleDeletion(p *branch, key []byte) Node { switch c := child.(type) { case *leaf: n = &leaf{key: append(append(p.key, []byte{byte(i)}...), c.key...), value: c.value} - case *branch: - br := new(branch) + case *Branch: + br := new(Branch) br.key = append(p.key, append([]byte{byte(i)}, c.key...)...) // adopt the grandchildren From b6fc29e26a44f1829b1ee482e303468b7219ba9d Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 23 Nov 2021 13:41:24 +0000 Subject: [PATCH 04/50] export `leaf` struct --- lib/trie/database.go | 6 +-- lib/trie/hash.go | 6 +-- lib/trie/hash_test.go | 88 +++++++++++++++++++++---------------------- lib/trie/node.go | 32 ++++++++-------- lib/trie/node_test.go | 70 +++++++++++++++++----------------- lib/trie/print.go | 2 +- lib/trie/trie.go | 50 ++++++++++++------------ lib/trie/trie_test.go | 4 +- 8 files changed, 129 insertions(+), 129 deletions(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index e7ee40b3c4..f53c0c139c 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -158,7 +158,7 @@ func (t *Trie) load(db chaindb.Database, curr Node) error { hash := child.GetHash() enc, err := db.Get(hash) if err != nil { - return fmt.Errorf("failed to find node key=%x index=%d: %w", child.(*leaf).hash, i, err) + return fmt.Errorf("failed to find node key=%x index=%d: %w", child.(*Leaf).hash, i, err) } child, err = decodeBytes(enc) @@ -271,7 +271,7 @@ func getFromDB(db chaindb.Database, parent Node, key []byte) ([]byte, error) { } // load child with potential value - enc, err := db.Get(p.children[key[length]].(*leaf).hash) + enc, err := db.Get(p.children[key[length]].(*Leaf).hash) if err != nil { return nil, fmt.Errorf("failed to find node in database: %w", err) } @@ -285,7 +285,7 @@ func getFromDB(db chaindb.Database, parent Node, key []byte) ([]byte, error) { if err != nil { return nil, err } - case *leaf: + case *Leaf: if bytes.Equal(p.key, key) { return p.value, nil } diff --git a/lib/trie/hash.go b/lib/trie/hash.go index f668e18fc6..921436c5f6 100644 --- a/lib/trie/hash.go +++ b/lib/trie/hash.go @@ -104,7 +104,7 @@ func encodeNode(n Node, buffer bytesBuffer, parallel bool) (err error) { return fmt.Errorf("cannot encode branch: %w", err) } return nil - case *leaf: + case *Leaf: err := encodeLeaf(n, buffer) if err != nil { return fmt.Errorf("cannot encode leaf: %w", err) @@ -268,7 +268,7 @@ func encodeChild(child Node, buffer io.Writer) (err error) { switch impl := child.(type) { case *Branch: isNil = impl == nil - case *leaf: + case *Leaf: isNil = impl == nil default: isNil = child == nil @@ -309,7 +309,7 @@ func encodeAndHash(n Node) (b []byte, err error) { // encodeLeaf encodes a leaf to the buffer given, with the encoding // specified at the top of this package. -func encodeLeaf(l *leaf, buffer io.Writer) (err error) { +func encodeLeaf(l *Leaf, buffer io.Writer) (err error) { l.encodingMu.RLock() defer l.encodingMu.RUnlock() if !l.dirty && l.encoding != nil { diff --git a/lib/trie/hash_test.go b/lib/trie/hash_test.go index 8d6949fc00..2e2ea04317 100644 --- a/lib/trie/hash_test.go +++ b/lib/trie/hash_test.go @@ -38,10 +38,10 @@ func Test_hashNode(t *testing.T) { wrappedErr: ErrNodeTypeUnsupported, errMessage: "cannot encode node: " + "node type is not supported: " + - "*trie.Mocknode", + "*trie.MockNode", }, "small leaf buffer write error": { - n: &leaf{ + n: &Leaf{ encoding: []byte{1, 2, 3}, }, writeCall: true, @@ -54,7 +54,7 @@ func Test_hashNode(t *testing.T) { "test error", }, "small leaf success": { - n: &leaf{ + n: &Leaf{ encoding: []byte{1, 2, 3}, }, writeCall: true, @@ -63,7 +63,7 @@ func Test_hashNode(t *testing.T) { }, }, "leaf hash sum buffer write error": { - n: &leaf{ + n: &Leaf{ encoding: []byte{ 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, @@ -87,7 +87,7 @@ func Test_hashNode(t *testing.T) { "test error", }, "leaf hash sum success": { - n: &leaf{ + n: &Leaf{ encoding: []byte{ 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, @@ -167,7 +167,7 @@ func Test_encodeNode(t *testing.T) { }, }, "leaf error": { - n: &leaf{ + n: &Leaf{ encoding: []byte{1, 2, 3}, }, writes: []writeCall{ @@ -179,7 +179,7 @@ func Test_encodeNode(t *testing.T) { "test error", }, "leaf success": { - n: &leaf{ + n: &Leaf{ encoding: []byte{1, 2, 3}, }, writes: []writeCall{ @@ -204,7 +204,7 @@ func Test_encodeNode(t *testing.T) { "unsupported node type": { n: NewMockNode(nil), wrappedErr: ErrNodeTypeUnsupported, - errMessage: "node type is not supported: *trie.Mocknode", + errMessage: "node type is not supported: *trie.MockNode", }, } @@ -324,8 +324,8 @@ func Test_encodeBranch(t *testing.T) { key: []byte{1, 2, 3}, value: []byte{100}, children: [16]Node{ - nil, nil, nil, &leaf{key: []byte{9}}, - nil, nil, nil, &leaf{key: []byte{11}}, + nil, nil, nil, &Leaf{key: []byte{9}}, + nil, nil, nil, &Leaf{key: []byte{11}}, }, }, writes: []writeCall{ @@ -348,8 +348,8 @@ func Test_encodeBranch(t *testing.T) { key: []byte{1, 2, 3}, value: []byte{100}, children: [16]Node{ - nil, nil, nil, &leaf{key: []byte{9}}, - nil, nil, nil, &leaf{key: []byte{11}}, + nil, nil, nil, &Leaf{key: []byte{9}}, + nil, nil, nil, &Leaf{key: []byte{11}}, }, }, writes: []writeCall{ @@ -375,8 +375,8 @@ func Test_encodeBranch(t *testing.T) { key: []byte{1, 2, 3}, value: []byte{100}, children: [16]Node{ - nil, nil, nil, &leaf{key: []byte{9}}, - nil, nil, nil, &leaf{key: []byte{11}}, + nil, nil, nil, &Leaf{key: []byte{9}}, + nil, nil, nil, &Leaf{key: []byte{11}}, }, }, writes: []writeCall{ @@ -407,8 +407,8 @@ func Test_encodeBranch(t *testing.T) { key: []byte{1, 2, 3}, value: []byte{100}, children: [16]Node{ - nil, nil, nil, &leaf{key: []byte{9}}, - nil, nil, nil, &leaf{key: []byte{11}}, + nil, nil, nil, &Leaf{key: []byte{9}}, + nil, nil, nil, &Leaf{key: []byte{11}}, }, }, writes: []writeCall{ @@ -443,8 +443,8 @@ func Test_encodeBranch(t *testing.T) { key: []byte{1, 2, 3}, value: []byte{100}, children: [16]Node{ - nil, nil, nil, &leaf{key: []byte{9}}, - nil, nil, nil, &leaf{key: []byte{11}}, + nil, nil, nil, &Leaf{key: []byte{9}}, + nil, nil, nil, &Leaf{key: []byte{11}}, }, }, writes: []writeCall{ @@ -474,8 +474,8 @@ func Test_encodeBranch(t *testing.T) { key: []byte{1, 2, 3}, value: []byte{100}, children: [16]Node{ - nil, nil, nil, &leaf{key: []byte{9}}, - nil, nil, nil, &leaf{key: []byte{11}}, + nil, nil, nil, &Leaf{key: []byte{9}}, + nil, nil, nil, &Leaf{key: []byte{11}}, }, }, writes: []writeCall{ @@ -546,7 +546,7 @@ func Test_encodeChildrenInParallel(t *testing.T) { "no children": {}, "first child not nil": { children: [16]Node{ - &leaf{key: []byte{1}}, + &Leaf{key: []byte{1}}, }, writes: []writeCall{ { @@ -559,7 +559,7 @@ func Test_encodeChildrenInParallel(t *testing.T) { nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - &leaf{key: []byte{1}}, + &Leaf{key: []byte{1}}, }, writes: []writeCall{ { @@ -569,8 +569,8 @@ func Test_encodeChildrenInParallel(t *testing.T) { }, "first two children not nil": { children: [16]Node{ - &leaf{key: []byte{1}}, - &leaf{key: []byte{2}}, + &Leaf{key: []byte{1}}, + &Leaf{key: []byte{2}}, }, writes: []writeCall{ { @@ -586,7 +586,7 @@ func Test_encodeChildrenInParallel(t *testing.T) { nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - &leaf{ + &Leaf{ key: []byte{1}, }, nil, nil, nil, nil, @@ -646,7 +646,7 @@ func Test_encodeChildrenSequentially(t *testing.T) { "no children": {}, "first child not nil": { children: [16]Node{ - &leaf{key: []byte{1}}, + &Leaf{key: []byte{1}}, }, writes: []writeCall{ { @@ -659,7 +659,7 @@ func Test_encodeChildrenSequentially(t *testing.T) { nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - &leaf{key: []byte{1}}, + &Leaf{key: []byte{1}}, }, writes: []writeCall{ { @@ -669,8 +669,8 @@ func Test_encodeChildrenSequentially(t *testing.T) { }, "first two children not nil": { children: [16]Node{ - &leaf{key: []byte{1}}, - &leaf{key: []byte{2}}, + &Leaf{key: []byte{1}}, + &Leaf{key: []byte{2}}, }, writes: []writeCall{ { @@ -686,7 +686,7 @@ func Test_encodeChildrenSequentially(t *testing.T) { nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - &leaf{ + &Leaf{ key: []byte{1}, }, nil, nil, nil, nil, @@ -748,13 +748,13 @@ func Test_encodeChild(t *testing.T) { }{ "nil node": {}, "nil leaf": { - child: (*leaf)(nil), + child: (*Leaf)(nil), }, "nil branch": { child: (*Branch)(nil), }, "empty leaf child": { - child: &leaf{}, + child: &Leaf{}, writeCall: true, write: writeCall{ written: []byte{8, 64, 0}, @@ -778,7 +778,7 @@ func Test_encodeChild(t *testing.T) { errMessage: "failed to write child to buffer: test error", }, "leaf child": { - child: &leaf{ + child: &Leaf{ key: []byte{1}, value: []byte{2}, }, @@ -792,7 +792,7 @@ func Test_encodeChild(t *testing.T) { key: []byte{1}, value: []byte{2}, children: [16]Node{ - nil, nil, &leaf{ + nil, nil, &Leaf{ key: []byte{5}, value: []byte{6}, }, @@ -846,10 +846,10 @@ func Test_encodeAndHash(t *testing.T) { errMessage: "cannot hash node: " + "cannot encode node: " + "node type is not supported: " + - "*trie.Mocknode", + "*trie.MockNode", }, "leaf": { - n: &leaf{}, + n: &Leaf{}, b: []byte{0x8, 0x40, 0}, }, } @@ -877,13 +877,13 @@ func Test_encodeLeaf(t *testing.T) { t.Parallel() testCases := map[string]struct { - leaf *leaf + leaf *Leaf writes []writeCall wrappedErr error errMessage string }{ "clean leaf with encoding": { - leaf: &leaf{ + leaf: &Leaf{ encoding: []byte{1, 2, 3}, }, writes: []writeCall{ @@ -893,7 +893,7 @@ func Test_encodeLeaf(t *testing.T) { }, }, "write error for clean leaf with encoding": { - leaf: &leaf{ + leaf: &Leaf{ encoding: []byte{1, 2, 3}, }, writes: []writeCall{ @@ -906,14 +906,14 @@ func Test_encodeLeaf(t *testing.T) { errMessage: "cannot write stored encoding to buffer: test error", }, "header encoding error": { - leaf: &leaf{ + leaf: &Leaf{ key: make([]byte, 63+(1<<16)), }, wrappedErr: ErrPartialKeyTooBig, errMessage: "cannot encode header: partial key length greater than or equal to 2^16", }, "buffer write error for encoded header": { - leaf: &leaf{ + leaf: &Leaf{ key: []byte{1, 2, 3}, }, writes: []writeCall{ @@ -926,7 +926,7 @@ func Test_encodeLeaf(t *testing.T) { errMessage: "cannot write encoded header to buffer: test error", }, "buffer write error for encoded key": { - leaf: &leaf{ + leaf: &Leaf{ key: []byte{1, 2, 3}, }, writes: []writeCall{ @@ -942,7 +942,7 @@ func Test_encodeLeaf(t *testing.T) { errMessage: "cannot write LE key to buffer: test error", }, "buffer write error for encoded value": { - leaf: &leaf{ + leaf: &Leaf{ key: []byte{1, 2, 3}, value: []byte{4, 5, 6}, }, @@ -962,7 +962,7 @@ func Test_encodeLeaf(t *testing.T) { errMessage: "cannot write scale encoded value to buffer: test error", }, "success": { - leaf: &leaf{ + leaf: &Leaf{ key: []byte{1, 2, 3}, value: []byte{4, 5, 6}, }, diff --git a/lib/trie/node.go b/lib/trie/node.go index 17ff14f727..1a61a7d124 100644 --- a/lib/trie/node.go +++ b/lib/trie/node.go @@ -64,7 +64,7 @@ type ( generation uint64 sync.RWMutex } - leaf struct { + Leaf struct { key []byte // partial key value []byte dirty bool @@ -80,7 +80,7 @@ func (b *Branch) SetGeneration(generation uint64) { b.generation = generation } -func (l *leaf) SetGeneration(generation uint64) { +func (l *Leaf) SetGeneration(generation uint64) { l.generation = generation } @@ -110,14 +110,14 @@ func (b *Branch) Copy() Node { return cpy } -func (l *leaf) Copy() Node { +func (l *Leaf) Copy() Node { l.RLock() defer l.RUnlock() l.encodingMu.RLock() defer l.encodingMu.RUnlock() - cpy := &leaf{ + cpy := &Leaf{ key: make([]byte, len(l.key)), value: make([]byte, len(l.value)), dirty: l.dirty, @@ -137,7 +137,7 @@ func (b *Branch) SetEncodingAndHash(enc, hash []byte) { b.hash = hash } -func (l *leaf) SetEncodingAndHash(enc, hash []byte) { +func (l *Leaf) SetEncodingAndHash(enc, hash []byte) { l.encodingMu.Lock() l.encoding = enc l.encodingMu.Unlock() @@ -153,11 +153,11 @@ func (b *Branch) GetGeneration() uint64 { return b.generation } -func (l *leaf) GetGeneration() uint64 { +func (l *Leaf) GetGeneration() uint64 { return l.generation } -func (l *leaf) GetHash() []byte { +func (l *Leaf) GetHash() []byte { return l.hash } @@ -170,7 +170,7 @@ func (b *Branch) String() string { return fmt.Sprintf("branch key=%x childrenBitmap=%16b value=%v dirty=%v", b.key, b.childrenBitmap(), b.value, b.dirty) } -func (l *leaf) String() string { +func (l *Leaf) String() string { if len(l.value) > 1024 { return fmt.Sprintf("leaf key=%x value (hashed)=%x dirty=%v", l.key, common.MustBlake2bHash(l.value), l.dirty) } @@ -198,7 +198,7 @@ func (b *Branch) numChildren() int { return count } -func (l *leaf) IsDirty() bool { +func (l *Leaf) IsDirty() bool { return l.dirty } @@ -206,7 +206,7 @@ func (b *Branch) IsDirty() bool { return b.dirty } -func (l *leaf) SetDirty(dirty bool) { +func (l *Leaf) SetDirty(dirty bool) { l.dirty = dirty } @@ -214,7 +214,7 @@ func (b *Branch) SetDirty(dirty bool) { b.dirty = dirty } -func (l *leaf) SetKey(key []byte) { +func (l *Leaf) SetKey(key []byte) { l.key = key } @@ -260,7 +260,7 @@ func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) { return encoding, hash, nil } -func (l *leaf) EncodeAndHash() (encoding, hash []byte, err error) { +func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { l.encodingMu.RLock() if !l.IsDirty() && l.encoding != nil && l.hash != nil { l.encodingMu.RUnlock() @@ -320,7 +320,7 @@ func decode(r io.Reader) (Node, error) { nodeType := header >> 6 if nodeType == 1 { - l := new(leaf) + l := new(Leaf) err := l.Decode(r, header) return l, err } else if nodeType == 2 || nodeType == 3 { @@ -380,7 +380,7 @@ func (b *Branch) Decode(r io.Reader, header byte) (err error) { return err } - b.children[i] = &leaf{ + b.children[i] = &Leaf{ hash: hash, } } @@ -392,7 +392,7 @@ func (b *Branch) Decode(r io.Reader, header byte) (err error) { } // Decode decodes a byte array with the encoding specified at the top of this package into a leaf node -func (l *leaf) Decode(r io.Reader, header byte) (err error) { +func (l *Leaf) Decode(r io.Reader, header byte) (err error) { if header == 0 { header, err = readByte(r) if err != nil { @@ -451,7 +451,7 @@ func (b *Branch) header() ([]byte, error) { return fullHeader, nil } -func (l *leaf) header() ([]byte, error) { +func (l *Leaf) header() ([]byte, error) { var header byte = 1 << 6 var encodePkLen []byte var err error diff --git a/lib/trie/node_test.go b/lib/trie/node_test.go index be380e10be..52b087c541 100644 --- a/lib/trie/node_test.go +++ b/lib/trie/node_test.go @@ -42,19 +42,19 @@ func TestChildrenBitmap(t *testing.T) { t.Errorf("Fail to get children bitmap: got %x expected %x", res, 1) } - b.children[0] = &leaf{key: []byte{0x00}, value: []byte{0x00}} + b.children[0] = &Leaf{key: []byte{0x00}, value: []byte{0x00}} res = b.childrenBitmap() if res != 1 { t.Errorf("Fail to get children bitmap: got %x expected %x", res, 1) } - b.children[4] = &leaf{key: []byte{0x00}, value: []byte{0x00}} + b.children[4] = &Leaf{key: []byte{0x00}, value: []byte{0x00}} res = b.childrenBitmap() if res != 1<<4+1 { t.Errorf("Fail to get children bitmap: got %x expected %x", res, 17) } - b.children[15] = &leaf{key: []byte{0x00}, value: []byte{0x00}} + b.children[15] = &Leaf{key: []byte{0x00}, value: []byte{0x00}} res = b.childrenBitmap() if res != 1<<15+1<<4+1 { t.Errorf("Fail to get children bitmap: got %x expected %x", res, 257) @@ -115,18 +115,18 @@ func TestFailingPk(t *testing.T) { func TestLeafHeader(t *testing.T) { tests := []struct { - br *leaf + br *Leaf header []byte }{ - {&leaf{key: nil, value: nil}, []byte{0x40}}, - {&leaf{key: []byte{0x00}, value: nil}, []byte{0x41}}, - {&leaf{key: []byte{0x00, 0x00, 0xf, 0x3}, value: nil}, []byte{0x44}}, - {&leaf{key: byteArray(62), value: nil}, []byte{0x7e}}, - {&leaf{key: byteArray(63), value: nil}, []byte{0x7f, 0}}, - {&leaf{key: byteArray(64), value: []byte{0x01}}, []byte{0x7f, 1}}, - - {&leaf{key: byteArray(318), value: []byte{0x01}}, []byte{0x7f, 0xff, 0}}, - {&leaf{key: byteArray(573), value: []byte{0x01}}, []byte{0x7f, 0xff, 0xff, 0}}, + {&Leaf{key: nil, value: nil}, []byte{0x40}}, + {&Leaf{key: []byte{0x00}, value: nil}, []byte{0x41}}, + {&Leaf{key: []byte{0x00, 0x00, 0xf, 0x3}, value: nil}, []byte{0x44}}, + {&Leaf{key: byteArray(62), value: nil}, []byte{0x7e}}, + {&Leaf{key: byteArray(63), value: nil}, []byte{0x7f, 0}}, + {&Leaf{key: byteArray(64), value: []byte{0x01}}, []byte{0x7f, 1}}, + + {&Leaf{key: byteArray(318), value: []byte{0x01}}, []byte{0x7f, 0xff, 0}}, + {&Leaf{key: byteArray(573), value: []byte{0x01}}, []byte{0x7f, 0xff, 0xff, 0}}, } for i, test := range tests { @@ -188,7 +188,7 @@ func TestLeafEncode(t *testing.T) { randVals := generateRand(100) for i, testKey := range randKeys { - l := &leaf{key: testKey, value: randVals[i]} + l := &Leaf{key: testKey, value: randVals[i]} expected := []byte{} header, err := l.header() @@ -239,14 +239,14 @@ func TestBranchDecode(t *testing.T) { {key: []byte{0x00}, children: [16]Node{}, value: nil}, {key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]Node{}, value: nil}, {key: []byte{}, children: [16]Node{}, value: []byte{0x01}}, - {key: []byte{}, children: [16]Node{&leaf{}}, value: []byte{0x01}}, - {key: []byte{}, children: [16]Node{&leaf{}, nil, &leaf{}}, value: []byte{0x01}}, + {key: []byte{}, children: [16]Node{&Leaf{}}, value: []byte{0x01}}, + {key: []byte{}, children: [16]Node{&Leaf{}, nil, &Leaf{}}, value: []byte{0x01}}, { key: []byte{}, children: [16]Node{ - &leaf{}, nil, &leaf{}, nil, + &Leaf{}, nil, &Leaf{}, nil, nil, nil, nil, nil, - nil, &leaf{}, nil, &leaf{}, + nil, &Leaf{}, nil, &Leaf{}, }, value: []byte{0x01}, }, @@ -276,7 +276,7 @@ func TestBranchDecode(t *testing.T) { } func TestLeafDecode(t *testing.T) { - tests := []*leaf{ + tests := []*Leaf{ {key: []byte{}, value: nil, dirty: true}, {key: []byte{0x01}, value: nil, dirty: true}, {key: []byte{0x00, 0x00, 0xf, 0x3}, value: nil, dirty: true}, @@ -293,7 +293,7 @@ func TestLeafDecode(t *testing.T) { err := encodeLeaf(test, buffer) require.NoError(t, err) - res := new(leaf) + res := new(Leaf) err = res.Decode(buffer, 0) require.NoError(t, err) @@ -309,24 +309,24 @@ func TestDecode(t *testing.T) { &Branch{key: []byte{0x00}, children: [16]Node{}, value: nil}, &Branch{key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]Node{}, value: nil}, &Branch{key: []byte{}, children: [16]Node{}, value: []byte{0x01}}, - &Branch{key: []byte{}, children: [16]Node{&leaf{}}, value: []byte{0x01}}, - &Branch{key: []byte{}, children: [16]Node{&leaf{}, nil, &leaf{}}, value: []byte{0x01}}, + &Branch{key: []byte{}, children: [16]Node{&Leaf{}}, value: []byte{0x01}}, + &Branch{key: []byte{}, children: [16]Node{&Leaf{}, nil, &Leaf{}}, value: []byte{0x01}}, &Branch{ key: []byte{}, children: [16]Node{ - &leaf{}, nil, &leaf{}, nil, + &Leaf{}, nil, &Leaf{}, nil, nil, nil, nil, nil, - nil, &leaf{}, nil, &leaf{}}, + nil, &Leaf{}, nil, &Leaf{}}, value: []byte{0x01}, }, - &leaf{key: []byte{}, value: nil}, - &leaf{key: []byte{0x00}, value: nil}, - &leaf{key: []byte{0x00, 0x00, 0xf, 0x3}, value: nil}, - &leaf{key: byteArray(62), value: nil}, - &leaf{key: byteArray(63), value: nil}, - &leaf{key: byteArray(64), value: []byte{0x01}}, - &leaf{key: byteArray(318), value: []byte{0x01}}, - &leaf{key: byteArray(573), value: []byte{0x01}}, + &Leaf{key: []byte{}, value: nil}, + &Leaf{key: []byte{0x00}, value: nil}, + &Leaf{key: []byte{0x00, 0x00, 0xf, 0x3}, value: nil}, + &Leaf{key: byteArray(62), value: nil}, + &Leaf{key: byteArray(63), value: nil}, + &Leaf{key: byteArray(64), value: []byte{0x01}}, + &Leaf{key: byteArray(318), value: []byte{0x01}}, + &Leaf{key: byteArray(573), value: []byte{0x01}}, } buffer := bytes.NewBuffer(nil) @@ -344,9 +344,9 @@ func TestDecode(t *testing.T) { require.Equal(t, n.key, res.(*Branch).key) require.Equal(t, n.childrenBitmap(), res.(*Branch).childrenBitmap()) require.Equal(t, n.value, res.(*Branch).value) - case *leaf: - require.Equal(t, n.key, res.(*leaf).key) - require.Equal(t, n.value, res.(*leaf).value) + case *Leaf: + require.Equal(t, n.key, res.(*Leaf).key) + require.Equal(t, n.value, res.(*Leaf).value) default: t.Fatal("unexpected node") } diff --git a/lib/trie/print.go b/lib/trie/print.go index c8824a5ade..1804f86d99 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -48,7 +48,7 @@ func (t *Trie) string(tree gotree.Tree, curr Node, idx int) { t.string(sub, child, i) } } - case *leaf: + case *Leaf: buffer := encodingBufferPool.Get().(*bytes.Buffer) buffer.Reset() diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 2881f3a781..ad99492c8b 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -149,7 +149,7 @@ func (t *Trie) entries(current Node, prefix []byte, kv map[string][]byte) map[st for i, child := range c.children { t.entries(child, append(prefix, append(c.key, byte(i))...), kv) } - case *leaf: + case *Leaf: kv[string(nibblesToKeyLE(append(prefix, c.key...)))] = c.value return kv } @@ -220,7 +220,7 @@ func (t *Trie) nextKey(curr Node, prefix, key []byte) []byte { } } } - case *leaf: + case *Leaf: fullKey := append(prefix, c.key...) var cmp int if len(key) < len(fullKey) { @@ -252,7 +252,7 @@ func (t *Trie) Put(key, value []byte) { func (t *Trie) tryPut(key, value []byte) { k := keyToNibbles(key) - t.root = t.insert(t.root, k, &leaf{key: nil, value: value, dirty: true, generation: t.generation}) + t.root = t.insert(t.root, k, &Leaf{key: nil, value: value, dirty: true, generation: t.generation}) } // insert attempts to insert a key with value into the trie @@ -268,12 +268,12 @@ func (t *Trie) insert(parent Node, key []byte, value Node) Node { case nil: value.SetKey(key) return value - case *leaf: + case *Leaf: // if a value already exists in the trie at this key, overwrite it with the new value // if the values are the same, don't mark node dirty if p.value != nil && bytes.Equal(p.key, key) { - if !bytes.Equal(value.(*leaf).value, p.value) { - p.value = value.(*leaf).value + if !bytes.Equal(value.(*Leaf).value, p.value) { + p.value = value.(*Leaf).value p.dirty = true } return p @@ -287,7 +287,7 @@ func (t *Trie) insert(parent Node, key []byte, value Node) Node { // value goes at this branch if len(key) == length { - br.value = value.(*leaf).value + br.value = value.(*Leaf).value br.SetDirty(true) // if we are not replacing previous leaf, then add it as a child to the new branch @@ -335,14 +335,14 @@ func (t *Trie) updateBranch(p *Branch, key []byte, value Node) (n Node) { switch v := value.(type) { case *Branch: p.value = v.value - case *leaf: + case *Leaf: p.value = v.value } return p } switch c := p.children[key[length]].(type) { - case *Branch, *leaf: + case *Branch, *Leaf: n = t.insert(c, key[length+1:], value) p.children[key[length]] = n n.SetDirty(true) @@ -350,7 +350,7 @@ func (t *Trie) updateBranch(p *Branch, key []byte, value Node) (n Node) { return p case nil: // otherwise, add node as child of this branch - value.(*leaf).key = key[length+1:] + value.(*Leaf).key = key[length+1:] p.children[key[length]] = value p.SetDirty(true) return p @@ -367,7 +367,7 @@ func (t *Trie) updateBranch(p *Branch, key []byte, value Node) (n Node) { br.children[parentIndex] = t.insert(nil, p.key[length+1:], p) if len(key) <= length { - br.value = value.(*leaf).value + br.value = value.(*Leaf).value } else { br.children[key[length]] = t.insert(nil, key[length+1:], value) } @@ -424,7 +424,7 @@ func (t *Trie) getKeysWithPrefix(parent Node, prefix, key []byte, keys [][]byte) key = key[len(p.key):] keys = t.getKeysWithPrefix(p.children[key[0]], append(append(prefix, p.key...), key[0]), key[1:], keys) - case *leaf: + case *Leaf: length := lenCommonPrefix(p.key, key) if bytes.Equal(p.key[:length], key) || len(key) == 0 { keys = append(keys, nibblesToKeyLE(append(prefix, p.key...))) @@ -447,7 +447,7 @@ func (t *Trie) addAllKeys(parent Node, prefix []byte, keys [][]byte) [][]byte { for i, child := range p.children { keys = t.addAllKeys(child, append(append(prefix, p.key...), byte(i)), keys) } - case *leaf: + case *Leaf: keys = append(keys, nibblesToKeyLE(append(prefix, p.key...))) case nil: return keys @@ -466,14 +466,14 @@ func (t *Trie) Get(key []byte) []byte { return l.value } -func (t *Trie) tryGet(key []byte) *leaf { +func (t *Trie) tryGet(key []byte) *Leaf { k := keyToNibbles(key) return t.retrieve(t.root, k) } -func (t *Trie) retrieve(parent Node, key []byte) *leaf { +func (t *Trie) retrieve(parent Node, key []byte) *Leaf { var ( - value *leaf + value *Leaf ) switch p := parent.(type) { @@ -482,7 +482,7 @@ func (t *Trie) retrieve(parent Node, key []byte) *leaf { // found the value at this node if bytes.Equal(p.key, key) || len(key) == 0 { - return &leaf{key: p.key, value: p.value, dirty: false} + return &Leaf{key: p.key, value: p.value, dirty: false} } // did not find value @@ -491,7 +491,7 @@ func (t *Trie) retrieve(parent Node, key []byte) *leaf { } value = t.retrieve(p.children[key[length]], key[length+1:]) - case *leaf: + case *Leaf: if bytes.Equal(p.key, key) { value = p } @@ -562,7 +562,7 @@ func (t *Trie) clearPrefixLimit(cn Node, prefix []byte, limit *uint32) (Node, bo } return curr, curr.IsDirty(), allDeleted - case *leaf: + case *Leaf: length := lenCommonPrefix(c.key, prefix) if length == len(prefix) { *limit-- @@ -582,7 +582,7 @@ func (t *Trie) deleteNodes(cn Node, prefix []byte, limit *uint32) (Node, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { - case *leaf: + case *Leaf: if *limit == 0 { return c, false } @@ -681,7 +681,7 @@ func (t *Trie) clearPrefix(cn Node, prefix []byte) (Node, bool) { } return curr, curr.IsDirty() - case *leaf: + case *Leaf: length := lenCommonPrefix(c.key, prefix) if length == len(prefix) { return nil, true @@ -723,7 +723,7 @@ func (t *Trie) delete(parent Node, key []byte) (Node, bool) { p.SetDirty(true) n = handleDeletion(p, key) return n, true - case *leaf: + case *Leaf: if bytes.Equal(key, p.key) || len(key) == 0 { // Key exists. Delete it. return nil, true @@ -747,7 +747,7 @@ func handleDeletion(p *Branch, key []byte) Node { // if branch has no children, just a value, turn it into a leaf if bitmap == 0 && p.value != nil { - n = &leaf{key: key[:length], value: p.value, dirty: true} + n = &Leaf{key: key[:length], value: p.value, dirty: true} } else if p.numChildren() == 1 && p.value == nil { // there is only 1 child and no value, combine the child branch with this branch // find index of child @@ -761,8 +761,8 @@ func handleDeletion(p *Branch, key []byte) Node { child := p.children[i] switch c := child.(type) { - case *leaf: - n = &leaf{key: append(append(p.key, []byte{byte(i)}...), c.key...), value: c.value} + case *Leaf: + n = &Leaf{key: append(append(p.key, []byte{byte(i)}...), c.key...), value: c.value} case *Branch: br := new(Branch) br.key = append(p.key, append([]byte{byte(i)}, c.key...)...) diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index 3d3bf170d1..29cdc54e0a 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -68,7 +68,7 @@ func TestNewEmptyTrie(t *testing.T) { } func TestNewTrie(t *testing.T) { - trie := NewTrie(&leaf{key: []byte{0}, value: []byte{17}}) + trie := NewTrie(&Leaf{key: []byte{0}, value: []byte{17}}) if trie == nil { t.Error("did not initialise trie") } @@ -942,7 +942,7 @@ func TestClearPrefix_Small(t *testing.T) { } ssTrie.ClearPrefix([]byte("noo")) - require.Equal(t, ssTrie.root, &leaf{key: keyToNibbles([]byte("other")), value: []byte("other"), dirty: true}) + require.Equal(t, ssTrie.root, &Leaf{key: keyToNibbles([]byte("other")), value: []byte("other"), dirty: true}) // Get the updated root hash of all tries. tHash, err = trie.Hash() From 5bc7d18edf34fedd150a842072a6262b1c116dcf Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 23 Nov 2021 13:51:27 +0000 Subject: [PATCH 05/50] Add exported comments --- lib/trie/node.go | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/lib/trie/node.go b/lib/trie/node.go index 1a61a7d124..7c9f316a31 100644 --- a/lib/trie/node.go +++ b/lib/trie/node.go @@ -54,6 +54,7 @@ type Node interface { } type ( + // Branch is a branch in the trie. Branch struct { key []byte // partial key children [16]Node @@ -64,6 +65,8 @@ type ( generation uint64 sync.RWMutex } + + // Leaf is a leaf in the trie. Leaf struct { key []byte // partial key value []byte @@ -76,14 +79,17 @@ type ( } ) +// SetGeneration sets the generation given to the branch. func (b *Branch) SetGeneration(generation uint64) { b.generation = generation } +// SetGeneration sets the generation given to the leaf. func (l *Leaf) SetGeneration(generation uint64) { l.generation = generation } +// Copy deep copies the branch. func (b *Branch) Copy() Node { b.RLock() defer b.RUnlock() @@ -110,6 +116,7 @@ func (b *Branch) Copy() Node { return cpy } +// Copy deep copies the leaf. func (l *Leaf) Copy() Node { l.RLock() defer l.RUnlock() @@ -132,11 +139,15 @@ func (l *Leaf) Copy() Node { return cpy } +// SetEncodingAndHash sets the encoding and hash slices +// given to the branch. Note it does not copy them, so beware. func (b *Branch) SetEncodingAndHash(enc, hash []byte) { b.encoding = enc b.hash = hash } +// SetEncodingAndHash sets the encoding and hash slices +// given to the branch. Note it does not copy them, so beware. func (l *Leaf) SetEncodingAndHash(enc, hash []byte) { l.encodingMu.Lock() l.encoding = enc @@ -145,18 +156,28 @@ func (l *Leaf) SetEncodingAndHash(enc, hash []byte) { l.hash = hash } +// GetHash returns the hash of the branch. +// Note it does not copy it, so modifying +// the returned hash will modify the hash +// of the branch. func (b *Branch) GetHash() []byte { return b.hash } +// GetGeneration returns the generation of the branch. func (b *Branch) GetGeneration() uint64 { return b.generation } +// GetGeneration returns the generation of the leaf. func (l *Leaf) GetGeneration() uint64 { return l.generation } +// GetHash returns the hash of the leaf. +// Note it does not copy it, so modifying +// the returned hash will modify the hash +// of the branch. func (l *Leaf) GetHash() []byte { return l.hash } @@ -198,30 +219,44 @@ func (b *Branch) numChildren() int { return count } +// IsDirty returns the dirty status of the leaf. func (l *Leaf) IsDirty() bool { return l.dirty } +// IsDirty returns the dirty status of the branch. func (b *Branch) IsDirty() bool { return b.dirty } +// SetDirty sets the dirty status to the leaf. func (l *Leaf) SetDirty(dirty bool) { l.dirty = dirty } +// SetDirty sets the dirty status to the branch. func (b *Branch) SetDirty(dirty bool) { b.dirty = dirty } +// SetKey sets the key to the leaf. +// Note it does not copy it so modifying the passed key +// will modify the key stored in the leaf. func (l *Leaf) SetKey(key []byte) { l.key = key } +// SetKey sets the key to the branch. +// Note it does not copy it so modifying the passed key +// will modify the key stored in the branch. func (b *Branch) SetKey(key []byte) { b.key = key } +// EncodeAndHash returns the encoding of the branch and +// the blake2b hash digest of the encoding of the branch. +// If the encoding is less than 32 bytes, the hash returned +// is the encoding and not the hash of the encoding. func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) { if !b.dirty && b.encoding != nil && b.hash != nil { return b.encoding, b.hash, nil @@ -260,6 +295,10 @@ func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) { return encoding, hash, nil } +// EncodeAndHash returns the encoding of the leaf and +// the blake2b hash digest of the encoding of the leaf. +// If the encoding is less than 32 bytes, the hash returned +// is the encoding and not the hash of the encoding. func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { l.encodingMu.RLock() if !l.IsDirty() && l.encoding != nil && l.hash != nil { From 59ab0031201f77625398a02942cbf8c9bc0fc967 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 29 Nov 2021 15:21:10 +0000 Subject: [PATCH 06/50] Refactor encoding and hash related code with tests --- lib/trie/branch/branch.go | 35 + lib/trie/branch/buffer_mock_test.go | 77 ++ lib/trie/branch/children.go | 29 + lib/trie/branch/children_test.go | 122 +++ lib/trie/branch/copy.go | 33 + lib/trie/branch/decode.go | 87 ++ lib/trie/branch/decode_test.go | 168 ++++ lib/trie/branch/dirty.go | 14 + lib/trie/branch/encode.go | 242 +++++ lib/trie/branch/encode_test.go | 605 ++++++++++++ lib/trie/branch/generation.go | 14 + lib/trie/branch/hash.go | 68 ++ lib/trie/branch/header.go | 32 + lib/trie/branch/header_test.go | 80 ++ lib/trie/branch/key.go | 11 + lib/trie/branch/writer_mock_test.go | 49 + lib/trie/bytesBuffer_mock_test.go | 77 -- lib/trie/codec.go | 67 -- lib/trie/codec_test.go | 80 -- lib/trie/database.go | 82 +- lib/trie/decode.go | 45 + lib/trie/decode/byte.go | 16 + lib/trie/decode/byte_test.go | 52 + lib/trie/decode/key.go | 77 ++ lib/trie/decode/key_test.go | 138 +++ lib/trie/decode_test.go | 106 ++ lib/trie/encode/buffer.go | 16 + lib/trie/encode/doc.go | 28 + lib/trie/encode/key.go | 84 ++ lib/trie/encode/key_test.go | 170 ++++ lib/trie/encodedecode_test/branch_test.go | 89 ++ lib/trie/encodedecode_test/nibbles_test.go | 50 + lib/trie/hash.go | 350 ------- lib/trie/hash_test.go | 1012 -------------------- lib/trie/leaf/buffer_mock_test.go | 77 ++ lib/trie/leaf/copy.go | 29 + lib/trie/leaf/decode.go | 57 ++ lib/trie/leaf/decode_test.go | 121 +++ lib/trie/leaf/dirty.go | 14 + lib/trie/leaf/encode.go | 203 ++++ lib/trie/leaf/encode_test.go | 318 ++++++ lib/trie/leaf/generation.go | 14 + lib/trie/leaf/header.go | 26 + lib/trie/leaf/header_test.go | 74 ++ lib/trie/leaf/key.go | 11 + lib/trie/leaf/leaf.go | 33 + lib/trie/leaf/writer_mock_test.go | 49 + lib/trie/lookup.go | 15 +- lib/trie/node.go | 577 ----------- lib/trie/node/interface.go | 25 + lib/trie/node/types.go | 17 + lib/trie/node_mock_test.go | 183 ---- lib/trie/node_test.go | 327 +------ lib/trie/pools/pools.go | 42 + lib/trie/print.go | 45 +- lib/trie/proof.go | 3 +- lib/trie/trie.go | 331 ++++--- lib/trie/trie_test.go | 30 +- 58 files changed, 3917 insertions(+), 2909 deletions(-) create mode 100644 lib/trie/branch/branch.go create mode 100644 lib/trie/branch/buffer_mock_test.go create mode 100644 lib/trie/branch/children.go create mode 100644 lib/trie/branch/children_test.go create mode 100644 lib/trie/branch/copy.go create mode 100644 lib/trie/branch/decode.go create mode 100644 lib/trie/branch/decode_test.go create mode 100644 lib/trie/branch/dirty.go create mode 100644 lib/trie/branch/encode.go create mode 100644 lib/trie/branch/encode_test.go create mode 100644 lib/trie/branch/generation.go create mode 100644 lib/trie/branch/hash.go create mode 100644 lib/trie/branch/header.go create mode 100644 lib/trie/branch/header_test.go create mode 100644 lib/trie/branch/key.go create mode 100644 lib/trie/branch/writer_mock_test.go delete mode 100644 lib/trie/bytesBuffer_mock_test.go delete mode 100644 lib/trie/codec.go delete mode 100644 lib/trie/codec_test.go create mode 100644 lib/trie/decode.go create mode 100644 lib/trie/decode/byte.go create mode 100644 lib/trie/decode/byte_test.go create mode 100644 lib/trie/decode/key.go create mode 100644 lib/trie/decode/key_test.go create mode 100644 lib/trie/decode_test.go create mode 100644 lib/trie/encode/buffer.go create mode 100644 lib/trie/encode/doc.go create mode 100644 lib/trie/encode/key.go create mode 100644 lib/trie/encode/key_test.go create mode 100644 lib/trie/encodedecode_test/branch_test.go create mode 100644 lib/trie/encodedecode_test/nibbles_test.go delete mode 100644 lib/trie/hash.go delete mode 100644 lib/trie/hash_test.go create mode 100644 lib/trie/leaf/buffer_mock_test.go create mode 100644 lib/trie/leaf/copy.go create mode 100644 lib/trie/leaf/decode.go create mode 100644 lib/trie/leaf/decode_test.go create mode 100644 lib/trie/leaf/dirty.go create mode 100644 lib/trie/leaf/encode.go create mode 100644 lib/trie/leaf/encode_test.go create mode 100644 lib/trie/leaf/generation.go create mode 100644 lib/trie/leaf/header.go create mode 100644 lib/trie/leaf/header_test.go create mode 100644 lib/trie/leaf/key.go create mode 100644 lib/trie/leaf/leaf.go create mode 100644 lib/trie/leaf/writer_mock_test.go delete mode 100644 lib/trie/node.go create mode 100644 lib/trie/node/interface.go create mode 100644 lib/trie/node/types.go delete mode 100644 lib/trie/node_mock_test.go create mode 100644 lib/trie/pools/pools.go diff --git a/lib/trie/branch/branch.go b/lib/trie/branch/branch.go new file mode 100644 index 0000000000..a5d0dfc36d --- /dev/null +++ b/lib/trie/branch/branch.go @@ -0,0 +1,35 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package branch + +import ( + "fmt" + "sync" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie/node" +) + +var _ node.Node = (*Branch)(nil) + +// Branch is a branch in the trie. +type Branch struct { + Key []byte // partial key + Children [16]node.Node + Value []byte + Dirty bool + Hash []byte + Encoding []byte + Generation uint64 + sync.RWMutex +} + +func (b *Branch) String() string { + if len(b.Value) > 1024 { + return fmt.Sprintf("key=%x childrenBitmap=%16b value (hashed)=%x dirty=%v", + b.Key, b.ChildrenBitmap(), common.MustBlake2bHash(b.Value), b.Dirty) + } + return fmt.Sprintf("key=%x childrenBitmap=%16b value=%v dirty=%v", + b.Key, b.ChildrenBitmap(), b.Value, b.Dirty) +} diff --git a/lib/trie/branch/buffer_mock_test.go b/lib/trie/branch/buffer_mock_test.go new file mode 100644 index 0000000000..3e864b1cfc --- /dev/null +++ b/lib/trie/branch/buffer_mock_test.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ChainSafe/gossamer/lib/trie/encode (interfaces: Buffer) + +// Package branch is a generated GoMock package. +package branch + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockBuffer is a mock of Buffer interface. +type MockBuffer struct { + ctrl *gomock.Controller + recorder *MockBufferMockRecorder +} + +// MockBufferMockRecorder is the mock recorder for MockBuffer. +type MockBufferMockRecorder struct { + mock *MockBuffer +} + +// NewMockBuffer creates a new mock instance. +func NewMockBuffer(ctrl *gomock.Controller) *MockBuffer { + mock := &MockBuffer{ctrl: ctrl} + mock.recorder = &MockBufferMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBuffer) EXPECT() *MockBufferMockRecorder { + return m.recorder +} + +// Bytes mocks base method. +func (m *MockBuffer) Bytes() []byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Bytes") + ret0, _ := ret[0].([]byte) + return ret0 +} + +// Bytes indicates an expected call of Bytes. +func (mr *MockBufferMockRecorder) Bytes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bytes", reflect.TypeOf((*MockBuffer)(nil).Bytes)) +} + +// Len mocks base method. +func (m *MockBuffer) Len() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Len") + ret0, _ := ret[0].(int) + return ret0 +} + +// Len indicates an expected call of Len. +func (mr *MockBufferMockRecorder) Len() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockBuffer)(nil).Len)) +} + +// Write mocks base method. +func (m *MockBuffer) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockBufferMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockBuffer)(nil).Write), arg0) +} diff --git a/lib/trie/branch/children.go b/lib/trie/branch/children.go new file mode 100644 index 0000000000..54395458da --- /dev/null +++ b/lib/trie/branch/children.go @@ -0,0 +1,29 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package branch + +// ChildrenBitmap returns the 16 bit bitmap +// of the children in the branch. +func (b *Branch) ChildrenBitmap() uint16 { + var bitmap uint16 + var i uint + for i = 0; i < 16; i++ { + if b.Children[i] != nil { + bitmap = bitmap | 1<> 6 + if nodeType != 2 && nodeType != 3 { + return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotABranch, nodeType) + } + + branch = new(Branch) + + keyLen := header & 0x3f + branch.Key, err = decode.Key(reader, keyLen) + if err != nil { + return nil, fmt.Errorf("cannot decode key: %w", err) + } + + childrenBitmap := make([]byte, 2) + _, err = reader.Read(childrenBitmap) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrReadChildrenBitmap, err) + } + + sd := scale.NewDecoder(reader) + + if nodeType == 3 { + var value []byte + // branch w/ value + err := sd.Decode(&value) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrDecodeValue, err) + } + branch.Value = value + } + + for i := 0; i < 16; i++ { + if (childrenBitmap[i/8]>>(i%8))&1 != 1 { + continue + } + var hash []byte + err := sd.Decode(&hash) + if err != nil { + return nil, fmt.Errorf("%w: at index %d: %s", + ErrDecodeChildHash, i, err) + } + + branch.Children[i] = &leaf.Leaf{ + Hash: hash, + } + } + + branch.Dirty = true // TODO move as soon as it gets modified? + + return branch, nil +} diff --git a/lib/trie/branch/decode_test.go b/lib/trie/branch/decode_test.go new file mode 100644 index 0000000000..9d89ea2d7a --- /dev/null +++ b/lib/trie/branch/decode_test.go @@ -0,0 +1,168 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package branch + +import ( + "bytes" + "io" + "testing" + + "github.com/ChainSafe/gossamer/lib/trie/decode" + "github.com/ChainSafe/gossamer/lib/trie/leaf" + "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func scaleEncodeBytes(t *testing.T, b ...byte) (encoded []byte) { + encoded, err := scale.Marshal(b) + require.NoError(t, err) + return encoded +} + +func concatByteSlices(slices [][]byte) (concatenated []byte) { + length := 0 + for i := range slices { + length += len(slices[i]) + } + concatenated = make([]byte, 0, length) + for _, slice := range slices { + concatenated = append(concatenated, slice...) + } + return concatenated +} + +func Test_Decode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reader io.Reader + header byte + branch *Branch + errWrapped error + errMessage string + }{ + "no data with header 0": { + reader: bytes.NewBuffer(nil), + errWrapped: ErrReadHeaderByte, + errMessage: "cannot read header byte: EOF", + }, + "no data with header 1": { + reader: bytes.NewBuffer(nil), + header: 1, + errWrapped: ErrNodeTypeIsNotABranch, + errMessage: "node type is not a branch: 0", + }, + "first byte as 0 header 0": { + reader: bytes.NewBuffer([]byte{0}), + errWrapped: ErrNodeTypeIsNotABranch, + errMessage: "node type is not a branch: 0", + }, + "key decoding error": { + reader: bytes.NewBuffer([]byte{ + 129, // node type 2 and key length 1 + // missing key data byte + }), + errWrapped: decode.ErrReadKeyData, + errMessage: "cannot decode key: cannot read key data: EOF", + }, + "children bitmap read error": { + reader: bytes.NewBuffer([]byte{ + 129, // node type 2 and key length 1 + 9, // key data + // missing children bitmap 2 bytes + }), + errWrapped: ErrReadChildrenBitmap, + errMessage: "cannot read children bitmap: EOF", + }, + "children decoding error": { + reader: bytes.NewBuffer([]byte{ + 129, // node type 2 and key length 1 + 9, // key data + 0, 4, // children bitmap + // missing children scale encoded data + }), + errWrapped: ErrDecodeChildHash, + errMessage: "cannot decode child hash: at index 10: EOF", + }, + "success node type 2": { + reader: bytes.NewBuffer( + concatByteSlices([][]byte{ + { + 129, // node type 2 and key length 1 + 9, // key data + 0, 4, // children bitmap + }, + scaleEncodeBytes(t, 1, 2, 3, 4, 5), // child hash + }), + ), + branch: &Branch{ + Key: []byte{9}, + Children: [16]node.Node{ + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + &leaf.Leaf{ + Hash: []byte{1, 2, 3, 4, 5}, + }, + }, + Dirty: true, + }, + }, + "value decoding error for node type 3": { + reader: bytes.NewBuffer( + concatByteSlices([][]byte{ + { + 193, // node type 3 and key length 1 + 9, // key data + }, + {0, 4}, // children bitmap + // missing encoded branch value + }), + ), + errWrapped: ErrDecodeValue, + errMessage: "cannot decode value: EOF", + }, + "success node type 3": { + reader: bytes.NewBuffer( + concatByteSlices([][]byte{ + { + 193, // node type 3 and key length 1 + 9, // key data + }, + {0, 4}, // children bitmap + scaleEncodeBytes(t, 7, 8, 9), // branch value + scaleEncodeBytes(t, 1, 2, 3, 4, 5), // child hash + }), + ), + branch: &Branch{ + Key: []byte{9}, + Value: []byte{7, 8, 9}, + Children: [16]node.Node{ + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + &leaf.Leaf{ + Hash: []byte{1, 2, 3, 4, 5}, + }, + }, + Dirty: true, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + branch, err := Decode(testCase.reader, testCase.header) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.branch, branch) + }) + } +} diff --git a/lib/trie/branch/dirty.go b/lib/trie/branch/dirty.go new file mode 100644 index 0000000000..930c01fa91 --- /dev/null +++ b/lib/trie/branch/dirty.go @@ -0,0 +1,14 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package branch + +// IsDirty returns the dirty status of the branch. +func (b *Branch) IsDirty() bool { + return b.Dirty +} + +// SetDirty sets the dirty status to the branch. +func (b *Branch) SetDirty(dirty bool) { + b.Dirty = dirty +} diff --git a/lib/trie/branch/encode.go b/lib/trie/branch/encode.go new file mode 100644 index 0000000000..3a1bfa4e74 --- /dev/null +++ b/lib/trie/branch/encode.go @@ -0,0 +1,242 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package branch + +import ( + "bytes" + "fmt" + "hash" + "io" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie/encode" + "github.com/ChainSafe/gossamer/lib/trie/leaf" + "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/ChainSafe/gossamer/lib/trie/pools" + "github.com/ChainSafe/gossamer/pkg/scale" +) + +// ScaleEncodeHash hashes the node (blake2b sum on encoded value) +// and then SCALE encodes it. This is used to encode children +// nodes of branches. +func (b *Branch) ScaleEncodeHash() (encoding []byte, err error) { + // if b == nil { // TODO remove + // panic("Should write 0 to buffer") + // } + + buffer := pools.DigestBuffers.Get().(*bytes.Buffer) + buffer.Reset() + defer pools.DigestBuffers.Put(buffer) + + err = b.hash(buffer) + if err != nil { + return nil, fmt.Errorf("cannot hash node: %w", err) + } + + encoding, err = scale.Marshal(buffer.Bytes()) + if err != nil { + return nil, fmt.Errorf("cannot scale encode hashed node: %w", err) + } + + return encoding, nil +} + +func (b *Branch) hash(digestBuffer io.Writer) (err error) { + encodingBuffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + encodingBuffer.Reset() + defer pools.EncodingBuffers.Put(encodingBuffer) + + err = b.Encode(encodingBuffer) + if err != nil { + return fmt.Errorf("cannot encode leaf: %w", err) + } + + // if length of encoded leaf is less than 32 bytes, do not hash + if encodingBuffer.Len() < 32 { + _, err = digestBuffer.Write(encodingBuffer.Bytes()) + if err != nil { + return fmt.Errorf("cannot write encoded branch to buffer: %w", err) + } + return nil + } + + // otherwise, hash encoded node + hasher := pools.Hashers.Get().(hash.Hash) + hasher.Reset() + defer pools.Hashers.Put(hasher) + + // Note: using the sync.Pool's buffer is useful here. + _, err = hasher.Write(encodingBuffer.Bytes()) + if err != nil { + return fmt.Errorf("cannot hash encoded node: %w", err) + } + + _, err = digestBuffer.Write(hasher.Sum(nil)) + if err != nil { + return fmt.Errorf("cannot write hash sum of branch to buffer: %w", err) + } + return nil +} + +// Encode encodes a branch with the encoding specified at the top of this package +// to the buffer given. +func (b *Branch) Encode(buffer encode.Buffer) (err error) { + if !b.Dirty && b.Encoding != nil { + _, err = buffer.Write(b.Encoding) + if err != nil { + return fmt.Errorf("cannot write stored encoding to buffer: %w", err) + } + return nil + } + + encodedHeader, err := b.Header() + if err != nil { + return fmt.Errorf("cannot encode header: %w", err) + } + + _, err = buffer.Write(encodedHeader) + if err != nil { + return fmt.Errorf("cannot write encoded header to buffer: %w", err) + } + + keyLE := encode.NibblesToKeyLE(b.Key) + _, err = buffer.Write(keyLE) + if err != nil { + return fmt.Errorf("cannot write encoded key to buffer: %w", err) + } + + childrenBitmap := common.Uint16ToBytes(b.ChildrenBitmap()) + _, err = buffer.Write(childrenBitmap) + if err != nil { + return fmt.Errorf("cannot write children bitmap to buffer: %w", err) + } + + if b.Value != nil { + bytes, err := scale.Marshal(b.Value) + if err != nil { + return fmt.Errorf("cannot scale encode value: %w", err) + } + + _, err = buffer.Write(bytes) + if err != nil { + return fmt.Errorf("cannot write encoded value to buffer: %w", err) + } + } + + const parallel = false // TODO + if parallel { + err = encodeChildrenInParallel(b.Children, buffer) + } else { + err = encodeChildrenSequentially(b.Children, buffer) + } + if err != nil { + return fmt.Errorf("cannot encode children of branch: %w", err) + } + + return nil +} + +func encodeChildrenInParallel(children [16]node.Node, buffer io.Writer) (err error) { + type result struct { + index int + buffer *bytes.Buffer + err error + } + + resultsCh := make(chan result) + + for i, child := range children { + go func(index int, child node.Node) { + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + buffer.Reset() + // buffer is put back in the pool after processing its + // data in the select block below. + + err := encodeChild(child, buffer) + + resultsCh <- result{ + index: index, + buffer: buffer, + err: err, + } + }(i, child) + } + + currentIndex := 0 + resultBuffers := make([]*bytes.Buffer, len(children)) + for range children { + result := <-resultsCh + if result.err != nil && err == nil { // only set the first error we get + err = result.err + } + + resultBuffers[result.index] = result.buffer + + // write as many completed buffers to the result buffer. + for currentIndex < len(children) && + resultBuffers[currentIndex] != nil { + bufferSlice := resultBuffers[currentIndex].Bytes() + if len(bufferSlice) > 0 { + // note buffer.Write copies the byte slice given as argument + _, writeErr := buffer.Write(bufferSlice) + if writeErr != nil && err == nil { + err = fmt.Errorf( + "cannot write encoding of child at index %d: %w", + currentIndex, writeErr) + } + } + + pools.EncodingBuffers.Put(resultBuffers[currentIndex]) + resultBuffers[currentIndex] = nil + + currentIndex++ + } + } + + for _, buffer := range resultBuffers { + if buffer == nil { // already emptied and put back in pool + continue + } + pools.EncodingBuffers.Put(buffer) + } + + return err +} + +func encodeChildrenSequentially(children [16]node.Node, buffer io.Writer) (err error) { + for i, child := range children { + err = encodeChild(child, buffer) + if err != nil { + return fmt.Errorf("cannot encode child at index %d: %w", i, err) + } + } + return nil +} + +func encodeChild(child node.Node, buffer io.Writer) (err error) { + var isNil bool + switch impl := child.(type) { + case *Branch: + isNil = impl == nil + case *leaf.Leaf: + isNil = impl == nil + default: + isNil = child == nil + } + if isNil { + return nil + } + + scaleEncodedChild, err := child.ScaleEncodeHash() + if err != nil { + return fmt.Errorf("failed to hash and scale encode child: %w", err) + } + + _, err = buffer.Write(scaleEncodedChild) + if err != nil { + return fmt.Errorf("failed to write child to buffer: %w", err) + } + + return nil +} diff --git a/lib/trie/branch/encode_test.go b/lib/trie/branch/encode_test.go new file mode 100644 index 0000000000..4b9324f74b --- /dev/null +++ b/lib/trie/branch/encode_test.go @@ -0,0 +1,605 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package branch + +import ( + "errors" + "testing" + + "github.com/ChainSafe/gossamer/lib/trie/encode" + "github.com/ChainSafe/gossamer/lib/trie/leaf" + "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type writeCall struct { + written []byte + n int + err error +} + +var errTest = errors.New("test error") + +//go:generate mockgen -destination=buffer_mock_test.go -package $GOPACKAGE github.com/ChainSafe/gossamer/lib/trie/encode Buffer + +func Test_Branch_Encode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + writes []writeCall + parallel bool + wrappedErr error + errMessage string + }{ + "clean branch with encoding": { + branch: &Branch{ + Encoding: []byte{1, 2, 3}, + }, + writes: []writeCall{ + { // stored encoding + written: []byte{1, 2, 3}, + }, + }, + }, + "write error for clean branch with encoding": { + branch: &Branch{ + Encoding: []byte{1, 2, 3}, + }, + writes: []writeCall{ + { // stored encoding + written: []byte{1, 2, 3}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write stored encoding to buffer: test error", + }, + "header encoding error": { + branch: &Branch{ + Key: make([]byte, 63+(1<<16)), + }, + wrappedErr: encode.ErrPartialKeyTooBig, + errMessage: "cannot encode header: partial key length cannot be larger than or equal to 2^16: 65536", + }, + "buffer write error for encoded header": { + branch: &Branch{ + Key: []byte{1, 2, 3}, + Value: []byte{100}, + }, + writes: []writeCall{ + { // header + written: []byte{195}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write encoded header to buffer: test error", + }, + "buffer write error for encoded key": { + branch: &Branch{ + Key: []byte{1, 2, 3}, + Value: []byte{100}, + }, + writes: []writeCall{ + { // header + written: []byte{195}, + }, + { // key LE + written: []byte{1, 35}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write encoded key to buffer: test error", + }, + "buffer write error for children bitmap": { + branch: &Branch{ + Key: []byte{1, 2, 3}, + Value: []byte{100}, + Children: [16]node.Node{ + nil, nil, nil, &leaf.Leaf{Key: []byte{9}}, + nil, nil, nil, &leaf.Leaf{Key: []byte{11}}, + }, + }, + writes: []writeCall{ + { // header + written: []byte{195}, + }, + { // key LE + written: []byte{1, 35}, + }, + { // children bitmap + written: []byte{136, 0}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write children bitmap to buffer: test error", + }, + "buffer write error for value": { + branch: &Branch{ + Key: []byte{1, 2, 3}, + Value: []byte{100}, + Children: [16]node.Node{ + nil, nil, nil, &leaf.Leaf{Key: []byte{9}}, + nil, nil, nil, &leaf.Leaf{Key: []byte{11}}, + }, + }, + writes: []writeCall{ + { // header + written: []byte{195}, + }, + { // key LE + written: []byte{1, 35}, + }, + { // children bitmap + written: []byte{136, 0}, + }, + { // value + written: []byte{4, 100}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write encoded value to buffer: test error", + }, + "buffer write error for children encoded sequentially": { + branch: &Branch{ + Key: []byte{1, 2, 3}, + Value: []byte{100}, + Children: [16]node.Node{ + nil, nil, nil, &leaf.Leaf{Key: []byte{9}}, + nil, nil, nil, &leaf.Leaf{Key: []byte{11}}, + }, + }, + writes: []writeCall{ + { // header + written: []byte{195}, + }, + { // key LE + written: []byte{1, 35}, + }, + { // children bitmap + written: []byte{136, 0}, + }, + { // value + written: []byte{4, 100}, + }, + { // children + written: []byte{12, 65, 9, 0}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot encode children of branch: " + + "cannot encode child at index 3: " + + "failed to write child to buffer: test error", + }, + "buffer write error for children encoded in parallel": { + branch: &Branch{ + Key: []byte{1, 2, 3}, + Value: []byte{100}, + Children: [16]node.Node{ + nil, nil, nil, &leaf.Leaf{Key: []byte{9}}, + nil, nil, nil, &leaf.Leaf{Key: []byte{11}}, + }, + }, + writes: []writeCall{ + { // header + written: []byte{195}, + }, + { // key LE + written: []byte{1, 35}, + }, + { // children bitmap + written: []byte{136, 0}, + }, + { // value + written: []byte{4, 100}, + }, + { // first children + written: []byte{12, 65, 9, 0}, + err: errTest, + }, + }, + parallel: true, + wrappedErr: errTest, + errMessage: "cannot encode children of branch: " + + "cannot encode child at index 3: " + + "failed to write child to buffer: " + + "test error", + }, + "success with parallel children encoding": { + branch: &Branch{ + Key: []byte{1, 2, 3}, + Value: []byte{100}, + Children: [16]node.Node{ + nil, nil, nil, &leaf.Leaf{Key: []byte{9}}, + nil, nil, nil, &leaf.Leaf{Key: []byte{11}}, + }, + }, + writes: []writeCall{ + { // header + written: []byte{195}, + }, + { // key LE + written: []byte{1, 35}, + }, + { // children bitmap + written: []byte{136, 0}, + }, + { // value + written: []byte{4, 100}, + }, + { // first children + written: []byte{12, 65, 9, 0}, + }, + { // second children + written: []byte{12, 65, 11, 0}, + }, + }, + parallel: true, + }, + "success with sequential children encoding": { + branch: &Branch{ + Key: []byte{1, 2, 3}, + Value: []byte{100}, + Children: [16]node.Node{ + nil, nil, nil, &leaf.Leaf{Key: []byte{9}}, + nil, nil, nil, &leaf.Leaf{Key: []byte{11}}, + }, + }, + writes: []writeCall{ + { // header + written: []byte{195}, + }, + { // key LE + written: []byte{1, 35}, + }, + { // children bitmap + written: []byte{136, 0}, + }, + { // value + written: []byte{4, 100}, + }, + { // first children + written: []byte{12, 65, 9, 0}, + }, + { // second children + written: []byte{12, 65, 11, 0}, + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + buffer := NewMockBuffer(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := buffer.EXPECT(). + Write(write.written). + Return(write.n, write.err) + + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + err := testCase.branch.Encode(buffer) + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + }) + } +} + +func Test_encodeChildrenInParallel(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + children [16]node.Node + writes []writeCall + wrappedErr error + errMessage string + }{ + "no children": {}, + "first child not nil": { + children: [16]node.Node{ + &leaf.Leaf{Key: []byte{1}}, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + }, + }, + }, + "last child not nil": { + children: [16]node.Node{ + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + &leaf.Leaf{Key: []byte{1}}, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + }, + }, + }, + "first two children not nil": { + children: [16]node.Node{ + &leaf.Leaf{Key: []byte{1}}, + &leaf.Leaf{Key: []byte{2}}, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + }, + { + written: []byte{12, 65, 2, 0}, + }, + }, + }, + "encoding error": { + children: [16]node.Node{ + nil, nil, nil, nil, + nil, nil, nil, nil, + nil, nil, nil, + &leaf.Leaf{ + Key: []byte{1}, + }, + nil, nil, nil, nil, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write encoding of child at index 11: " + + "test error", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + buffer := NewMockWriter(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := buffer.EXPECT(). + Write(write.written). + Return(write.n, write.err) + + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + err := encodeChildrenInParallel(testCase.children, buffer) + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + }) + } +} + +func Test_encodeChildrenSequentially(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + children [16]node.Node + writes []writeCall + wrappedErr error + errMessage string + }{ + "no children": {}, + "first child not nil": { + children: [16]node.Node{ + &leaf.Leaf{Key: []byte{1}}, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + }, + }, + }, + "last child not nil": { + children: [16]node.Node{ + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, + &leaf.Leaf{Key: []byte{1}}, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + }, + }, + }, + "first two children not nil": { + children: [16]node.Node{ + &leaf.Leaf{Key: []byte{1}}, + &leaf.Leaf{Key: []byte{2}}, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + }, + { + written: []byte{12, 65, 2, 0}, + }, + }, + }, + "encoding error": { + children: [16]node.Node{ + nil, nil, nil, nil, + nil, nil, nil, nil, + nil, nil, nil, + &leaf.Leaf{ + Key: []byte{1}, + }, + nil, nil, nil, nil, + }, + writes: []writeCall{ + { + written: []byte{12, 65, 1, 0}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot encode child at index 11: " + + "failed to write child to buffer: test error", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + buffer := NewMockWriter(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := buffer.EXPECT(). + Write(write.written). + Return(write.n, write.err) + + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + err := encodeChildrenSequentially(testCase.children, buffer) + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + }) + } +} + +//go:generate mockgen -destination=writer_mock_test.go -package $GOPACKAGE io Writer + +func Test_encodeChild(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + child node.Node + writeCall bool + write writeCall + wrappedErr error + errMessage string + }{ + "nil node": {}, + "nil leaf": { + child: (*leaf.Leaf)(nil), + }, + "nil branch": { + child: (*Branch)(nil), + }, + "empty leaf child": { + child: &leaf.Leaf{}, + writeCall: true, + write: writeCall{ + written: []byte{8, 64, 0}, + }, + }, + "empty branch child": { + child: &Branch{}, + writeCall: true, + write: writeCall{ + written: []byte{12, 128, 0, 0}, + }, + }, + "buffer write error": { + child: &Branch{}, + writeCall: true, + write: writeCall{ + written: []byte{12, 128, 0, 0}, + err: errTest, + }, + wrappedErr: errTest, + errMessage: "failed to write child to buffer: test error", + }, + "leaf child": { + child: &leaf.Leaf{ + Key: []byte{1}, + Value: []byte{2}, + }, + writeCall: true, + write: writeCall{ + written: []byte{16, 65, 1, 4, 2}, + }, + }, + "branch child": { + child: &Branch{ + Key: []byte{1}, + Value: []byte{2}, + Children: [16]node.Node{ + nil, nil, &leaf.Leaf{ + Key: []byte{5}, + Value: []byte{6}, + }, + }, + }, + writeCall: true, + write: writeCall{ + written: []byte{44, 193, 1, 4, 0, 4, 2, 16, 65, 5, 4, 6}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + buffer := NewMockWriter(ctrl) + + if testCase.writeCall { + buffer.EXPECT(). + Write(testCase.write.written). + Return(testCase.write.n, testCase.write.err) + } + + err := encodeChild(testCase.child, buffer) + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/lib/trie/branch/generation.go b/lib/trie/branch/generation.go new file mode 100644 index 0000000000..a5d8f4e510 --- /dev/null +++ b/lib/trie/branch/generation.go @@ -0,0 +1,14 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package branch + +// SetGeneration sets the generation given to the branch. +func (b *Branch) SetGeneration(generation uint64) { + b.Generation = generation +} + +// GetGeneration returns the generation of the branch. +func (b *Branch) GetGeneration() uint64 { + return b.Generation +} diff --git a/lib/trie/branch/hash.go b/lib/trie/branch/hash.go new file mode 100644 index 0000000000..d826dc7ba8 --- /dev/null +++ b/lib/trie/branch/hash.go @@ -0,0 +1,68 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package branch + +import ( + "bytes" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie/pools" +) + +// SetEncodingAndHash sets the encoding and hash slices +// given to the branch. Note it does not copy them, so beware. +func (b *Branch) SetEncodingAndHash(enc, hash []byte) { + b.Encoding = enc + b.Hash = hash +} + +// GetHash returns the hash of the branch. +// Note it does not copy it, so modifying +// the returned hash will modify the hash +// of the branch. +func (b *Branch) GetHash() []byte { + return b.Hash +} + +// EncodeAndHash returns the encoding of the branch and +// the blake2b hash digest of the encoding of the branch. +// If the encoding is less than 32 bytes, the hash returned +// is the encoding and not the hash of the encoding. +func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) { + if !b.Dirty && b.Encoding != nil && b.Hash != nil { + return b.Encoding, b.Hash, nil + } + + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + buffer.Reset() + defer pools.EncodingBuffers.Put(buffer) + + err = b.Encode(buffer) + if err != nil { + return nil, nil, err + } + + bufferBytes := buffer.Bytes() + + b.Encoding = make([]byte, len(bufferBytes)) + copy(b.Encoding, bufferBytes) + encoding = b.Encoding // no need to copy + + if buffer.Len() < 32 { + b.Hash = make([]byte, len(bufferBytes)) + copy(b.Hash, bufferBytes) + hash = b.Hash // no need to copy + return encoding, hash, nil + } + + // Note: using the sync.Pool's buffer is useful here. + hashArray, err := common.Blake2bHash(buffer.Bytes()) + if err != nil { + return nil, nil, err + } + b.Hash = hashArray[:] + hash = b.Hash // no need to copy + + return encoding, hash, nil +} diff --git a/lib/trie/branch/header.go b/lib/trie/branch/header.go new file mode 100644 index 0000000000..b990d00529 --- /dev/null +++ b/lib/trie/branch/header.go @@ -0,0 +1,32 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package branch + +import "github.com/ChainSafe/gossamer/lib/trie/encode" + +// Header creates the encoded header for the branch. +func (b *Branch) Header() (encoding []byte, err error) { + var header byte + if b.Value == nil { + header = 2 << 6 + } else { + header = 3 << 6 + } + + var encodedPublicKeyLength []byte + if len(b.Key) >= 63 { + header = header | 0x3f + encodedPublicKeyLength, err = encode.ExtraPartialKeyLength(len(b.Key)) + if err != nil { + return nil, err + } + } else { + header = header | byte(len(b.Key)) + } + + encoding = make([]byte, 0, len(encodedPublicKeyLength)+1) + encoding = append(encoding, header) + encoding = append(encoding, encodedPublicKeyLength...) + return encoding, nil +} diff --git a/lib/trie/branch/header_test.go b/lib/trie/branch/header_test.go new file mode 100644 index 0000000000..ad251b468c --- /dev/null +++ b/lib/trie/branch/header_test.go @@ -0,0 +1,80 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package branch + +import ( + "testing" + + "github.com/ChainSafe/gossamer/lib/trie/encode" + "github.com/stretchr/testify/assert" +) + +func Test_Branch_Header(t *testing.T) { + testCases := map[string]struct { + branch *Branch + encoding []byte + wrappedErr error + errMessage string + }{ + "no key": { + branch: &Branch{}, + encoding: []byte{0x80}, + }, + "with value": { + branch: &Branch{ + Value: []byte{}, + }, + encoding: []byte{0xc0}, + }, + "key of length 30": { + branch: &Branch{ + Key: make([]byte, 30), + }, + encoding: []byte{0x9e}, + }, + "key of length 62": { + branch: &Branch{ + Key: make([]byte, 62), + }, + encoding: []byte{0xbe}, + }, + "key of length 63": { + branch: &Branch{ + Key: make([]byte, 63), + }, + encoding: []byte{0xbf, 0x0}, + }, + "key of length 64": { + branch: &Branch{ + Key: make([]byte, 64), + }, + encoding: []byte{0xbf, 0x1}, + }, + "key too big": { + branch: &Branch{ + Key: make([]byte, 65535+63), + }, + wrappedErr: encode.ErrPartialKeyTooBig, + errMessage: "partial key length cannot be larger than or equal to 2^16: 65535", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + encoding, err := testCase.branch.Header() + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, testCase.encoding, encoding) + }) + } +} diff --git a/lib/trie/branch/key.go b/lib/trie/branch/key.go new file mode 100644 index 0000000000..aa88e8a0c3 --- /dev/null +++ b/lib/trie/branch/key.go @@ -0,0 +1,11 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package branch + +// SetKey sets the key to the branch. +// Note it does not copy it so modifying the passed key +// will modify the key stored in the branch. +func (b *Branch) SetKey(key []byte) { + b.Key = key +} diff --git a/lib/trie/branch/writer_mock_test.go b/lib/trie/branch/writer_mock_test.go new file mode 100644 index 0000000000..609c8f248d --- /dev/null +++ b/lib/trie/branch/writer_mock_test.go @@ -0,0 +1,49 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: io (interfaces: Writer) + +// Package branch is a generated GoMock package. +package branch + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockWriter is a mock of Writer interface. +type MockWriter struct { + ctrl *gomock.Controller + recorder *MockWriterMockRecorder +} + +// MockWriterMockRecorder is the mock recorder for MockWriter. +type MockWriterMockRecorder struct { + mock *MockWriter +} + +// NewMockWriter creates a new mock instance. +func NewMockWriter(ctrl *gomock.Controller) *MockWriter { + mock := &MockWriter{ctrl: ctrl} + mock.recorder = &MockWriterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockWriter) EXPECT() *MockWriterMockRecorder { + return m.recorder +} + +// Write mocks base method. +func (m *MockWriter) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockWriterMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockWriter)(nil).Write), arg0) +} diff --git a/lib/trie/bytesBuffer_mock_test.go b/lib/trie/bytesBuffer_mock_test.go deleted file mode 100644 index c59f7dd4a9..0000000000 --- a/lib/trie/bytesBuffer_mock_test.go +++ /dev/null @@ -1,77 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: hash.go - -// Package trie is a generated GoMock package. -package trie - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockbytesBuffer is a mock of bytesBuffer interface. -type MockbytesBuffer struct { - ctrl *gomock.Controller - recorder *MockbytesBufferMockRecorder -} - -// MockbytesBufferMockRecorder is the mock recorder for MockbytesBuffer. -type MockbytesBufferMockRecorder struct { - mock *MockbytesBuffer -} - -// NewMockbytesBuffer creates a new mock instance. -func NewMockbytesBuffer(ctrl *gomock.Controller) *MockbytesBuffer { - mock := &MockbytesBuffer{ctrl: ctrl} - mock.recorder = &MockbytesBufferMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockbytesBuffer) EXPECT() *MockbytesBufferMockRecorder { - return m.recorder -} - -// Bytes mocks base method. -func (m *MockbytesBuffer) Bytes() []byte { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Bytes") - ret0, _ := ret[0].([]byte) - return ret0 -} - -// Bytes indicates an expected call of Bytes. -func (mr *MockbytesBufferMockRecorder) Bytes() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bytes", reflect.TypeOf((*MockbytesBuffer)(nil).Bytes)) -} - -// Len mocks base method. -func (m *MockbytesBuffer) Len() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Len") - ret0, _ := ret[0].(int) - return ret0 -} - -// Len indicates an expected call of Len. -func (mr *MockbytesBufferMockRecorder) Len() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockbytesBuffer)(nil).Len)) -} - -// Write mocks base method. -func (m *MockbytesBuffer) Write(p []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", p) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Write indicates an expected call of Write. -func (mr *MockbytesBufferMockRecorder) Write(p interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockbytesBuffer)(nil).Write), p) -} diff --git a/lib/trie/codec.go b/lib/trie/codec.go deleted file mode 100644 index 33bad34007..0000000000 --- a/lib/trie/codec.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -// keyToNibbles turns bytes into nibbles -// does not rearrange the nibbles; assumes they are already ordered in LE -func keyToNibbles(in []byte) []byte { - if len(in) == 0 { - return []byte{} - } else if len(in) == 1 && in[0] == 0 { - return []byte{0, 0} - } - - l := len(in) * 2 - res := make([]byte, l) - for i, b := range in { - res[2*i] = b / 16 - res[2*i+1] = b % 16 - } - - return res -} - -// nibblesToKey turns a slice of nibbles w/ length k into a big endian byte array -// if the length of the input is odd, the result is [ in[1] in[0] | ... | 0000 in[k-1] ] -// otherwise, res = [ in[1] in[0] | ... | in[k-1] in[k-2] ] -func nibblesToKey(in []byte) (res []byte) { - if len(in)%2 == 0 { - res = make([]byte, len(in)/2) - for i := 0; i < len(in); i += 2 { - res[i/2] = (in[i] & 0xf) | (in[i+1] << 4 & 0xf0) - } - } else { - res = make([]byte, len(in)/2+1) - for i := 0; i < len(in); i += 2 { - if i < len(in)-1 { - res[i/2] = (in[i] & 0xf) | (in[i+1] << 4 & 0xf0) - } else { - res[i/2] = (in[i] & 0xf) - } - } - } - - return res -} - -// nibblesToKey turns a slice of nibbles w/ length k into a little endian byte array -// assumes nibbles are already LE, does not rearrange nibbles -// if the length of the input is odd, the result is [ 0000 in[0] | in[1] in[2] | ... | in[k-2] in[k-1] ] -// otherwise, res = [ in[0] in[1] | ... | in[k-2] in[k-1] ] -func nibblesToKeyLE(in []byte) (res []byte) { - if len(in)%2 == 0 { - res = make([]byte, len(in)/2) - for i := 0; i < len(in); i += 2 { - res[i/2] = (in[i] << 4 & 0xf0) | (in[i+1] & 0xf) - } - } else { - res = make([]byte, len(in)/2+1) - res[0] = in[0] - for i := 2; i < len(in); i += 2 { - res[i/2] = (in[i-1] << 4 & 0xf0) | (in[i] & 0xf) - } - } - - return res -} diff --git a/lib/trie/codec_test.go b/lib/trie/codec_test.go deleted file mode 100644 index 108a5acfa7..0000000000 --- a/lib/trie/codec_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "bytes" - "fmt" - "testing" -) - -func TestKeyToNibbles(t *testing.T) { - tests := []struct { - input []byte - expected []byte - }{ - {[]byte{0x0}, []byte{0, 0}}, - {[]byte{0xFF}, []byte{0xF, 0xF}}, - {[]byte{0x3a, 0x05}, []byte{0x3, 0xa, 0x0, 0x5}}, - {[]byte{0xAA, 0xFF, 0x01}, []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}}, - {[]byte{0xAA, 0xFF, 0x01, 0xc2}, []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}}, - {[]byte{0xAA, 0xFF, 0x01, 0xc0}, []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x0}}, - } - - for _, test := range tests { - test := test - t.Run(fmt.Sprintf("%v", test.input), func(t *testing.T) { - res := keyToNibbles(test.input) - if !bytes.Equal(test.expected, res) { - t.Errorf("Output doesn't match expected. got=%v expected=%v\n", res, test.expected) - } - }) - } -} - -func TestNibblesToKey(t *testing.T) { - tests := []struct { - input []byte - expected []byte - }{ - {[]byte{0xF, 0xF}, []byte{0xFF}}, - {[]byte{0x3, 0xa, 0x0, 0x5}, []byte{0xa3, 0x50}}, - {[]byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}, []byte{0xaa, 0xff, 0x10}}, - {[]byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}, []byte{0xaa, 0xff, 0x10, 0x2c}}, - {[]byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc}, []byte{0xaa, 0xff, 0x10, 0x0c}}, - } - - for _, test := range tests { - test := test - t.Run(fmt.Sprintf("%v", test.input), func(t *testing.T) { - res := nibblesToKey(test.input) - if !bytes.Equal(test.expected, res) { - t.Errorf("Output doesn't match expected. got=%x expected=%x\n", res, test.expected) - } - }) - } -} - -func TestNibblesToKeyLE(t *testing.T) { - tests := []struct { - input []byte - expected []byte - }{ - {[]byte{0xF, 0xF}, []byte{0xFF}}, - {[]byte{0x3, 0xa, 0x0, 0x5}, []byte{0x3a, 0x05}}, - {[]byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}, []byte{0xaa, 0xff, 0x01}}, - {[]byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}, []byte{0xaa, 0xff, 0x01, 0xc2}}, - {[]byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc}, []byte{0xa, 0xaf, 0xf0, 0x1c}}, - } - - for _, test := range tests { - test := test - t.Run(fmt.Sprintf("%v", test.input), func(t *testing.T) { - res := nibblesToKeyLE(test.input) - if !bytes.Equal(test.expected, res) { - t.Errorf("Output doesn't match expected. got=%x expected=%x\n", res, test.expected) - } - }) - } -} diff --git a/lib/trie/database.go b/lib/trie/database.go index f53c0c139c..db337d0591 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -9,6 +9,10 @@ import ( "fmt" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie/branch" + "github.com/ChainSafe/gossamer/lib/trie/decode" + "github.com/ChainSafe/gossamer/lib/trie/leaf" + "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/ChainSafe/chaindb" ) @@ -31,7 +35,7 @@ func (t *Trie) Store(db chaindb.Database) error { return batch.Flush() } -func (t *Trie) store(db chaindb.Batch, curr Node) error { +func (t *Trie) store(db chaindb.Batch, curr node.Node) error { if curr == nil { return nil } @@ -46,8 +50,8 @@ func (t *Trie) store(db chaindb.Batch, curr Node) error { return err } - if c, ok := curr.(*Branch); ok { - for _, child := range c.children { + if c, ok := curr.(*branch.Branch); ok { + for _, child := range c.Children { if child == nil { continue } @@ -72,12 +76,12 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { return ErrEmptyProof } - mappedNodes := make(map[string]Node, len(proof)) + mappedNodes := make(map[string]node.Node, len(proof)) // map all the proofs hash -> decoded node // and takes the loop to indentify the root node for _, rawNode := range proof { - decNode, err := decodeBytes(rawNode) + decNode, err := decodeNode(bytes.NewBuffer(rawNode)) if err != nil { return err } @@ -103,13 +107,13 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { // loadProof is a recursive function that will create all the trie paths based // on the mapped proofs slice starting by the root -func (t *Trie) loadProof(proof map[string]Node, curr Node) { - c, ok := curr.(*Branch) +func (t *Trie) loadProof(proof map[string]node.Node, curr node.Node) { + c, ok := curr.(*branch.Branch) if !ok { return } - for i, child := range c.children { + for i, child := range c.Children { if child == nil { continue } @@ -119,7 +123,7 @@ func (t *Trie) loadProof(proof map[string]Node, curr Node) { continue } - c.children[i] = proofNode + c.Children[i] = proofNode t.loadProof(proof, proofNode) } } @@ -137,7 +141,7 @@ func (t *Trie) Load(db chaindb.Database, root common.Hash) error { return fmt.Errorf("failed to find root key=%s: %w", root, err) } - t.root, err = decodeBytes(enc) + t.root, err = decodeNode(bytes.NewBuffer(enc)) if err != nil { return err } @@ -148,9 +152,9 @@ func (t *Trie) Load(db chaindb.Database, root common.Hash) error { return t.load(db, t.root) } -func (t *Trie) load(db chaindb.Database, curr Node) error { - if c, ok := curr.(*Branch); ok { - for i, child := range c.children { +func (t *Trie) load(db chaindb.Database, curr node.Node) error { + if c, ok := curr.(*branch.Branch); ok { + for i, child := range c.Children { if child == nil { continue } @@ -158,10 +162,10 @@ func (t *Trie) load(db chaindb.Database, curr Node) error { hash := child.GetHash() enc, err := db.Get(hash) if err != nil { - return fmt.Errorf("failed to find node key=%x index=%d: %w", child.(*Leaf).hash, i, err) + return fmt.Errorf("failed to find node key=%x index=%d: %w", child.(*leaf.Leaf).Hash, i, err) } - child, err = decodeBytes(enc) + child, err = decodeNode(bytes.NewBuffer(enc)) if err != nil { return err } @@ -169,7 +173,7 @@ func (t *Trie) load(db chaindb.Database, curr Node) error { child.SetDirty(false) child.SetEncodingAndHash(enc, hash) - c.children[i] = child + c.Children[i] = child err = t.load(db, child) if err != nil { return err @@ -181,9 +185,9 @@ func (t *Trie) load(db chaindb.Database, curr Node) error { } // GetNodeHashes return hash of each key of the trie. -func (t *Trie) GetNodeHashes(curr Node, keys map[common.Hash]struct{}) error { - if c, ok := curr.(*Branch); ok { - for _, child := range c.children { +func (t *Trie) GetNodeHashes(curr node.Node, keys map[common.Hash]struct{}) error { + if c, ok := curr.(*branch.Branch); ok { + for _, child := range c.Children { if child == nil { continue } @@ -234,14 +238,14 @@ func GetFromDB(db chaindb.Database, root common.Hash, key []byte) ([]byte, error return nil, nil } - k := keyToNibbles(key) + k := decode.KeyLEToNibbles(key) enc, err := db.Get(root[:]) if err != nil { return nil, fmt.Errorf("failed to find root key=%s: %w", root, err) } - rootNode, err := decodeBytes(enc) + rootNode, err := decodeNode(bytes.NewBuffer(enc)) if err != nil { return nil, err } @@ -249,34 +253,34 @@ func GetFromDB(db chaindb.Database, root common.Hash, key []byte) ([]byte, error return getFromDB(db, rootNode, k) } -func getFromDB(db chaindb.Database, parent Node, key []byte) ([]byte, error) { +func getFromDB(db chaindb.Database, parent node.Node, key []byte) ([]byte, error) { var value []byte switch p := parent.(type) { - case *Branch: - length := lenCommonPrefix(p.key, key) + case *branch.Branch: + length := lenCommonPrefix(p.Key, key) // found the value at this node - if bytes.Equal(p.key, key) || len(key) == 0 { - return p.value, nil + if bytes.Equal(p.Key, key) || len(key) == 0 { + return p.Value, nil } // did not find value - if bytes.Equal(p.key[:length], key) && len(key) < len(p.key) { + if bytes.Equal(p.Key[:length], key) && len(key) < len(p.Key) { return nil, nil } - if p.children[key[length]] == nil { + if p.Children[key[length]] == nil { return nil, nil } // load child with potential value - enc, err := db.Get(p.children[key[length]].(*Leaf).hash) + enc, err := db.Get(p.Children[key[length]].(*leaf.Leaf).Hash) if err != nil { return nil, fmt.Errorf("failed to find node in database: %w", err) } - child, err := decodeBytes(enc) + child, err := decodeNode(bytes.NewBuffer(enc)) if err != nil { return nil, err } @@ -285,9 +289,9 @@ func getFromDB(db chaindb.Database, parent Node, key []byte) ([]byte, error) { if err != nil { return nil, err } - case *Leaf: - if bytes.Equal(p.key, key) { - return p.value, nil + case *leaf.Leaf: + if bytes.Equal(p.Key, key) { + return p.Value, nil } case nil: return nil, nil @@ -308,7 +312,7 @@ func (t *Trie) WriteDirty(db chaindb.Database) error { return batch.Flush() } -func (t *Trie) writeDirty(db chaindb.Batch, curr Node) error { +func (t *Trie) writeDirty(db chaindb.Batch, curr node.Node) error { if curr == nil || !curr.IsDirty() { return nil } @@ -333,8 +337,8 @@ func (t *Trie) writeDirty(db chaindb.Batch, curr Node) error { return err } - if c, ok := curr.(*Branch); ok { - for _, child := range c.children { + if c, ok := curr.(*branch.Branch); ok { + for _, child := range c.Children { if child == nil { continue } @@ -356,7 +360,7 @@ func (t *Trie) GetInsertedNodeHashes() ([]common.Hash, error) { return t.getInsertedNodeHashes(t.root) } -func (t *Trie) getInsertedNodeHashes(curr Node) ([]common.Hash, error) { +func (t *Trie) getInsertedNodeHashes(curr node.Node) ([]common.Hash, error) { var nodeHashes []common.Hash if curr == nil || !curr.IsDirty() { return nil, nil @@ -379,8 +383,8 @@ func (t *Trie) getInsertedNodeHashes(curr Node) ([]common.Hash, error) { nodeHash := common.BytesToHash(hash) nodeHashes = append(nodeHashes, nodeHash) - if c, ok := curr.(*Branch); ok { - for _, child := range c.children { + if c, ok := curr.(*branch.Branch); ok { + for _, child := range c.Children { if child == nil { continue } diff --git a/lib/trie/decode.go b/lib/trie/decode.go new file mode 100644 index 0000000000..452c60cdc7 --- /dev/null +++ b/lib/trie/decode.go @@ -0,0 +1,45 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package trie + +import ( + "errors" + "fmt" + "io" + + "github.com/ChainSafe/gossamer/lib/trie/branch" + "github.com/ChainSafe/gossamer/lib/trie/decode" + "github.com/ChainSafe/gossamer/lib/trie/leaf" + "github.com/ChainSafe/gossamer/lib/trie/node" +) + +var ( + ErrReadHeaderByte = errors.New("cannot read header byte") + ErrUnknownNodeType = errors.New("unknown node type") +) + +func decodeNode(reader io.Reader) (n node.Node, err error) { + header, err := decode.ReadNextByte(reader) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrReadHeaderByte, err) + } + + nodeType := header >> 6 + switch nodeType { + case node.LeafType: + n, err = leaf.Decode(reader, header) + if err != nil { + return nil, fmt.Errorf("cannot decode leaf: %w", err) + } + return n, nil + case node.BranchType, node.BranchWithValueType: + n, err = branch.Decode(reader, header) + if err != nil { + return nil, fmt.Errorf("cannot decode branch: %w", err) + } + return n, nil + default: + return nil, fmt.Errorf("%w: %d", ErrUnknownNodeType, nodeType) + } +} diff --git a/lib/trie/decode/byte.go b/lib/trie/decode/byte.go new file mode 100644 index 0000000000..4560fcbc90 --- /dev/null +++ b/lib/trie/decode/byte.go @@ -0,0 +1,16 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package decode + +import "io" + +// ReadNextByte reads the next byte from the reader. +func ReadNextByte(reader io.Reader) (b byte, err error) { + buffer := make([]byte, 1) + _, err = reader.Read(buffer) + if err != nil { + return 0, err + } + return buffer[0], nil +} diff --git a/lib/trie/decode/byte_test.go b/lib/trie/decode/byte_test.go new file mode 100644 index 0000000000..4a8115d287 --- /dev/null +++ b/lib/trie/decode/byte_test.go @@ -0,0 +1,52 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package decode + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_ReadNextByte(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reader io.Reader + b byte + errWrapped error + errMessage string + }{ + "empty buffer": { + reader: bytes.NewBuffer(nil), + errWrapped: io.EOF, + errMessage: "EOF", + }, + "single byte buffer": { + reader: bytes.NewBuffer([]byte{1}), + b: 1, + }, + "two bytes buffer": { + reader: bytes.NewBuffer([]byte{1, 2}), + b: 1, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + b, err := ReadNextByte(testCase.reader) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.b, b) + }) + } +} diff --git a/lib/trie/decode/key.go b/lib/trie/decode/key.go new file mode 100644 index 0000000000..7c014bc944 --- /dev/null +++ b/lib/trie/decode/key.go @@ -0,0 +1,77 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package decode + +import ( + "errors" + "fmt" + "io" +) + +const maxPartialKeySize = ^uint16(0) + +var ( + ErrPartialKeyTooBig = errors.New("partial key length cannot be larger than or equal to 2^16") + ErrReadKeyLength = errors.New("cannot read key length") + ErrReadKeyData = errors.New("cannot read key data") +) + +// Key decodes a key from a reader. +func Key(reader io.Reader, keyLength byte) (b []byte, err error) { + publicKeyLength := int(keyLength) + + if keyLength == 0x3f { + // partial key longer than 63, read next bytes for rest of pk len + for { + nextKeyLen, err := ReadNextByte(reader) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrReadKeyLength, err) + } + publicKeyLength += int(nextKeyLen) + + if nextKeyLen < 0xff { + break + } + + if publicKeyLength >= int(maxPartialKeySize) { + return nil, fmt.Errorf("%w: %d", + ErrPartialKeyTooBig, publicKeyLength) + } + } + } + + if publicKeyLength == 0 { + return []byte{}, nil + } + + key := make([]byte, publicKeyLength/2+publicKeyLength%2) + n, err := reader.Read(key) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrReadKeyData, err) + } else if n != len(key) { + return nil, fmt.Errorf("%w: read %d bytes instead of %d", + ErrReadKeyData, n, len(key)) + } + + return KeyLEToNibbles(key)[publicKeyLength%2:], nil +} + +// KeyLEToNibbles converts a Little Endian byte slice into nibbles. +// It assumes bytes are already in Little Endian and does not rearrange nibbles. +func KeyLEToNibbles(in []byte) (nibbles []byte) { + if len(in) == 0 { + return []byte{} + } else if len(in) == 1 && in[0] == 0 { + return []byte{0, 0} + } + + l := len(in) * 2 + nibbles = make([]byte, l) + for i, b := range in { + nibbles[2*i] = b / 16 + nibbles[2*i+1] = b % 16 + } + + return nibbles +} diff --git a/lib/trie/decode/key_test.go b/lib/trie/decode/key_test.go new file mode 100644 index 0000000000..c25749d7e8 --- /dev/null +++ b/lib/trie/decode/key_test.go @@ -0,0 +1,138 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package decode + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +func repeatBytes(n int, b byte) (slice []byte) { + slice = make([]byte, n) + for i := range slice { + slice[i] = b + } + return slice +} + +func Test_Key(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reader io.Reader + keyLength byte + b []byte + errWrapped error + errMessage string + }{ + "zero key length": { + b: []byte{}, + }, + "short key length": { + reader: bytes.NewBuffer([]byte{1, 2, 3}), + keyLength: 5, + b: []byte{0x1, 0x0, 0x2, 0x0, 0x3}, + }, + "key read error": { + reader: bytes.NewBuffer(nil), + keyLength: 5, + errWrapped: ErrReadKeyData, + errMessage: "cannot read key data: EOF", + }, + "long key length": { + reader: bytes.NewBuffer( + append( + []byte{ + 6, // key length + }, + repeatBytes(64, 7)..., // key data + )), + keyLength: 0x3f, + b: []byte{ + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7}, + }, + "key length read error": { + reader: bytes.NewBuffer(nil), + keyLength: 0x3f, + errWrapped: ErrReadKeyLength, + errMessage: "cannot read key length: EOF", + }, + "key length too big": { + reader: bytes.NewBuffer(repeatBytes(257, 0xff)), + keyLength: 0x3f, + errWrapped: ErrPartialKeyTooBig, + errMessage: "partial key length cannot be larger than or equal to 2^16: 65598", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + b, err := Key(testCase.reader, testCase.keyLength) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.b, b) + }) + } +} + +func Test_KeyToNibbles(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in []byte + nibbles []byte + }{ + "nil input": { + nibbles: []byte{}, + }, + "empty input": { + in: []byte{}, + nibbles: []byte{}, + }, + "0x0": { + in: []byte{0x0}, + nibbles: []byte{0, 0}}, + "0xFF": { + in: []byte{0xFF}, + nibbles: []byte{0xF, 0xF}}, + "0x3a 0x05": { + in: []byte{0x3a, 0x05}, + nibbles: []byte{0x3, 0xa, 0x0, 0x5}}, + "0xAA 0xFF 0x01": { + in: []byte{0xAA, 0xFF, 0x01}, + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}}, + "0xAA 0xFF 0x01 0xc2": { + in: []byte{0xAA, 0xFF, 0x01, 0xc2}, + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}}, + "0xAA 0xFF 0x01 0xc0": { + in: []byte{0xAA, 0xFF, 0x01, 0xc0}, + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x0}}, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + nibbles := KeyLEToNibbles(testCase.in) + + assert.Equal(t, testCase.nibbles, nibbles) + }) + } +} diff --git a/lib/trie/decode_test.go b/lib/trie/decode_test.go new file mode 100644 index 0000000000..32002d4f24 --- /dev/null +++ b/lib/trie/decode_test.go @@ -0,0 +1,106 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package trie + +import ( + "bytes" + "io" + "testing" + + "github.com/ChainSafe/gossamer/lib/trie/branch" + "github.com/ChainSafe/gossamer/lib/trie/decode" + "github.com/ChainSafe/gossamer/lib/trie/leaf" + "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func scaleEncodeBytes(t *testing.T, b ...byte) (encoded []byte) { + encoded, err := scale.Marshal(b) + require.NoError(t, err) + return encoded +} + +func Test_decodeNode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reader io.Reader + n node.Node + errWrapped error + errMessage string + }{ + "no data": { + reader: bytes.NewBuffer(nil), + errWrapped: ErrReadHeaderByte, + errMessage: "cannot read header byte: EOF", + }, + "unknown node type": { + reader: bytes.NewBuffer([]byte{0}), + errWrapped: ErrUnknownNodeType, + errMessage: "unknown node type: 0", + }, + "leaf decoding error": { + reader: bytes.NewBuffer([]byte{ + 65, // node type 1 and key length 1 + // missing key data byte + }), + errWrapped: decode.ErrReadKeyData, + errMessage: "cannot decode leaf: cannot decode key: cannot read key data: EOF", + }, + "leaf success": { + reader: bytes.NewBuffer( + append( + []byte{ + 65, // node type 1 and key length 1 + 9, // key data + }, + scaleEncodeBytes(t, 1, 2, 3)..., + ), + ), + n: &leaf.Leaf{ + Key: []byte{9}, + Value: []byte{1, 2, 3}, + Dirty: true, + }, + }, + "branch decoding error": { + reader: bytes.NewBuffer([]byte{ + 129, // node type 2 and key length 1 + // missing key data byte + }), + errWrapped: decode.ErrReadKeyData, + errMessage: "cannot decode branch: cannot decode key: cannot read key data: EOF", + }, + "branch success": { + reader: bytes.NewBuffer( + []byte{ + 129, // node type 2 and key length 1 + 9, // key data + 0, 0, // no children bitmap + }, + ), + n: &branch.Branch{ + Key: []byte{9}, + Dirty: true, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + n, err := decodeNode(testCase.reader) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.n, n) + }) + } +} diff --git a/lib/trie/encode/buffer.go b/lib/trie/encode/buffer.go new file mode 100644 index 0000000000..748f30ed97 --- /dev/null +++ b/lib/trie/encode/buffer.go @@ -0,0 +1,16 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package encode + +import "io" + +// Buffer is an interface with some methods of *bytes.Buffer. +type Buffer interface { + Writer + Len() int + Bytes() []byte +} + +// Writer is the io.Writer interface +type Writer io.Writer diff --git a/lib/trie/encode/doc.go b/lib/trie/encode/doc.go new file mode 100644 index 0000000000..e2fc9fd64d --- /dev/null +++ b/lib/trie/encode/doc.go @@ -0,0 +1,28 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package encode + +//nolint:lll +// Modified Merkle-Patricia Trie +// See https://github.com/w3f/polkadot-spec/blob/master/runtime-environment-spec/polkadot_re_spec.pdf for the full specification. +// +// Note that for the following definitions, `|` denotes concatenation +// +// Branch encoding: +// NodeHeader | Extra partial key length | Partial Key | Value +// `NodeHeader` is a byte such that: +// most significant two bits of `NodeHeader`: 10 if branch w/o value, 11 if branch w/ value +// least significant six bits of `NodeHeader`: if len(key) > 62, 0x3f, otherwise len(key) +// `Extra partial key length` is included if len(key) > 63 and consists of the remaining key length +// `Partial Key` is the branch's key +// `Value` is: Children Bitmap | SCALE Branch node Value | Hash(Enc(Child[i_1])) | Hash(Enc(Child[i_2])) | ... | Hash(Enc(Child[i_n])) +// +// Leaf encoding: +// NodeHeader | Extra partial key length | Partial Key | Value +// `NodeHeader` is a byte such that: +// most significant two bits of `NodeHeader`: 01 +// least significant six bits of `NodeHeader`: if len(key) > 62, 0x3f, otherwise len(key) +// `Extra partial key length` is included if len(key) > 63 and consists of the remaining key length +// `Partial Key` is the leaf's key +// `Value` is the leaf's SCALE encoded value diff --git a/lib/trie/encode/key.go b/lib/trie/encode/key.go new file mode 100644 index 0000000000..8a8e97ebd1 --- /dev/null +++ b/lib/trie/encode/key.go @@ -0,0 +1,84 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package encode + +import ( + "errors" + "fmt" +) + +const maxPartialKeySize = ^uint16(0) + +var ErrPartialKeyTooBig = errors.New("partial key length cannot be larger than or equal to 2^16") + +// ExtraPartialKeyLength encodes the public key length. +func ExtraPartialKeyLength(publicKeyLength int) (encoding []byte, err error) { + publicKeyLength -= 63 + + if publicKeyLength >= int(maxPartialKeySize) { + return nil, fmt.Errorf("%w: %d", + ErrPartialKeyTooBig, publicKeyLength) + } + + for i := uint16(0); i < maxPartialKeySize; i++ { + if publicKeyLength < 255 { + encoding = append(encoding, byte(publicKeyLength)) + break + } + encoding = append(encoding, byte(255)) + publicKeyLength -= 255 + } + + return encoding, nil +} + +// NibblesToKey converts a slice of nibbles with length k into a +// Big Endian byte slice. +// It assumes nibbles are already in Little Endian and does not rearrange nibbles. +// If the length of the input is odd, the result is +// [ in[1] in[0] | ... | 0000 in[k-1] ] +// Otherwise, the result is +// [ in[1] in[0] | ... | in[k-1] in[k-2] ] +func NibblesToKey(nibbles []byte) (key []byte) { + if len(nibbles)%2 == 0 { + key = make([]byte, len(nibbles)/2) + for i := 0; i < len(nibbles); i += 2 { + key[i/2] = (nibbles[i] & 0xf) | (nibbles[i+1] << 4 & 0xf0) + } + } else { + key = make([]byte, len(nibbles)/2+1) + for i := 0; i < len(nibbles); i += 2 { + key[i/2] = nibbles[i] & 0xf + if i < len(nibbles)-1 { + key[i/2] |= (nibbles[i+1] << 4 & 0xf0) + } + } + } + + return key +} + +// NibblesToKeyLE converts a slice of nibbles with length k into a +// Little Endian byte slice. +// It assumes nibbles are already in Little Endian and does not rearrange nibbles. +// If the length of the input is odd, the result is +// [ 0000 in[0] | in[1] in[2] | ... | in[k-2] in[k-1] ] +// Otherwise, the result is +// [ in[0] in[1] | ... | in[k-2] in[k-1] ] +func NibblesToKeyLE(nibbles []byte) (keyLE []byte) { + if len(nibbles)%2 == 0 { + keyLE = make([]byte, len(nibbles)/2) + for i := 0; i < len(nibbles); i += 2 { + keyLE[i/2] = (nibbles[i] << 4 & 0xf0) | (nibbles[i+1] & 0xf) + } + } else { + keyLE = make([]byte, len(nibbles)/2+1) + keyLE[0] = nibbles[0] + for i := 2; i < len(nibbles); i += 2 { + keyLE[i/2] = (nibbles[i-1] << 4 & 0xf0) | (nibbles[i] & 0xf) + } + } + + return keyLE +} diff --git a/lib/trie/encode/key_test.go b/lib/trie/encode/key_test.go new file mode 100644 index 0000000000..acdd5d0158 --- /dev/null +++ b/lib/trie/encode/key_test.go @@ -0,0 +1,170 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package encode + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_ExtraPartialKeyLength(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + publicKeyLength int + encoding []byte + err error + }{ + "length equal to maximum": { + publicKeyLength: int(maxPartialKeySize) + 63, + err: ErrPartialKeyTooBig, + }, + "zero length": { + encoding: []byte{0xc1}, + }, + "one length": { + publicKeyLength: 1, + encoding: []byte{0xc2}, + }, + "length at maximum allowed": { + publicKeyLength: int(maxPartialKeySize) + 62, + encoding: []byte{ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + encoding, err := ExtraPartialKeyLength(testCase.publicKeyLength) + + assert.ErrorIs(t, err, testCase.err) + assert.Equal(t, testCase.encoding, encoding) + }) + } +} + +func Test_NibblesToKey(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + nibbles []byte + key []byte + }{ + "nil nibbles": { + key: []byte{}, + }, + "empty nibbles": { + nibbles: []byte{}, + key: []byte{}, + }, + "0xF 0xF": { + nibbles: []byte{0xF, 0xF}, + key: []byte{0xFF}, + }, + "0x3 0xa 0x0 0x5": { + nibbles: []byte{0x3, 0xa, 0x0, 0x5}, + key: []byte{0xa3, 0x50}, + }, + "0xa 0xa 0xf 0xf 0x0 0x1": { + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}, + key: []byte{0xaa, 0xff, 0x10}, + }, + "0xa 0xa 0xf 0xf 0x0 0x1 0xc 0x2": { + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}, + key: []byte{0xaa, 0xff, 0x10, 0x2c}, + }, + "0xa 0xa 0xf 0xf 0x0 0x1 0xc": { + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc}, + key: []byte{0xaa, 0xff, 0x10, 0x0c}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + key := NibblesToKey(testCase.nibbles) + + assert.Equal(t, testCase.key, key) + }) + } +} + +func Test_NibblesToKeyLE(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + nibbles []byte + keyLE []byte + }{ + "nil nibbles": { + keyLE: []byte{}, + }, + "empty nibbles": { + nibbles: []byte{}, + keyLE: []byte{}, + }, + "0xF 0xF": { + nibbles: []byte{0xF, 0xF}, + keyLE: []byte{0xFF}, + }, + "0x3 0xa 0x0 0x5": { + nibbles: []byte{0x3, 0xa, 0x0, 0x5}, + keyLE: []byte{0x3a, 0x05}, + }, + "0xa 0xa 0xf 0xf 0x0 0x1": { + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}, + keyLE: []byte{0xaa, 0xff, 0x01}, + }, + "0xa 0xa 0xf 0xf 0x0 0x1 0xc 0x2": { + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}, + keyLE: []byte{0xaa, 0xff, 0x01, 0xc2}, + }, + "0xa 0xa 0xf 0xf 0x0 0x1 0xc": { + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc}, + keyLE: []byte{0xa, 0xaf, 0xf0, 0x1c}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + keyLE := NibblesToKeyLE(testCase.nibbles) + + assert.Equal(t, testCase.keyLE, keyLE) + }) + } +} diff --git a/lib/trie/encodedecode_test/branch_test.go b/lib/trie/encodedecode_test/branch_test.go new file mode 100644 index 0000000000..41b5d68dcc --- /dev/null +++ b/lib/trie/encodedecode_test/branch_test.go @@ -0,0 +1,89 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package encodedecode_test + +import ( + "bytes" + "testing" + + "github.com/ChainSafe/gossamer/lib/trie/branch" + "github.com/ChainSafe/gossamer/lib/trie/leaf" + "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_Branch_Encode_Decode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branchToEncode *branch.Branch + branchDecoded *branch.Branch + }{ + "empty branch": { + branchToEncode: new(branch.Branch), + branchDecoded: &branch.Branch{ + Key: []byte{}, + Dirty: true, + }, + }, + "branch with key 5": { + branchToEncode: &branch.Branch{ + Key: []byte{5}, + }, + branchDecoded: &branch.Branch{ + Key: []byte{5}, + Dirty: true, + }, + }, + "branch with two bytes key": { + branchToEncode: &branch.Branch{ + Key: []byte{0xf, 0xa}, // note: each byte cannot be larger than 0xf + }, + branchDecoded: &branch.Branch{ + Key: []byte{0xf, 0xa}, + Dirty: true, + }, + }, + "branch with child": { + branchToEncode: &branch.Branch{ + Key: []byte{5}, + Children: [16]node.Node{ + &leaf.Leaf{ + Key: []byte{9}, + Value: []byte{10}, + }, + }, + }, + branchDecoded: &branch.Branch{ + Key: []byte{5}, + Children: [16]node.Node{ + &leaf.Leaf{ + // TODO key and value are nil here?? Why? + Hash: []byte{0x41, 0x9, 0x4, 0xa}, + }, + }, + Dirty: true, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + buffer := bytes.NewBuffer(nil) + + err := testCase.branchToEncode.Encode(buffer) + require.NoError(t, err) + + const header = 0 + resultBranch, err := branch.Decode(buffer, header) + require.NoError(t, err) + + assert.Equal(t, testCase.branchDecoded, resultBranch) + }) + } +} diff --git a/lib/trie/encodedecode_test/nibbles_test.go b/lib/trie/encodedecode_test/nibbles_test.go new file mode 100644 index 0000000000..05fd0b5a95 --- /dev/null +++ b/lib/trie/encodedecode_test/nibbles_test.go @@ -0,0 +1,50 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package encodedecode_test + +import ( + "testing" + + "github.com/ChainSafe/gossamer/lib/trie/decode" + "github.com/ChainSafe/gossamer/lib/trie/encode" + "github.com/stretchr/testify/assert" +) + +func Test_NibblesKeyLE(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + nibblesToEncode []byte + nibblesDecoded []byte + }{ + "empty input": { + nibblesToEncode: []byte{}, + nibblesDecoded: []byte{}, + }, + "one byte": { + nibblesToEncode: []byte{1}, + nibblesDecoded: []byte{0, 1}, + }, + "two bytes": { + nibblesToEncode: []byte{1, 2}, + nibblesDecoded: []byte{1, 2}, + }, + "three bytes": { + nibblesToEncode: []byte{1, 2, 3}, + nibblesDecoded: []byte{0, 1, 2, 3}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + keyLE := encode.NibblesToKeyLE(testCase.nibblesToEncode) + nibblesDecoded := decode.KeyLEToNibbles(keyLE) + + assert.Equal(t, testCase.nibblesDecoded, nibblesDecoded) + }) + } +} diff --git a/lib/trie/hash.go b/lib/trie/hash.go deleted file mode 100644 index 921436c5f6..0000000000 --- a/lib/trie/hash.go +++ /dev/null @@ -1,350 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "bytes" - "errors" - "fmt" - "hash" - "io" - "sync" - - "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/pkg/scale" - "golang.org/x/crypto/blake2b" -) - -var encodingBufferPool = &sync.Pool{ - New: func() interface{} { - const initialBufferCapacity = 1900000 // 1.9MB, from checking capacities at runtime - b := make([]byte, 0, initialBufferCapacity) - return bytes.NewBuffer(b) - }, -} - -var digestBufferPool = &sync.Pool{ - New: func() interface{} { - const bufferCapacity = 32 - b := make([]byte, 0, bufferCapacity) - return bytes.NewBuffer(b) - }, -} - -var hasherPool = &sync.Pool{ - New: func() interface{} { - hasher, err := blake2b.New256(nil) - if err != nil { - // Conversation on why we panic here: - // https://github.com/ChainSafe/gossamer/pull/2009#discussion_r753430764 - panic("cannot create Blake2b-256 hasher: " + err.Error()) - } - return hasher - }, -} - -func hashNode(n Node, digestBuffer io.Writer) (err error) { - encodingBuffer := encodingBufferPool.Get().(*bytes.Buffer) - encodingBuffer.Reset() - defer encodingBufferPool.Put(encodingBuffer) - - const parallel = false - - err = encodeNode(n, encodingBuffer, parallel) - if err != nil { - return fmt.Errorf("cannot encode node: %w", err) - } - - // if length of encoded leaf is less than 32 bytes, do not hash - if encodingBuffer.Len() < 32 { - _, err = digestBuffer.Write(encodingBuffer.Bytes()) - if err != nil { - return fmt.Errorf("cannot write encoded node to buffer: %w", err) - } - return nil - } - - // otherwise, hash encoded node - hasher := hasherPool.Get().(hash.Hash) - hasher.Reset() - defer hasherPool.Put(hasher) - - // Note: using the sync.Pool's buffer is useful here. - _, err = hasher.Write(encodingBuffer.Bytes()) - if err != nil { - return fmt.Errorf("cannot hash encoded node: %w", err) - } - - _, err = digestBuffer.Write(hasher.Sum(nil)) - if err != nil { - return fmt.Errorf("cannot write hash sum of node to buffer: %w", err) - } - return nil -} - -var ErrNodeTypeUnsupported = errors.New("node type is not supported") - -type bytesBuffer interface { - // note: cannot compose with io.Writer for mock generation - Write(p []byte) (n int, err error) - Len() int - Bytes() []byte -} - -// encodeNode writes the encoding of the node to the buffer given. -// It is the high-level function wrapping the encoding for different -// node types. The encoding has the following format: -// NodeHeader | Extra partial key length | Partial Key | Value -func encodeNode(n Node, buffer bytesBuffer, parallel bool) (err error) { - switch n := n.(type) { - case *Branch: - err := encodeBranch(n, buffer, parallel) - if err != nil { - return fmt.Errorf("cannot encode branch: %w", err) - } - return nil - case *Leaf: - err := encodeLeaf(n, buffer) - if err != nil { - return fmt.Errorf("cannot encode leaf: %w", err) - } - - n.encodingMu.Lock() - defer n.encodingMu.Unlock() - - // TODO remove this copying since it defeats the purpose of `buffer` - // and the sync.Pool. - n.encoding = make([]byte, buffer.Len()) - copy(n.encoding, buffer.Bytes()) - return nil - case nil: - _, err := buffer.Write([]byte{0}) - if err != nil { - return fmt.Errorf("cannot encode nil node: %w", err) - } - return nil - default: - return fmt.Errorf("%w: %T", ErrNodeTypeUnsupported, n) - } -} - -// encodeBranch encodes a branch with the encoding specified at the top of this package -// to the buffer given. -func encodeBranch(b *Branch, buffer io.Writer, parallel bool) (err error) { - if !b.dirty && b.encoding != nil { - _, err = buffer.Write(b.encoding) - if err != nil { - return fmt.Errorf("cannot write stored encoding to buffer: %w", err) - } - return nil - } - - encodedHeader, err := b.header() - if err != nil { - return fmt.Errorf("cannot encode header: %w", err) - } - - _, err = buffer.Write(encodedHeader) - if err != nil { - return fmt.Errorf("cannot write encoded header to buffer: %w", err) - } - - keyLE := nibblesToKeyLE(b.key) - _, err = buffer.Write(keyLE) - if err != nil { - return fmt.Errorf("cannot write encoded key to buffer: %w", err) - } - - childrenBitmap := common.Uint16ToBytes(b.childrenBitmap()) - _, err = buffer.Write(childrenBitmap) - if err != nil { - return fmt.Errorf("cannot write children bitmap to buffer: %w", err) - } - - if b.value != nil { - bytes, err := scale.Marshal(b.value) - if err != nil { - return fmt.Errorf("cannot scale encode value: %w", err) - } - - _, err = buffer.Write(bytes) - if err != nil { - return fmt.Errorf("cannot write encoded value to buffer: %w", err) - } - } - - if parallel { - err = encodeChildrenInParallel(b.children, buffer) - } else { - err = encodeChildrenSequentially(b.children, buffer) - } - if err != nil { - return fmt.Errorf("cannot encode children of branch: %w", err) - } - - return nil -} - -func encodeChildrenInParallel(children [16]Node, buffer io.Writer) (err error) { - type result struct { - index int - buffer *bytes.Buffer - err error - } - - resultsCh := make(chan result) - - for i, child := range children { - go func(index int, child Node) { - buffer := encodingBufferPool.Get().(*bytes.Buffer) - buffer.Reset() - // buffer is put back in the pool after processing its - // data in the select block below. - - err := encodeChild(child, buffer) - - resultsCh <- result{ - index: index, - buffer: buffer, - err: err, - } - }(i, child) - } - - currentIndex := 0 - resultBuffers := make([]*bytes.Buffer, len(children)) - for range children { - result := <-resultsCh - if result.err != nil && err == nil { // only set the first error we get - err = result.err - } - - resultBuffers[result.index] = result.buffer - - // write as many completed buffers to the result buffer. - for currentIndex < len(children) && - resultBuffers[currentIndex] != nil { - bufferSlice := resultBuffers[currentIndex].Bytes() - if len(bufferSlice) > 0 { - // note buffer.Write copies the byte slice given as argument - _, writeErr := buffer.Write(bufferSlice) - if writeErr != nil && err == nil { - err = fmt.Errorf( - "cannot write encoding of child at index %d: %w", - currentIndex, writeErr) - } - } - - encodingBufferPool.Put(resultBuffers[currentIndex]) - resultBuffers[currentIndex] = nil - - currentIndex++ - } - } - - for _, buffer := range resultBuffers { - if buffer == nil { // already emptied and put back in pool - continue - } - encodingBufferPool.Put(buffer) - } - - return err -} - -func encodeChildrenSequentially(children [16]Node, buffer io.Writer) (err error) { - for i, child := range children { - err = encodeChild(child, buffer) - if err != nil { - return fmt.Errorf("cannot encode child at index %d: %w", i, err) - } - } - return nil -} - -func encodeChild(child Node, buffer io.Writer) (err error) { - var isNil bool - switch impl := child.(type) { - case *Branch: - isNil = impl == nil - case *Leaf: - isNil = impl == nil - default: - isNil = child == nil - } - if isNil { - return nil - } - - scaleEncodedChild, err := encodeAndHash(child) - if err != nil { - return fmt.Errorf("failed to hash and scale encode child: %w", err) - } - - _, err = buffer.Write(scaleEncodedChild) - if err != nil { - return fmt.Errorf("failed to write child to buffer: %w", err) - } - - return nil -} - -func encodeAndHash(n Node) (b []byte, err error) { - buffer := digestBufferPool.Get().(*bytes.Buffer) - buffer.Reset() - defer digestBufferPool.Put(buffer) - - err = hashNode(n, buffer) - if err != nil { - return nil, fmt.Errorf("cannot hash node: %w", err) - } - - scEncChild, err := scale.Marshal(buffer.Bytes()) - if err != nil { - return nil, fmt.Errorf("cannot scale encode hashed node: %w", err) - } - return scEncChild, nil -} - -// encodeLeaf encodes a leaf to the buffer given, with the encoding -// specified at the top of this package. -func encodeLeaf(l *Leaf, buffer io.Writer) (err error) { - l.encodingMu.RLock() - defer l.encodingMu.RUnlock() - if !l.dirty && l.encoding != nil { - _, err = buffer.Write(l.encoding) - if err != nil { - return fmt.Errorf("cannot write stored encoding to buffer: %w", err) - } - return nil - } - - encodedHeader, err := l.header() - if err != nil { - return fmt.Errorf("cannot encode header: %w", err) - } - - _, err = buffer.Write(encodedHeader) - if err != nil { - return fmt.Errorf("cannot write encoded header to buffer: %w", err) - } - - keyLE := nibblesToKeyLE(l.key) - _, err = buffer.Write(keyLE) - if err != nil { - return fmt.Errorf("cannot write LE key to buffer: %w", err) - } - - encodedValue, err := scale.Marshal(l.value) // TODO scale encoder to write to buffer - if err != nil { - return fmt.Errorf("cannot scale marshal value: %w", err) - } - - _, err = buffer.Write(encodedValue) - if err != nil { - return fmt.Errorf("cannot write scale encoded value to buffer: %w", err) - } - - return nil -} diff --git a/lib/trie/hash_test.go b/lib/trie/hash_test.go deleted file mode 100644 index 2e2ea04317..0000000000 --- a/lib/trie/hash_test.go +++ /dev/null @@ -1,1012 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "errors" - "testing" - - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type writeCall struct { - written []byte - n int - err error -} - -var errTest = errors.New("test error") - -//go:generate mockgen -destination=bytesBuffer_mock_test.go -package $GOPACKAGE -source=hash.go . bytesBuffer -//go:generate mockgen -destination=node_mock_test.go -package $GOPACKAGE -source=node.go . node - -func Test_hashNode(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - n Node - writeCall bool - write writeCall - wrappedErr error - errMessage string - }{ - "node encoding error": { - n: NewMockNode(nil), - wrappedErr: ErrNodeTypeUnsupported, - errMessage: "cannot encode node: " + - "node type is not supported: " + - "*trie.MockNode", - }, - "small leaf buffer write error": { - n: &Leaf{ - encoding: []byte{1, 2, 3}, - }, - writeCall: true, - write: writeCall{ - written: []byte{1, 2, 3}, - err: errTest, - }, - wrappedErr: errTest, - errMessage: "cannot write encoded node to buffer: " + - "test error", - }, - "small leaf success": { - n: &Leaf{ - encoding: []byte{1, 2, 3}, - }, - writeCall: true, - write: writeCall{ - written: []byte{1, 2, 3}, - }, - }, - "leaf hash sum buffer write error": { - n: &Leaf{ - encoding: []byte{ - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - }, - }, - writeCall: true, - write: writeCall{ - written: []byte{ - 107, 105, 154, 175, 253, 170, 232, - 135, 240, 21, 207, 148, 82, 117, - 249, 230, 80, 197, 254, 17, 149, - 108, 50, 7, 80, 56, 114, 176, - 84, 114, 125, 234}, - err: errTest, - }, - wrappedErr: errTest, - errMessage: "cannot write hash sum of node to buffer: " + - "test error", - }, - "leaf hash sum success": { - n: &Leaf{ - encoding: []byte{ - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - }, - }, - writeCall: true, - write: writeCall{ - written: []byte{ - 107, 105, 154, 175, 253, 170, 232, - 135, 240, 21, 207, 148, 82, 117, - 249, 230, 80, 197, 254, 17, 149, - 108, 50, 7, 80, 56, 114, 176, - 84, 114, 125, 234}, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockWriter(ctrl) - if testCase.writeCall { - buffer.EXPECT(). - Write(testCase.write.written). - Return(testCase.write.n, testCase.write.err) - } - - err := hashNode(testCase.n, buffer) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - -func Test_encodeNode(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - n Node - writes []writeCall - leafEncodingCopy bool - leafBufferLen int - leafBufferBytes []byte - parallel bool - wrappedErr error - errMessage string - }{ - "branch error": { - n: &Branch{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - {written: []byte{1, 2, 3}, err: errTest}, - }, - wrappedErr: errTest, - errMessage: "cannot encode branch: " + - "cannot write stored encoding to buffer: " + - "test error", - }, - "branch success": { - n: &Branch{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - {written: []byte{1, 2, 3}}, - }, - }, - "leaf error": { - n: &Leaf{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - {written: []byte{1, 2, 3}, err: errTest}, - }, - wrappedErr: errTest, - errMessage: "cannot encode leaf: " + - "cannot write stored encoding to buffer: " + - "test error", - }, - "leaf success": { - n: &Leaf{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - {written: []byte{1, 2, 3}}, - }, - leafEncodingCopy: true, - leafBufferLen: 3, - leafBufferBytes: []byte{1, 2, 3}, - }, - "nil node error": { - writes: []writeCall{ - {written: []byte{0}, err: errTest}, - }, - wrappedErr: errTest, - errMessage: "cannot encode nil node: test error", - }, - "nil node success": { - writes: []writeCall{ - {written: []byte{0}}, - }, - }, - "unsupported node type": { - n: NewMockNode(nil), - wrappedErr: ErrNodeTypeUnsupported, - errMessage: "node type is not supported: *trie.MockNode", - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockbytesBuffer(ctrl) - var previousCall *gomock.Call - for _, write := range testCase.writes { - call := buffer.EXPECT(). - Write(write.written). - Return(write.n, write.err) - - if previousCall != nil { - call.After(previousCall) - } - previousCall = call - } - - if testCase.leafEncodingCopy { - previousCall = buffer.EXPECT().Len(). - Return(testCase.leafBufferLen). - After(previousCall) - buffer.EXPECT().Bytes(). - Return(testCase.leafBufferBytes). - After(previousCall) - } - - err := encodeNode(testCase.n, buffer, testCase.parallel) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - -func Test_encodeBranch(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - branch *Branch - writes []writeCall - parallel bool - wrappedErr error - errMessage string - }{ - "clean branch with encoding": { - branch: &Branch{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - { // stored encoding - written: []byte{1, 2, 3}, - }, - }, - }, - "write error for clean branch with encoding": { - branch: &Branch{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - { // stored encoding - written: []byte{1, 2, 3}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write stored encoding to buffer: test error", - }, - "header encoding error": { - branch: &Branch{ - key: make([]byte, 63+(1<<16)), - }, - wrappedErr: ErrPartialKeyTooBig, - errMessage: "cannot encode header: partial key length greater than or equal to 2^16", - }, - "buffer write error for encoded header": { - branch: &Branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write encoded header to buffer: test error", - }, - "buffer write error for encoded key": { - branch: &Branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write encoded key to buffer: test error", - }, - "buffer write error for children bitmap": { - branch: &Branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - children: [16]Node{ - nil, nil, nil, &Leaf{key: []byte{9}}, - nil, nil, nil, &Leaf{key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write children bitmap to buffer: test error", - }, - "buffer write error for value": { - branch: &Branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - children: [16]Node{ - nil, nil, nil, &Leaf{key: []byte{9}}, - nil, nil, nil, &Leaf{key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - }, - { // value - written: []byte{4, 100}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write encoded value to buffer: test error", - }, - "buffer write error for children encoded sequentially": { - branch: &Branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - children: [16]Node{ - nil, nil, nil, &Leaf{key: []byte{9}}, - nil, nil, nil, &Leaf{key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - }, - { // value - written: []byte{4, 100}, - }, - { // children - written: []byte{12, 65, 9, 0}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot encode children of branch: " + - "cannot encode child at index 3: " + - "failed to write child to buffer: test error", - }, - "buffer write error for children encoded in parallel": { - branch: &Branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - children: [16]Node{ - nil, nil, nil, &Leaf{key: []byte{9}}, - nil, nil, nil, &Leaf{key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - }, - { // value - written: []byte{4, 100}, - }, - { // first children - written: []byte{12, 65, 9, 0}, - err: errTest, - }, - { // second children - written: []byte{12, 65, 11, 0}, - }, - }, - parallel: true, - wrappedErr: errTest, - errMessage: "cannot encode children of branch: " + - "cannot write encoding of child at index 3: " + - "test error", - }, - "success with parallel children encoding": { - branch: &Branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - children: [16]Node{ - nil, nil, nil, &Leaf{key: []byte{9}}, - nil, nil, nil, &Leaf{key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - }, - { // value - written: []byte{4, 100}, - }, - { // first children - written: []byte{12, 65, 9, 0}, - }, - { // second children - written: []byte{12, 65, 11, 0}, - }, - }, - parallel: true, - }, - "success with sequential children encoding": { - branch: &Branch{ - key: []byte{1, 2, 3}, - value: []byte{100}, - children: [16]Node{ - nil, nil, nil, &Leaf{key: []byte{9}}, - nil, nil, nil, &Leaf{key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - }, - { // value - written: []byte{4, 100}, - }, - { // first children - written: []byte{12, 65, 9, 0}, - }, - { // second children - written: []byte{12, 65, 11, 0}, - }, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockReadWriter(ctrl) - var previousCall *gomock.Call - for _, write := range testCase.writes { - call := buffer.EXPECT(). - Write(write.written). - Return(write.n, write.err) - - if previousCall != nil { - call.After(previousCall) - } - previousCall = call - } - - err := encodeBranch(testCase.branch, buffer, testCase.parallel) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - -//go:generate mockgen -destination=readwriter_mock_test.go -package $GOPACKAGE io ReadWriter - -func Test_encodeChildrenInParallel(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - children [16]Node - writes []writeCall - wrappedErr error - errMessage string - }{ - "no children": {}, - "first child not nil": { - children: [16]Node{ - &Leaf{key: []byte{1}}, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - }, - }, - }, - "last child not nil": { - children: [16]Node{ - nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, - &Leaf{key: []byte{1}}, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - }, - }, - }, - "first two children not nil": { - children: [16]Node{ - &Leaf{key: []byte{1}}, - &Leaf{key: []byte{2}}, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - }, - { - written: []byte{12, 65, 2, 0}, - }, - }, - }, - "encoding error": { - children: [16]Node{ - nil, nil, nil, nil, - nil, nil, nil, nil, - nil, nil, nil, - &Leaf{ - key: []byte{1}, - }, - nil, nil, nil, nil, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write encoding of child at index 11: " + - "test error", - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockReadWriter(ctrl) - var previousCall *gomock.Call - for _, write := range testCase.writes { - call := buffer.EXPECT(). - Write(write.written). - Return(write.n, write.err) - - if previousCall != nil { - call.After(previousCall) - } - previousCall = call - } - - err := encodeChildrenInParallel(testCase.children, buffer) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - -func Test_encodeChildrenSequentially(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - children [16]Node - writes []writeCall - wrappedErr error - errMessage string - }{ - "no children": {}, - "first child not nil": { - children: [16]Node{ - &Leaf{key: []byte{1}}, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - }, - }, - }, - "last child not nil": { - children: [16]Node{ - nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, - &Leaf{key: []byte{1}}, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - }, - }, - }, - "first two children not nil": { - children: [16]Node{ - &Leaf{key: []byte{1}}, - &Leaf{key: []byte{2}}, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - }, - { - written: []byte{12, 65, 2, 0}, - }, - }, - }, - "encoding error": { - children: [16]Node{ - nil, nil, nil, nil, - nil, nil, nil, nil, - nil, nil, nil, - &Leaf{ - key: []byte{1}, - }, - nil, nil, nil, nil, - }, - writes: []writeCall{ - { - written: []byte{12, 65, 1, 0}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot encode child at index 11: " + - "failed to write child to buffer: test error", - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockReadWriter(ctrl) - var previousCall *gomock.Call - for _, write := range testCase.writes { - call := buffer.EXPECT(). - Write(write.written). - Return(write.n, write.err) - - if previousCall != nil { - call.After(previousCall) - } - previousCall = call - } - - err := encodeChildrenSequentially(testCase.children, buffer) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - -//go:generate mockgen -destination=writer_mock_test.go -package $GOPACKAGE io Writer - -func Test_encodeChild(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - child Node - writeCall bool - write writeCall - wrappedErr error - errMessage string - }{ - "nil node": {}, - "nil leaf": { - child: (*Leaf)(nil), - }, - "nil branch": { - child: (*Branch)(nil), - }, - "empty leaf child": { - child: &Leaf{}, - writeCall: true, - write: writeCall{ - written: []byte{8, 64, 0}, - }, - }, - "empty branch child": { - child: &Branch{}, - writeCall: true, - write: writeCall{ - written: []byte{12, 128, 0, 0}, - }, - }, - "buffer write error": { - child: &Branch{}, - writeCall: true, - write: writeCall{ - written: []byte{12, 128, 0, 0}, - err: errTest, - }, - wrappedErr: errTest, - errMessage: "failed to write child to buffer: test error", - }, - "leaf child": { - child: &Leaf{ - key: []byte{1}, - value: []byte{2}, - }, - writeCall: true, - write: writeCall{ - written: []byte{16, 65, 1, 4, 2}, - }, - }, - "branch child": { - child: &Branch{ - key: []byte{1}, - value: []byte{2}, - children: [16]Node{ - nil, nil, &Leaf{ - key: []byte{5}, - value: []byte{6}, - }, - }, - }, - writeCall: true, - write: writeCall{ - written: []byte{44, 193, 1, 4, 0, 4, 2, 16, 65, 5, 4, 6}, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockWriter(ctrl) - - if testCase.writeCall { - buffer.EXPECT(). - Write(testCase.write.written). - Return(testCase.write.n, testCase.write.err) - } - - err := encodeChild(testCase.child, buffer) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} - -func Test_encodeAndHash(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - n Node - b []byte - wrappedErr error - errMessage string - }{ - "node encoding error": { - n: NewMockNode(nil), - wrappedErr: ErrNodeTypeUnsupported, - errMessage: "cannot hash node: " + - "cannot encode node: " + - "node type is not supported: " + - "*trie.MockNode", - }, - "leaf": { - n: &Leaf{}, - b: []byte{0x8, 0x40, 0}, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - b, err := encodeAndHash(testCase.n) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - - assert.Equal(t, testCase.b, b) - }) - } -} - -func Test_encodeLeaf(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - leaf *Leaf - writes []writeCall - wrappedErr error - errMessage string - }{ - "clean leaf with encoding": { - leaf: &Leaf{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - { - written: []byte{1, 2, 3}, - }, - }, - }, - "write error for clean leaf with encoding": { - leaf: &Leaf{ - encoding: []byte{1, 2, 3}, - }, - writes: []writeCall{ - { - written: []byte{1, 2, 3}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write stored encoding to buffer: test error", - }, - "header encoding error": { - leaf: &Leaf{ - key: make([]byte, 63+(1<<16)), - }, - wrappedErr: ErrPartialKeyTooBig, - errMessage: "cannot encode header: partial key length greater than or equal to 2^16", - }, - "buffer write error for encoded header": { - leaf: &Leaf{ - key: []byte{1, 2, 3}, - }, - writes: []writeCall{ - { - written: []byte{67}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write encoded header to buffer: test error", - }, - "buffer write error for encoded key": { - leaf: &Leaf{ - key: []byte{1, 2, 3}, - }, - writes: []writeCall{ - { - written: []byte{67}, - }, - { - written: []byte{1, 35}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write LE key to buffer: test error", - }, - "buffer write error for encoded value": { - leaf: &Leaf{ - key: []byte{1, 2, 3}, - value: []byte{4, 5, 6}, - }, - writes: []writeCall{ - { - written: []byte{67}, - }, - { - written: []byte{1, 35}, - }, - { - written: []byte{12, 4, 5, 6}, - err: errTest, - }, - }, - wrappedErr: errTest, - errMessage: "cannot write scale encoded value to buffer: test error", - }, - "success": { - leaf: &Leaf{ - key: []byte{1, 2, 3}, - value: []byte{4, 5, 6}, - }, - writes: []writeCall{ - { - written: []byte{67}, - }, - { - written: []byte{1, 35}, - }, - { - written: []byte{12, 4, 5, 6}, - }, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - buffer := NewMockReadWriter(ctrl) - var previousCall *gomock.Call - for _, write := range testCase.writes { - call := buffer.EXPECT(). - Write(write.written). - Return(write.n, write.err) - - if previousCall != nil { - call.After(previousCall) - } - previousCall = call - } - - err := encodeLeaf(testCase.leaf, buffer) - - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - require.NoError(t, err) - } - }) - } -} diff --git a/lib/trie/leaf/buffer_mock_test.go b/lib/trie/leaf/buffer_mock_test.go new file mode 100644 index 0000000000..d3404d0a48 --- /dev/null +++ b/lib/trie/leaf/buffer_mock_test.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ChainSafe/gossamer/lib/trie/encode (interfaces: Buffer) + +// Package leaf is a generated GoMock package. +package leaf + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockBuffer is a mock of Buffer interface. +type MockBuffer struct { + ctrl *gomock.Controller + recorder *MockBufferMockRecorder +} + +// MockBufferMockRecorder is the mock recorder for MockBuffer. +type MockBufferMockRecorder struct { + mock *MockBuffer +} + +// NewMockBuffer creates a new mock instance. +func NewMockBuffer(ctrl *gomock.Controller) *MockBuffer { + mock := &MockBuffer{ctrl: ctrl} + mock.recorder = &MockBufferMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBuffer) EXPECT() *MockBufferMockRecorder { + return m.recorder +} + +// Bytes mocks base method. +func (m *MockBuffer) Bytes() []byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Bytes") + ret0, _ := ret[0].([]byte) + return ret0 +} + +// Bytes indicates an expected call of Bytes. +func (mr *MockBufferMockRecorder) Bytes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bytes", reflect.TypeOf((*MockBuffer)(nil).Bytes)) +} + +// Len mocks base method. +func (m *MockBuffer) Len() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Len") + ret0, _ := ret[0].(int) + return ret0 +} + +// Len indicates an expected call of Len. +func (mr *MockBufferMockRecorder) Len() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockBuffer)(nil).Len)) +} + +// Write mocks base method. +func (m *MockBuffer) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockBufferMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockBuffer)(nil).Write), arg0) +} diff --git a/lib/trie/leaf/copy.go b/lib/trie/leaf/copy.go new file mode 100644 index 0000000000..3f07972249 --- /dev/null +++ b/lib/trie/leaf/copy.go @@ -0,0 +1,29 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package leaf + +import "github.com/ChainSafe/gossamer/lib/trie/node" + +// Copy deep copies the leaf. +func (l *Leaf) Copy() node.Node { + l.RLock() + defer l.RUnlock() + + l.encodingMu.RLock() + defer l.encodingMu.RUnlock() + + cpy := &Leaf{ + Key: make([]byte, len(l.Key)), + Value: make([]byte, len(l.Value)), + Dirty: l.Dirty, + Hash: make([]byte, len(l.Hash)), + Encoding: make([]byte, len(l.Encoding)), + Generation: l.Generation, + } + copy(cpy.Key, l.Key) + copy(cpy.Value, l.Value) + copy(cpy.Hash, l.Hash) + copy(cpy.Encoding, l.Encoding) + return cpy +} diff --git a/lib/trie/leaf/decode.go b/lib/trie/leaf/decode.go new file mode 100644 index 0000000000..a01afb4a1c --- /dev/null +++ b/lib/trie/leaf/decode.go @@ -0,0 +1,57 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package leaf + +import ( + "errors" + "fmt" + "io" + + "github.com/ChainSafe/gossamer/lib/trie/decode" + "github.com/ChainSafe/gossamer/pkg/scale" +) + +var ( + ErrReadHeaderByte = errors.New("cannot read header byte") + ErrNodeTypeIsNotALeaf = errors.New("node type is not a leaf") + ErrDecodeValue = errors.New("cannot decode value") +) + +// Decode reads and decodes from a reader with the encoding specified in lib/trie/encode/doc.go. +func Decode(r io.Reader, header byte) (leaf *Leaf, err error) { // TODO return leaf + if header == 0 { // TODO remove this is taken care of by the caller + header, err = decode.ReadNextByte(r) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrReadHeaderByte, err) + } + } + + nodeType := header >> 6 + if nodeType != 1 { + return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotALeaf, nodeType) + } + + leaf = new(Leaf) + + keyLen := header & 0x3f + leaf.Key, err = decode.Key(r, keyLen) + if err != nil { + return nil, fmt.Errorf("cannot decode key: %w", err) + } + + sd := scale.NewDecoder(r) + var value []byte + err = sd.Decode(&value) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrDecodeValue, err) + } + + if len(value) > 0 { + leaf.Value = value + } + + leaf.Dirty = true // TODO move this as soon as it gets modified + + return leaf, nil +} diff --git a/lib/trie/leaf/decode_test.go b/lib/trie/leaf/decode_test.go new file mode 100644 index 0000000000..5e98eb71d5 --- /dev/null +++ b/lib/trie/leaf/decode_test.go @@ -0,0 +1,121 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package leaf + +import ( + "bytes" + "io" + "testing" + + "github.com/ChainSafe/gossamer/lib/trie/decode" + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func scaleEncodeBytes(t *testing.T, b ...byte) (encoded []byte) { + encoded, err := scale.Marshal(b) + require.NoError(t, err) + return encoded +} + +func concatByteSlices(slices [][]byte) (concatenated []byte) { + length := 0 + for i := range slices { + length += len(slices[i]) + } + concatenated = make([]byte, 0, length) + for _, slice := range slices { + concatenated = append(concatenated, slice...) + } + return concatenated +} + +func Test_Decode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reader io.Reader + header byte + leaf *Leaf + errWrapped error + errMessage string + }{ + "no data with header 0": { + reader: bytes.NewBuffer(nil), + errWrapped: ErrReadHeaderByte, + errMessage: "cannot read header byte: EOF", + }, + "no data with header 1": { + reader: bytes.NewBuffer(nil), + header: 1, + errWrapped: ErrNodeTypeIsNotALeaf, + errMessage: "node type is not a leaf: 0", + }, + "first byte as 0 header 0": { + reader: bytes.NewBuffer([]byte{0}), + errWrapped: ErrNodeTypeIsNotALeaf, + errMessage: "node type is not a leaf: 0", + }, + "key decoding error": { + reader: bytes.NewBuffer([]byte{ + 65, // node type 1 and key length 1 + // missing key data byte + }), + errWrapped: decode.ErrReadKeyData, + errMessage: "cannot decode key: cannot read key data: EOF", + }, + "value decoding error": { + reader: bytes.NewBuffer([]byte{ + 65, // node type 1 and key length 1 + 9, // key data + // missing value data + }), + errWrapped: ErrDecodeValue, + errMessage: "cannot decode value: EOF", + }, + "zero value": { + reader: bytes.NewBuffer([]byte{ + 65, // node type 1 and key length 1 + 9, // key data + 0, // missing value data + }), + leaf: &Leaf{ + Key: []byte{9}, + Dirty: true, + }, + }, + "success": { + reader: bytes.NewBuffer( + concatByteSlices([][]byte{ + { + 65, // node type 1 and key length 1 + 9, // key data + }, + scaleEncodeBytes(t, 1, 2, 3, 4, 5), // value data + }), + ), + leaf: &Leaf{ + Key: []byte{9}, + Value: []byte{1, 2, 3, 4, 5}, + Dirty: true, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + leaf, err := Decode(testCase.reader, testCase.header) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.leaf, leaf) + }) + } +} diff --git a/lib/trie/leaf/dirty.go b/lib/trie/leaf/dirty.go new file mode 100644 index 0000000000..b955754b03 --- /dev/null +++ b/lib/trie/leaf/dirty.go @@ -0,0 +1,14 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package leaf + +// IsDirty returns the dirty status of the leaf. +func (l *Leaf) IsDirty() bool { + return l.Dirty +} + +// SetDirty sets the dirty status to the leaf. +func (l *Leaf) SetDirty(dirty bool) { + l.Dirty = dirty +} diff --git a/lib/trie/leaf/encode.go b/lib/trie/leaf/encode.go new file mode 100644 index 0000000000..8467d204f7 --- /dev/null +++ b/lib/trie/leaf/encode.go @@ -0,0 +1,203 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package leaf + +import ( + "bytes" + "fmt" + "hash" + "io" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie/encode" + "github.com/ChainSafe/gossamer/lib/trie/pools" + "github.com/ChainSafe/gossamer/pkg/scale" +) + +// SetEncodingAndHash sets the encoding and hash slices +// given to the branch. Note it does not copy them, so beware. +func (l *Leaf) SetEncodingAndHash(enc, hash []byte) { + l.encodingMu.Lock() + l.Encoding = enc + l.encodingMu.Unlock() + l.Hash = hash +} + +// GetHash returns the hash of the leaf. +// Note it does not copy it, so modifying +// the returned hash will modify the hash +// of the branch. +func (l *Leaf) GetHash() []byte { + return l.Hash +} + +// EncodeAndHash returns the encoding of the leaf and +// the blake2b hash digest of the encoding of the leaf. +// If the encoding is less than 32 bytes, the hash returned +// is the encoding and not the hash of the encoding. +func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { + l.encodingMu.RLock() + if !l.IsDirty() && l.Encoding != nil && l.Hash != nil { + l.encodingMu.RUnlock() + return l.Encoding, l.Hash, nil + } + l.encodingMu.RUnlock() + + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + buffer.Reset() + defer pools.EncodingBuffers.Put(buffer) + + err = l.Encode(buffer) + if err != nil { + return nil, nil, err + } + + bufferBytes := buffer.Bytes() + + l.encodingMu.Lock() + // TODO remove this copying since it defeats the purpose of `buffer` + // and the sync.Pool. + l.Encoding = make([]byte, len(bufferBytes)) + copy(l.Encoding, bufferBytes) + l.encodingMu.Unlock() + encoding = l.Encoding // no need to copy + + if len(bufferBytes) < 32 { + l.Hash = make([]byte, len(bufferBytes)) + copy(l.Hash, bufferBytes) + hash = l.Hash // no need to copy + return encoding, hash, nil + } + + // Note: using the sync.Pool's buffer is useful here. + hashArray, err := common.Blake2bHash(buffer.Bytes()) + if err != nil { + return nil, nil, err + } + + l.Hash = hashArray[:] + hash = l.Hash // no need to copy + + return encoding, hash, nil +} + +// Encode encodes a leaf to the buffer given. +// The encoding has the following format: +// NodeHeader | Extra partial key length | Partial Key | Value +func (l *Leaf) Encode(buffer encode.Buffer) (err error) { + // if l == nil { + // // TODO remove if not needed + // _, err := buffer.Write([]byte{0}) + // if err != nil { + // return fmt.Errorf("cannot write nil encoding to buffer: %w", err) + // } + // return nil + // } + + l.encodingMu.RLock() + if !l.Dirty && l.Encoding != nil { + _, err = buffer.Write(l.Encoding) + l.encodingMu.RUnlock() + if err != nil { + return fmt.Errorf("cannot write stored encoding to buffer: %w", err) + } + return nil + } + l.encodingMu.RUnlock() + + encodedHeader, err := l.Header() + if err != nil { + return fmt.Errorf("cannot encode header: %w", err) + } + + _, err = buffer.Write(encodedHeader) + if err != nil { + return fmt.Errorf("cannot write encoded header to buffer: %w", err) + } + + keyLE := encode.NibblesToKeyLE(l.Key) + _, err = buffer.Write(keyLE) + if err != nil { + return fmt.Errorf("cannot write LE key to buffer: %w", err) + } + + encodedValue, err := scale.Marshal(l.Value) // TODO scale encoder to write to buffer + if err != nil { + return fmt.Errorf("cannot scale marshal value: %w", err) + } + + _, err = buffer.Write(encodedValue) + if err != nil { + return fmt.Errorf("cannot write scale encoded value to buffer: %w", err) + } + + // TODO remove this copying since it defeats the purpose of `buffer` + // and the sync.Pool. + l.encodingMu.Lock() + defer l.encodingMu.Unlock() + l.Encoding = make([]byte, buffer.Len()) + copy(l.Encoding, buffer.Bytes()) + return nil +} + +// ScaleEncodeHash hashes the node (blake2b sum on encoded value) +// and then SCALE encodes it. This is used to encode children +// nodes of branches. +func (l *Leaf) ScaleEncodeHash() (b []byte, err error) { + // if l == nil { // TODO remove + // panic("Should write 0 to buffer") + // } + + buffer := pools.DigestBuffers.Get().(*bytes.Buffer) + buffer.Reset() + defer pools.DigestBuffers.Put(buffer) + + err = l.hash(buffer) + if err != nil { + return nil, fmt.Errorf("cannot hash node: %w", err) + } + + scEncChild, err := scale.Marshal(buffer.Bytes()) + if err != nil { + return nil, fmt.Errorf("cannot scale encode hashed node: %w", err) + } + return scEncChild, nil +} + +func (l *Leaf) hash(writer io.Writer) (err error) { + encodingBuffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + encodingBuffer.Reset() + defer pools.EncodingBuffers.Put(encodingBuffer) + + err = l.Encode(encodingBuffer) + if err != nil { + return fmt.Errorf("cannot encode leaf: %w", err) + } + + // if length of encoded leaf is less than 32 bytes, do not hash + if encodingBuffer.Len() < 32 { + _, err = writer.Write(encodingBuffer.Bytes()) + if err != nil { + return fmt.Errorf("cannot write encoded leaf to buffer: %w", err) + } + return nil + } + + // otherwise, hash encoded node + hasher := pools.Hashers.Get().(hash.Hash) + hasher.Reset() + defer pools.Hashers.Put(hasher) + + // Note: using the sync.Pool's buffer is useful here. + _, err = hasher.Write(encodingBuffer.Bytes()) + if err != nil { + return fmt.Errorf("cannot hash encoded node: %w", err) + } + + _, err = writer.Write(hasher.Sum(nil)) + if err != nil { + return fmt.Errorf("cannot write hash sum of leaf to buffer: %w", err) + } + return nil +} diff --git a/lib/trie/leaf/encode_test.go b/lib/trie/leaf/encode_test.go new file mode 100644 index 0000000000..513b61eb90 --- /dev/null +++ b/lib/trie/leaf/encode_test.go @@ -0,0 +1,318 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package leaf + +import ( + "errors" + "testing" + + "github.com/ChainSafe/gossamer/lib/trie/encode" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type writeCall struct { + written []byte + n int + err error +} + +var errTest = errors.New("test error") + +//go:generate mockgen -destination=buffer_mock_test.go -package $GOPACKAGE github.com/ChainSafe/gossamer/lib/trie/encode Buffer + +func Test_Leaf_Encode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + writes []writeCall + bufferLenCall bool + bufferBytesCall bool + bufferBytes []byte + expectedEncoding []byte + wrappedErr error + errMessage string + }{ + "clean leaf with encoding": { + leaf: &Leaf{ + Encoding: []byte{1, 2, 3}, + }, + writes: []writeCall{ + { + written: []byte{1, 2, 3}, + }, + }, + expectedEncoding: []byte{1, 2, 3}, + }, + "write error for clean leaf with encoding": { + leaf: &Leaf{ + Encoding: []byte{1, 2, 3}, + }, + writes: []writeCall{ + { + written: []byte{1, 2, 3}, + err: errTest, + }, + }, + expectedEncoding: []byte{1, 2, 3}, + wrappedErr: errTest, + errMessage: "cannot write stored encoding to buffer: test error", + }, + "header encoding error": { + leaf: &Leaf{ + Key: make([]byte, 63+(1<<16)), + }, + wrappedErr: encode.ErrPartialKeyTooBig, + errMessage: "cannot encode header: partial key length cannot be larger than or equal to 2^16: 65536", + }, + "buffer write error for encoded header": { + leaf: &Leaf{ + Key: []byte{1, 2, 3}, + }, + writes: []writeCall{ + { + written: []byte{67}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write encoded header to buffer: test error", + }, + "buffer write error for encoded key": { + leaf: &Leaf{ + Key: []byte{1, 2, 3}, + }, + writes: []writeCall{ + { + written: []byte{67}, + }, + { + written: []byte{1, 35}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write LE key to buffer: test error", + }, + "buffer write error for encoded value": { + leaf: &Leaf{ + Key: []byte{1, 2, 3}, + Value: []byte{4, 5, 6}, + }, + writes: []writeCall{ + { + written: []byte{67}, + }, + { + written: []byte{1, 35}, + }, + { + written: []byte{12, 4, 5, 6}, + err: errTest, + }, + }, + wrappedErr: errTest, + errMessage: "cannot write scale encoded value to buffer: test error", + }, + "success": { + leaf: &Leaf{ + Key: []byte{1, 2, 3}, + Value: []byte{4, 5, 6}, + }, + writes: []writeCall{ + { + written: []byte{67}, + }, + { + written: []byte{1, 35}, + }, + { + written: []byte{12, 4, 5, 6}, + }, + }, + bufferLenCall: true, + bufferBytesCall: true, + bufferBytes: []byte{1, 2, 3}, + expectedEncoding: []byte{1, 2, 3}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + buffer := NewMockBuffer(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := buffer.EXPECT(). + Write(write.written). + Return(write.n, write.err) + + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + if testCase.bufferLenCall { + buffer.EXPECT().Len().Return(len(testCase.bufferBytes)) + } + if testCase.bufferBytesCall { + buffer.EXPECT().Bytes().Return(testCase.bufferBytes) + } + + err := testCase.leaf.Encode(buffer) + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + assert.Equal(t, testCase.expectedEncoding, testCase.leaf.Encoding) + }) + } +} + +func Test_Leaf_ScaleEncodeHash(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + b []byte + wrappedErr error + errMessage string + }{ + "leaf": { + leaf: &Leaf{}, + b: []byte{0x8, 0x40, 0}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + b, err := testCase.leaf.ScaleEncodeHash() + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + + assert.Equal(t, testCase.b, b) + }) + } +} + +//go:generate mockgen -destination=writer_mock_test.go -package $GOPACKAGE io Writer + +func Test_Leaf_hash(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + writeCall bool + write writeCall + wrappedErr error + errMessage string + }{ + "small leaf buffer write error": { + leaf: &Leaf{ + Encoding: []byte{1, 2, 3}, + }, + writeCall: true, + write: writeCall{ + written: []byte{1, 2, 3}, + err: errTest, + }, + wrappedErr: errTest, + errMessage: "cannot write encoded leaf to buffer: " + + "test error", + }, + "small leaf success": { + leaf: &Leaf{ + Encoding: []byte{1, 2, 3}, + }, + writeCall: true, + write: writeCall{ + written: []byte{1, 2, 3}, + }, + }, + "leaf hash sum buffer write error": { + leaf: &Leaf{ + Encoding: []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + }, + }, + writeCall: true, + write: writeCall{ + written: []byte{ + 107, 105, 154, 175, 253, 170, 232, + 135, 240, 21, 207, 148, 82, 117, + 249, 230, 80, 197, 254, 17, 149, + 108, 50, 7, 80, 56, 114, 176, + 84, 114, 125, 234}, + err: errTest, + }, + wrappedErr: errTest, + errMessage: "cannot write hash sum of leaf to buffer: " + + "test error", + }, + "leaf hash sum success": { + leaf: &Leaf{ + Encoding: []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + }, + }, + writeCall: true, + write: writeCall{ + written: []byte{ + 107, 105, 154, 175, 253, 170, 232, + 135, 240, 21, 207, 148, 82, 117, + 249, 230, 80, 197, 254, 17, 149, + 108, 50, 7, 80, 56, 114, 176, + 84, 114, 125, 234}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + writer := NewMockWriter(ctrl) + if testCase.writeCall { + writer.EXPECT(). + Write(testCase.write.written). + Return(testCase.write.n, testCase.write.err) + } + + err := testCase.leaf.hash(writer) + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/lib/trie/leaf/generation.go b/lib/trie/leaf/generation.go new file mode 100644 index 0000000000..1ce46bf81d --- /dev/null +++ b/lib/trie/leaf/generation.go @@ -0,0 +1,14 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package leaf + +// SetGeneration sets the generation given to the leaf. +func (l *Leaf) SetGeneration(generation uint64) { + l.Generation = generation +} + +// GetGeneration returns the generation of the leaf. +func (l *Leaf) GetGeneration() uint64 { + return l.Generation +} diff --git a/lib/trie/leaf/header.go b/lib/trie/leaf/header.go new file mode 100644 index 0000000000..264d38ab05 --- /dev/null +++ b/lib/trie/leaf/header.go @@ -0,0 +1,26 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package leaf + +import "github.com/ChainSafe/gossamer/lib/trie/encode" + +// Header creates the encoded header for the leaf. +func (l *Leaf) Header() (encoding []byte, err error) { + var header byte = 1 << 6 + var encodedPublicKeyLength []byte + + if len(l.Key) >= 63 { + header = header | 0x3f + encodedPublicKeyLength, err = encode.ExtraPartialKeyLength(len(l.Key)) + if err != nil { + return nil, err + } + } else { + header = header | byte(len(l.Key)) + } + + encoding = append([]byte{header}, encodedPublicKeyLength...) + + return encoding, nil +} diff --git a/lib/trie/leaf/header_test.go b/lib/trie/leaf/header_test.go new file mode 100644 index 0000000000..a1825034c1 --- /dev/null +++ b/lib/trie/leaf/header_test.go @@ -0,0 +1,74 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package leaf + +import ( + "testing" + + "github.com/ChainSafe/gossamer/lib/trie/encode" + "github.com/stretchr/testify/assert" +) + +func Test_Leaf_Header(t *testing.T) { + testCases := map[string]struct { + leaf *Leaf + encoding []byte + wrappedErr error + errMessage string + }{ + "no key": { + leaf: &Leaf{}, + encoding: []byte{0x40}, + }, + "key of length 30": { + leaf: &Leaf{ + Key: make([]byte, 30), + }, + encoding: []byte{0x5e}, + }, + "key of length 62": { + leaf: &Leaf{ + Key: make([]byte, 62), + }, + encoding: []byte{0x7e}, + }, + "key of length 63": { + leaf: &Leaf{ + Key: make([]byte, 63), + }, + encoding: []byte{0x7f, 0x0}, + }, + "key of length 64": { + leaf: &Leaf{ + Key: make([]byte, 64), + }, + encoding: []byte{0x7f, 0x1}, + }, + "key too big": { + leaf: &Leaf{ + Key: make([]byte, 65535+63), + }, + wrappedErr: encode.ErrPartialKeyTooBig, + errMessage: "partial key length cannot be larger than or equal to 2^16: 65535", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + encoding, err := testCase.leaf.Header() + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, testCase.encoding, encoding) + }) + } +} diff --git a/lib/trie/leaf/key.go b/lib/trie/leaf/key.go new file mode 100644 index 0000000000..9a7d3a11d6 --- /dev/null +++ b/lib/trie/leaf/key.go @@ -0,0 +1,11 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package leaf + +// SetKey sets the key to the leaf. +// Note it does not copy it so modifying the passed key +// will modify the key stored in the leaf. +func (l *Leaf) SetKey(key []byte) { + l.Key = key +} diff --git a/lib/trie/leaf/leaf.go b/lib/trie/leaf/leaf.go new file mode 100644 index 0000000000..7e86ad9c95 --- /dev/null +++ b/lib/trie/leaf/leaf.go @@ -0,0 +1,33 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package leaf + +import ( + "fmt" + "sync" + + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie/node" +) + +var _ node.Node = (*Leaf)(nil) + +// Leaf is a leaf in the trie. +type Leaf struct { + Key []byte // partial key + Value []byte + Dirty bool + Hash []byte + Encoding []byte + encodingMu sync.RWMutex + Generation uint64 + sync.RWMutex +} + +func (l *Leaf) String() string { + if len(l.Value) > 1024 { + return fmt.Sprintf("leaf key=%x value (hashed)=%x dirty=%v", l.Key, common.MustBlake2bHash(l.Value), l.Dirty) + } + return fmt.Sprintf("leaf key=%x value=%v dirty=%v", l.Key, l.Value, l.Dirty) +} diff --git a/lib/trie/leaf/writer_mock_test.go b/lib/trie/leaf/writer_mock_test.go new file mode 100644 index 0000000000..04cb474a72 --- /dev/null +++ b/lib/trie/leaf/writer_mock_test.go @@ -0,0 +1,49 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: io (interfaces: Writer) + +// Package leaf is a generated GoMock package. +package leaf + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockWriter is a mock of Writer interface. +type MockWriter struct { + ctrl *gomock.Controller + recorder *MockWriterMockRecorder +} + +// MockWriterMockRecorder is the mock recorder for MockWriter. +type MockWriterMockRecorder struct { + mock *MockWriter +} + +// NewMockWriter creates a new mock instance. +func NewMockWriter(ctrl *gomock.Controller) *MockWriter { + mock := &MockWriter{ctrl: ctrl} + mock.recorder = &MockWriterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockWriter) EXPECT() *MockWriterMockRecorder { + return m.recorder +} + +// Write mocks base method. +func (m *MockWriter) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockWriterMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockWriter)(nil).Write), arg0) +} diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index 0e38e19ea8..4df79c6c33 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -5,6 +5,9 @@ package trie import ( "bytes" + + "github.com/ChainSafe/gossamer/lib/trie/branch" + "github.com/ChainSafe/gossamer/lib/trie/node" ) // findAndRecord search for a desired key recording all the nodes in the path including the desired node @@ -12,7 +15,7 @@ func findAndRecord(t *Trie, key []byte, recorder *recorder) error { return find(t.root, key, recorder) } -func find(parent Node, key []byte, recorder *recorder) error { +func find(parent node.Node, key []byte, recorder *recorder) error { enc, hash, err := parent.EncodeAndHash() if err != nil { return err @@ -20,22 +23,22 @@ func find(parent Node, key []byte, recorder *recorder) error { recorder.record(hash, enc) - b, ok := parent.(*Branch) + b, ok := parent.(*branch.Branch) if !ok { return nil } - length := lenCommonPrefix(b.key, key) + length := lenCommonPrefix(b.Key, key) // found the value at this node - if bytes.Equal(b.key, key) || len(key) == 0 { + if bytes.Equal(b.Key, key) || len(key) == 0 { return nil } // did not find value - if bytes.Equal(b.key[:length], key) && len(key) < len(b.key) { + if bytes.Equal(b.Key[:length], key) && len(key) < len(b.Key) { return nil } - return find(b.children[key[length]], key[length+1:], recorder) + return find(b.Children[key[length]], key[length+1:], recorder) } diff --git a/lib/trie/node.go b/lib/trie/node.go deleted file mode 100644 index 7c9f316a31..0000000000 --- a/lib/trie/node.go +++ /dev/null @@ -1,577 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -//nolint:lll -// Modified Merkle-Patricia Trie -// See https://github.com/w3f/polkadot-spec/blob/master/runtime-environment-spec/polkadot_re_spec.pdf for the full specification. -// -// Note that for the following definitions, `|` denotes concatenation -// -// Branch encoding: -// NodeHeader | Extra partial key length | Partial Key | Value -// `NodeHeader` is a byte such that: -// most significant two bits of `NodeHeader`: 10 if branch w/o value, 11 if branch w/ value -// least significant six bits of `NodeHeader`: if len(key) > 62, 0x3f, otherwise len(key) -// `Extra partial key length` is included if len(key) > 63 and consists of the remaining key length -// `Partial Key` is the branch's key -// `Value` is: Children Bitmap | SCALE Branch node Value | Hash(Enc(Child[i_1])) | Hash(Enc(Child[i_2])) | ... | Hash(Enc(Child[i_n])) -// -// Leaf encoding: -// NodeHeader | Extra partial key length | Partial Key | Value -// `NodeHeader` is a byte such that: -// most significant two bits of `NodeHeader`: 01 -// least significant six bits of `NodeHeader`: if len(key) > 62, 0x3f, otherwise len(key) -// `Extra partial key length` is included if len(key) > 63 and consists of the remaining key length -// `Partial Key` is the leaf's key -// `Value` is the leaf's SCALE encoded value - -package trie - -import ( - "bytes" - "errors" - "fmt" - "io" - "sync" - - "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/pkg/scale" -) - -// Node is the interface for trie methods -type Node interface { - EncodeAndHash() ([]byte, []byte, error) - Decode(r io.Reader, h byte) error - IsDirty() bool - SetDirty(dirty bool) - SetKey(key []byte) - String() string - SetEncodingAndHash([]byte, []byte) - GetHash() []byte - GetGeneration() uint64 - SetGeneration(uint64) - Copy() Node -} - -type ( - // Branch is a branch in the trie. - Branch struct { - key []byte // partial key - children [16]Node - value []byte - dirty bool - hash []byte - encoding []byte - generation uint64 - sync.RWMutex - } - - // Leaf is a leaf in the trie. - Leaf struct { - key []byte // partial key - value []byte - dirty bool - hash []byte - encoding []byte - encodingMu sync.RWMutex - generation uint64 - sync.RWMutex - } -) - -// SetGeneration sets the generation given to the branch. -func (b *Branch) SetGeneration(generation uint64) { - b.generation = generation -} - -// SetGeneration sets the generation given to the leaf. -func (l *Leaf) SetGeneration(generation uint64) { - l.generation = generation -} - -// Copy deep copies the branch. -func (b *Branch) Copy() Node { - b.RLock() - defer b.RUnlock() - - cpy := &Branch{ - key: make([]byte, len(b.key)), - children: b.children, // copy interface pointers - value: nil, - dirty: b.dirty, - hash: make([]byte, len(b.hash)), - encoding: make([]byte, len(b.encoding)), - generation: b.generation, - } - copy(cpy.key, b.key) - - // nil and []byte{} are encoded differently, watch out! - if b.value != nil { - cpy.value = make([]byte, len(b.value)) - copy(cpy.value, b.value) - } - - copy(cpy.hash, b.hash) - copy(cpy.encoding, b.encoding) - return cpy -} - -// Copy deep copies the leaf. -func (l *Leaf) Copy() Node { - l.RLock() - defer l.RUnlock() - - l.encodingMu.RLock() - defer l.encodingMu.RUnlock() - - cpy := &Leaf{ - key: make([]byte, len(l.key)), - value: make([]byte, len(l.value)), - dirty: l.dirty, - hash: make([]byte, len(l.hash)), - encoding: make([]byte, len(l.encoding)), - generation: l.generation, - } - copy(cpy.key, l.key) - copy(cpy.value, l.value) - copy(cpy.hash, l.hash) - copy(cpy.encoding, l.encoding) - return cpy -} - -// SetEncodingAndHash sets the encoding and hash slices -// given to the branch. Note it does not copy them, so beware. -func (b *Branch) SetEncodingAndHash(enc, hash []byte) { - b.encoding = enc - b.hash = hash -} - -// SetEncodingAndHash sets the encoding and hash slices -// given to the branch. Note it does not copy them, so beware. -func (l *Leaf) SetEncodingAndHash(enc, hash []byte) { - l.encodingMu.Lock() - l.encoding = enc - l.encodingMu.Unlock() - - l.hash = hash -} - -// GetHash returns the hash of the branch. -// Note it does not copy it, so modifying -// the returned hash will modify the hash -// of the branch. -func (b *Branch) GetHash() []byte { - return b.hash -} - -// GetGeneration returns the generation of the branch. -func (b *Branch) GetGeneration() uint64 { - return b.generation -} - -// GetGeneration returns the generation of the leaf. -func (l *Leaf) GetGeneration() uint64 { - return l.generation -} - -// GetHash returns the hash of the leaf. -// Note it does not copy it, so modifying -// the returned hash will modify the hash -// of the branch. -func (l *Leaf) GetHash() []byte { - return l.hash -} - -func (b *Branch) String() string { - if len(b.value) > 1024 { - return fmt.Sprintf( - "branch key=%x childrenBitmap=%16b value (hashed)=%x dirty=%v", - b.key, b.childrenBitmap(), common.MustBlake2bHash(b.value), b.dirty) - } - return fmt.Sprintf("branch key=%x childrenBitmap=%16b value=%v dirty=%v", b.key, b.childrenBitmap(), b.value, b.dirty) -} - -func (l *Leaf) String() string { - if len(l.value) > 1024 { - return fmt.Sprintf("leaf key=%x value (hashed)=%x dirty=%v", l.key, common.MustBlake2bHash(l.value), l.dirty) - } - return fmt.Sprintf("leaf key=%x value=%v dirty=%v", l.key, l.value, l.dirty) -} - -func (b *Branch) childrenBitmap() uint16 { - var bitmap uint16 - var i uint - for i = 0; i < 16; i++ { - if b.children[i] != nil { - bitmap = bitmap | 1<> 6 - if nodeType == 1 { - l := new(Leaf) - err := l.Decode(r, header) - return l, err - } else if nodeType == 2 || nodeType == 3 { - b := new(Branch) - err := b.Decode(r, header) - return b, err - } - - return nil, errors.New("cannot decode invalid encoding into node") -} - -// Decode decodes a byte array with the encoding specified at the top of this package into a branch node -// Note that since the encoded branch stores the hash of the children nodes, we aren't able to reconstruct the child -// nodes from the encoding. This function instead stubs where the children are known to be with an empty leaf. -func (b *Branch) Decode(r io.Reader, header byte) (err error) { - if header == 0 { - header, err = readByte(r) - if err != nil { - return err - } - } - - nodeType := header >> 6 - if nodeType != 2 && nodeType != 3 { - return fmt.Errorf("cannot decode node to branch") - } - - keyLen := header & 0x3f - b.key, err = decodeKey(r, keyLen) - if err != nil { - return err - } - - childrenBitmap := make([]byte, 2) - _, err = r.Read(childrenBitmap) - if err != nil { - return err - } - - sd := scale.NewDecoder(r) - - if nodeType == 3 { - var value []byte - // branch w/ value - err := sd.Decode(&value) - if err != nil { - return err - } - b.value = value - } - - for i := 0; i < 16; i++ { - if (childrenBitmap[i/8]>>(i%8))&1 == 1 { - var hash []byte - err := sd.Decode(&hash) - if err != nil { - return err - } - - b.children[i] = &Leaf{ - hash: hash, - } - } - } - - b.dirty = true - - return nil -} - -// Decode decodes a byte array with the encoding specified at the top of this package into a leaf node -func (l *Leaf) Decode(r io.Reader, header byte) (err error) { - if header == 0 { - header, err = readByte(r) - if err != nil { - return err - } - } - - nodeType := header >> 6 - if nodeType != 1 { - return fmt.Errorf("cannot decode node to leaf") - } - - keyLen := header & 0x3f - l.key, err = decodeKey(r, keyLen) - if err != nil { - return err - } - - sd := scale.NewDecoder(r) - var value []byte - err = sd.Decode(&value) - if err != nil { - return err - } - - if len(value) > 0 { - l.value = value - } - - l.dirty = true - - return nil -} - -func (b *Branch) header() ([]byte, error) { - var header byte - if b.value == nil { - header = 2 << 6 - } else { - header = 3 << 6 - } - var encodePkLen []byte - var err error - - if len(b.key) >= 63 { - header = header | 0x3f - encodePkLen, err = encodeExtraPartialKeyLength(len(b.key)) - if err != nil { - return nil, err - } - } else { - header = header | byte(len(b.key)) - } - - fullHeader := append([]byte{header}, encodePkLen...) - return fullHeader, nil -} - -func (l *Leaf) header() ([]byte, error) { - var header byte = 1 << 6 - var encodePkLen []byte - var err error - - if len(l.key) >= 63 { - header = header | 0x3f - encodePkLen, err = encodeExtraPartialKeyLength(len(l.key)) - if err != nil { - return nil, err - } - } else { - header = header | byte(len(l.key)) - } - - fullHeader := append([]byte{header}, encodePkLen...) - return fullHeader, nil -} - -var ErrPartialKeyTooBig = errors.New("partial key length greater than or equal to 2^16") - -func encodeExtraPartialKeyLength(pkLen int) ([]byte, error) { - pkLen -= 63 - fullHeader := []byte{} - - if pkLen >= 1<<16 { - return nil, ErrPartialKeyTooBig - } - - for i := 0; i < 1<<16; i++ { - if pkLen < 255 { - fullHeader = append(fullHeader, byte(pkLen)) - break - } else { - fullHeader = append(fullHeader, byte(255)) - pkLen -= 255 - } - } - - return fullHeader, nil -} - -func decodeKey(r io.Reader, keyLen byte) ([]byte, error) { - var totalKeyLen = int(keyLen) - - if keyLen == 0x3f { - // partial key longer than 63, read next bytes for rest of pk len - for { - nextKeyLen, err := readByte(r) - if err != nil { - return nil, err - } - totalKeyLen += int(nextKeyLen) - - if nextKeyLen < 0xff { - break - } - - if totalKeyLen >= 1<<16 { - return nil, errors.New("partial key length greater than or equal to 2^16") - } - } - } - - if totalKeyLen != 0 { - key := make([]byte, totalKeyLen/2+totalKeyLen%2) - _, err := r.Read(key) - if err != nil { - return key, err - } - - return keyToNibbles(key)[totalKeyLen%2:], nil - } - - return []byte{}, nil -} - -func readByte(r io.Reader) (byte, error) { - buf := make([]byte, 1) - _, err := r.Read(buf) - if err != nil { - return 0, err - } - return buf[0], nil -} diff --git a/lib/trie/node/interface.go b/lib/trie/node/interface.go new file mode 100644 index 0000000000..2c1b26ee15 --- /dev/null +++ b/lib/trie/node/interface.go @@ -0,0 +1,25 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "github.com/ChainSafe/gossamer/lib/trie/encode" +) + +// Node is node in the trie and can be a leaf or a branch. +type Node interface { + Encode(buffer encode.Buffer) (err error) // TODO change to io.Writer + EncodeAndHash() ([]byte, []byte, error) + ScaleEncodeHash() (b []byte, err error) + // Decode(r io.Reader, h byte) error + IsDirty() bool + SetDirty(dirty bool) + SetKey(key []byte) + String() string + SetEncodingAndHash([]byte, []byte) + GetHash() []byte + GetGeneration() uint64 + SetGeneration(uint64) + Copy() Node +} diff --git a/lib/trie/node/types.go b/lib/trie/node/types.go new file mode 100644 index 0000000000..a912955b3e --- /dev/null +++ b/lib/trie/node/types.go @@ -0,0 +1,17 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +// Type is the byte type for the node. +type Type byte + +const ( + _ = iota + // LeafType type is 1 + LeafType + // BranchType type is 2 + BranchType + // BranchWithValueType type is 3 + BranchWithValueType +) diff --git a/lib/trie/node_mock_test.go b/lib/trie/node_mock_test.go deleted file mode 100644 index d381d1a157..0000000000 --- a/lib/trie/node_mock_test.go +++ /dev/null @@ -1,183 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: node.go - -// Package trie is a generated GoMock package. -package trie - -import ( - io "io" - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockNode is a mock of Node interface. -type MockNode struct { - ctrl *gomock.Controller - recorder *MockNodeMockRecorder -} - -// MockNodeMockRecorder is the mock recorder for MockNode. -type MockNodeMockRecorder struct { - mock *MockNode -} - -// NewMockNode creates a new mock instance. -func NewMockNode(ctrl *gomock.Controller) *MockNode { - mock := &MockNode{ctrl: ctrl} - mock.recorder = &MockNodeMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockNode) EXPECT() *MockNodeMockRecorder { - return m.recorder -} - -// Copy mocks base method. -func (m *MockNode) Copy() Node { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Copy") - ret0, _ := ret[0].(Node) - return ret0 -} - -// Copy indicates an expected call of Copy. -func (mr *MockNodeMockRecorder) Copy() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Copy", reflect.TypeOf((*MockNode)(nil).Copy)) -} - -// Decode mocks base method. -func (m *MockNode) Decode(r io.Reader, h byte) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Decode", r, h) - ret0, _ := ret[0].(error) - return ret0 -} - -// Decode indicates an expected call of Decode. -func (mr *MockNodeMockRecorder) Decode(r, h interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decode", reflect.TypeOf((*MockNode)(nil).Decode), r, h) -} - -// EncodeAndHash mocks base method. -func (m *MockNode) EncodeAndHash() ([]byte, []byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "EncodeAndHash") - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].([]byte) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// EncodeAndHash indicates an expected call of EncodeAndHash. -func (mr *MockNodeMockRecorder) EncodeAndHash() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncodeAndHash", reflect.TypeOf((*MockNode)(nil).EncodeAndHash)) -} - -// GetGeneration mocks base method. -func (m *MockNode) GetGeneration() uint64 { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGeneration") - ret0, _ := ret[0].(uint64) - return ret0 -} - -// GetGeneration indicates an expected call of GetGeneration. -func (mr *MockNodeMockRecorder) GetGeneration() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGeneration", reflect.TypeOf((*MockNode)(nil).GetGeneration)) -} - -// GetHash mocks base method. -func (m *MockNode) GetHash() []byte { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHash") - ret0, _ := ret[0].([]byte) - return ret0 -} - -// GetHash indicates an expected call of GetHash. -func (mr *MockNodeMockRecorder) GetHash() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHash", reflect.TypeOf((*MockNode)(nil).GetHash)) -} - -// IsDirty mocks base method. -func (m *MockNode) IsDirty() bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsDirty") - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsDirty indicates an expected call of IsDirty. -func (mr *MockNodeMockRecorder) IsDirty() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsDirty", reflect.TypeOf((*MockNode)(nil).IsDirty)) -} - -// SetDirty mocks base method. -func (m *MockNode) SetDirty(dirty bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetDirty", dirty) -} - -// SetDirty indicates an expected call of SetDirty. -func (mr *MockNodeMockRecorder) SetDirty(dirty interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDirty", reflect.TypeOf((*MockNode)(nil).SetDirty), dirty) -} - -// SetEncodingAndHash mocks base method. -func (m *MockNode) SetEncodingAndHash(arg0, arg1 []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetEncodingAndHash", arg0, arg1) -} - -// SetEncodingAndHash indicates an expected call of SetEncodingAndHash. -func (mr *MockNodeMockRecorder) SetEncodingAndHash(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetEncodingAndHash", reflect.TypeOf((*MockNode)(nil).SetEncodingAndHash), arg0, arg1) -} - -// SetGeneration mocks base method. -func (m *MockNode) SetGeneration(arg0 uint64) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetGeneration", arg0) -} - -// SetGeneration indicates an expected call of SetGeneration. -func (mr *MockNodeMockRecorder) SetGeneration(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetGeneration", reflect.TypeOf((*MockNode)(nil).SetGeneration), arg0) -} - -// SetKey mocks base method. -func (m *MockNode) SetKey(key []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetKey", key) -} - -// SetKey indicates an expected call of SetKey. -func (mr *MockNodeMockRecorder) SetKey(key interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetKey", reflect.TypeOf((*MockNode)(nil).SetKey), key) -} - -// String mocks base method. -func (m *MockNode) String() string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "String") - ret0, _ := ret[0].(string) - return ret0 -} - -// String indicates an expected call of String. -func (mr *MockNodeMockRecorder) String() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockNode)(nil).String)) -} diff --git a/lib/trie/node_test.go b/lib/trie/node_test.go index 52b087c541..f667bb82b4 100644 --- a/lib/trie/node_test.go +++ b/lib/trie/node_test.go @@ -5,213 +5,11 @@ package trie import ( "bytes" - "math/rand" - "strconv" "testing" - "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/pkg/scale" - - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// byteArray makes byte array with length specified; used to test byte array encoding -func byteArray(length int) []byte { - b := make([]byte, length) - for i := 0; i < length; i++ { - b[i] = 0xf - } - return b -} - -func generateRand(size int) [][]byte { - rt := make([][]byte, size) - for i := range rt { - buf := make([]byte, rand.Intn(379)+1) - rand.Read(buf) - rt[i] = buf - } - return rt -} - -func TestChildrenBitmap(t *testing.T) { - b := &Branch{children: [16]Node{}} - res := b.childrenBitmap() - if res != 0 { - t.Errorf("Fail to get children bitmap: got %x expected %x", res, 1) - } - - b.children[0] = &Leaf{key: []byte{0x00}, value: []byte{0x00}} - res = b.childrenBitmap() - if res != 1 { - t.Errorf("Fail to get children bitmap: got %x expected %x", res, 1) - } - - b.children[4] = &Leaf{key: []byte{0x00}, value: []byte{0x00}} - res = b.childrenBitmap() - if res != 1<<4+1 { - t.Errorf("Fail to get children bitmap: got %x expected %x", res, 17) - } - - b.children[15] = &Leaf{key: []byte{0x00}, value: []byte{0x00}} - res = b.childrenBitmap() - if res != 1<<15+1<<4+1 { - t.Errorf("Fail to get children bitmap: got %x expected %x", res, 257) - } -} - -func TestBranchHeader(t *testing.T) { - tests := []struct { - br *Branch - header []byte - }{ - {&Branch{key: nil, children: [16]Node{}, value: nil}, []byte{0x80}}, - {&Branch{key: []byte{0x00}, children: [16]Node{}, value: nil}, []byte{0x81}}, - {&Branch{key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]Node{}, value: nil}, []byte{0x84}}, - - {&Branch{key: nil, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc0}}, - {&Branch{key: []byte{0x00}, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc1}}, - {&Branch{key: []byte{0x00, 0x00}, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc2}}, - {&Branch{key: []byte{0x00, 0x00, 0xf}, children: [16]Node{}, value: []byte{0x01}}, []byte{0xc3}}, - - {&Branch{key: byteArray(62), children: [16]Node{}, value: nil}, []byte{0xbe}}, - {&Branch{key: byteArray(62), children: [16]Node{}, value: []byte{0x00}}, []byte{0xfe}}, - {&Branch{key: byteArray(63), children: [16]Node{}, value: nil}, []byte{0xbf, 0}}, - {&Branch{key: byteArray(64), children: [16]Node{}, value: nil}, []byte{0xbf, 1}}, - {&Branch{key: byteArray(64), children: [16]Node{}, value: []byte{0x01}}, []byte{0xff, 1}}, - - {&Branch{key: byteArray(317), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 254}}, - {&Branch{key: byteArray(318), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 255, 0}}, - {&Branch{key: byteArray(573), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 255, 255, 0}}, - } - - for _, test := range tests { - test := test - res, err := test.br.header() - if err != nil { - t.Fatalf("Error when encoding header: %s", err) - } else if !bytes.Equal(res, test.header) { - t.Errorf("Branch header fail case %v: got %x expected %x", test.br, res, test.header) - } - } -} - -func TestFailingPk(t *testing.T) { - tests := []struct { - br *Branch - header []byte - }{ - {&Branch{key: byteArray(2 << 16), children: [16]Node{}, value: []byte{0x01}}, []byte{255, 254}}, - } - - for _, test := range tests { - _, err := test.br.header() - if err == nil { - t.Fatalf("should error when encoding node w pk length > 2^16") - } - } -} - -func TestLeafHeader(t *testing.T) { - tests := []struct { - br *Leaf - header []byte - }{ - {&Leaf{key: nil, value: nil}, []byte{0x40}}, - {&Leaf{key: []byte{0x00}, value: nil}, []byte{0x41}}, - {&Leaf{key: []byte{0x00, 0x00, 0xf, 0x3}, value: nil}, []byte{0x44}}, - {&Leaf{key: byteArray(62), value: nil}, []byte{0x7e}}, - {&Leaf{key: byteArray(63), value: nil}, []byte{0x7f, 0}}, - {&Leaf{key: byteArray(64), value: []byte{0x01}}, []byte{0x7f, 1}}, - - {&Leaf{key: byteArray(318), value: []byte{0x01}}, []byte{0x7f, 0xff, 0}}, - {&Leaf{key: byteArray(573), value: []byte{0x01}}, []byte{0x7f, 0xff, 0xff, 0}}, - } - - for i, test := range tests { - test := test - t.Run(strconv.Itoa(i), func(t *testing.T) { - res, err := test.br.header() - if err != nil { - t.Fatalf("Error when encoding header: %s", err) - } else if !bytes.Equal(res, test.header) { - t.Errorf("Leaf header fail: got %x expected %x", res, test.header) - } - }) - } -} - -func TestBranchEncode(t *testing.T) { - randKeys := generateRand(101) - randVals := generateRand(101) - - for i, testKey := range randKeys { - b := &Branch{key: testKey, children: [16]Node{}, value: randVals[i]} - expected := bytes.NewBuffer(nil) - - header, err := b.header() - if err != nil { - t.Fatalf("Error when encoding header: %s", err) - } - - expected.Write(header) - expected.Write(nibblesToKeyLE(b.key)) - expected.Write(common.Uint16ToBytes(b.childrenBitmap())) - - enc, err := scale.Marshal(b.value) - if err != nil { - t.Fatalf("Fail when encoding value with scale: %s", err) - } - - expected.Write(enc) - - for _, child := range b.children { - if child == nil { - continue - } - - err := hashNode(child, expected) - require.NoError(t, err) - } - - buffer := bytes.NewBuffer(nil) - const parallel = false - err = encodeBranch(b, buffer, parallel) - require.NoError(t, err) - assert.Equal(t, expected.Bytes(), buffer.Bytes()) - } -} - -func TestLeafEncode(t *testing.T) { - randKeys := generateRand(100) - randVals := generateRand(100) - - for i, testKey := range randKeys { - l := &Leaf{key: testKey, value: randVals[i]} - expected := []byte{} - - header, err := l.header() - if err != nil { - t.Fatalf("Error when encoding header: %s", err) - } - expected = append(expected, header...) - expected = append(expected, nibblesToKeyLE(l.key)...) - - enc, err := scale.Marshal(l.value) - if err != nil { - t.Fatalf("Fail when encoding value with scale: %s", err) - } - - expected = append(expected, enc...) - - buffer := bytes.NewBuffer(nil) - err = encodeLeaf(l, buffer) - require.NoError(t, err) - assert.Equal(t, expected, buffer.Bytes()) - } -} - func TestEncodeRoot(t *testing.T) { trie := NewEmptyTrie() @@ -222,133 +20,12 @@ func TestEncodeRoot(t *testing.T) { val := trie.Get(test.key) if !bytes.Equal(val, test.value) { - t.Errorf("Fail to get key %x with value %x: got %x", test.key, test.value, val) + t.Errorf("Fail to get Key %x with value %x: got %x", test.Key(), test.value, val) } buffer := bytes.NewBuffer(nil) - const parallel = false - err := encodeNode(trie.root, buffer, parallel) + err := trie.root.Encode(buffer) require.NoError(t, err) } } } - -func TestBranchDecode(t *testing.T) { - tests := []*Branch{ - {key: []byte{}, children: [16]Node{}, value: nil}, - {key: []byte{0x00}, children: [16]Node{}, value: nil}, - {key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]Node{}, value: nil}, - {key: []byte{}, children: [16]Node{}, value: []byte{0x01}}, - {key: []byte{}, children: [16]Node{&Leaf{}}, value: []byte{0x01}}, - {key: []byte{}, children: [16]Node{&Leaf{}, nil, &Leaf{}}, value: []byte{0x01}}, - { - key: []byte{}, - children: [16]Node{ - &Leaf{}, nil, &Leaf{}, nil, - nil, nil, nil, nil, - nil, &Leaf{}, nil, &Leaf{}, - }, - value: []byte{0x01}, - }, - {key: byteArray(62), children: [16]Node{}, value: nil}, - {key: byteArray(63), children: [16]Node{}, value: nil}, - {key: byteArray(64), children: [16]Node{}, value: nil}, - {key: byteArray(317), children: [16]Node{}, value: []byte{0x01}}, - {key: byteArray(318), children: [16]Node{}, value: []byte{0x01}}, - {key: byteArray(573), children: [16]Node{}, value: []byte{0x01}}, - } - - buffer := bytes.NewBuffer(nil) - const parallel = false - - for _, test := range tests { - err := encodeBranch(test, buffer, parallel) - require.NoError(t, err) - - res := new(Branch) - err = res.Decode(buffer, 0) - - require.NoError(t, err) - require.Equal(t, test.key, res.key) - require.Equal(t, test.childrenBitmap(), res.childrenBitmap()) - require.Equal(t, test.value, res.value) - } -} - -func TestLeafDecode(t *testing.T) { - tests := []*Leaf{ - {key: []byte{}, value: nil, dirty: true}, - {key: []byte{0x01}, value: nil, dirty: true}, - {key: []byte{0x00, 0x00, 0xf, 0x3}, value: nil, dirty: true}, - {key: byteArray(62), value: nil, dirty: true}, - {key: byteArray(63), value: nil, dirty: true}, - {key: byteArray(64), value: []byte{0x01}, dirty: true}, - {key: byteArray(318), value: []byte{0x01}, dirty: true}, - {key: byteArray(573), value: []byte{0x01}, dirty: true}, - } - - buffer := bytes.NewBuffer(nil) - - for _, test := range tests { - err := encodeLeaf(test, buffer) - require.NoError(t, err) - - res := new(Leaf) - err = res.Decode(buffer, 0) - require.NoError(t, err) - - res.hash = nil - test.encoding = nil - require.Equal(t, test, res) - } -} - -func TestDecode(t *testing.T) { - tests := []Node{ - &Branch{key: []byte{}, children: [16]Node{}, value: nil}, - &Branch{key: []byte{0x00}, children: [16]Node{}, value: nil}, - &Branch{key: []byte{0x00, 0x00, 0xf, 0x3}, children: [16]Node{}, value: nil}, - &Branch{key: []byte{}, children: [16]Node{}, value: []byte{0x01}}, - &Branch{key: []byte{}, children: [16]Node{&Leaf{}}, value: []byte{0x01}}, - &Branch{key: []byte{}, children: [16]Node{&Leaf{}, nil, &Leaf{}}, value: []byte{0x01}}, - &Branch{ - key: []byte{}, - children: [16]Node{ - &Leaf{}, nil, &Leaf{}, nil, - nil, nil, nil, nil, - nil, &Leaf{}, nil, &Leaf{}}, - value: []byte{0x01}, - }, - &Leaf{key: []byte{}, value: nil}, - &Leaf{key: []byte{0x00}, value: nil}, - &Leaf{key: []byte{0x00, 0x00, 0xf, 0x3}, value: nil}, - &Leaf{key: byteArray(62), value: nil}, - &Leaf{key: byteArray(63), value: nil}, - &Leaf{key: byteArray(64), value: []byte{0x01}}, - &Leaf{key: byteArray(318), value: []byte{0x01}}, - &Leaf{key: byteArray(573), value: []byte{0x01}}, - } - - buffer := bytes.NewBuffer(nil) - const parallel = false - - for _, test := range tests { - err := encodeNode(test, buffer, parallel) - require.NoError(t, err) - - res, err := decode(buffer) - require.NoError(t, err) - - switch n := test.(type) { - case *Branch: - require.Equal(t, n.key, res.(*Branch).key) - require.Equal(t, n.childrenBitmap(), res.(*Branch).childrenBitmap()) - require.Equal(t, n.value, res.(*Branch).value) - case *Leaf: - require.Equal(t, n.key, res.(*Leaf).key) - require.Equal(t, n.value, res.(*Leaf).value) - default: - t.Fatal("unexpected node") - } - } -} diff --git a/lib/trie/pools/pools.go b/lib/trie/pools/pools.go new file mode 100644 index 0000000000..1bfe8f5a83 --- /dev/null +++ b/lib/trie/pools/pools.go @@ -0,0 +1,42 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package pools + +import ( + "bytes" + "sync" + + "golang.org/x/crypto/blake2b" +) + +// DigestBuffers is a sync pool of buffers of capacity 32. +var DigestBuffers = &sync.Pool{ + New: func() interface{} { + const bufferCapacity = 32 + b := make([]byte, 0, bufferCapacity) + return bytes.NewBuffer(b) + }, +} + +// EncodingBuffers is a sync pool of buffers of capacity 1.9MB. +var EncodingBuffers = &sync.Pool{ + New: func() interface{} { + const initialBufferCapacity = 1900000 // 1.9MB, from checking capacities at runtime + b := make([]byte, 0, initialBufferCapacity) + return bytes.NewBuffer(b) + }, +} + +// Hashers is a sync pool of blake2b 256 hashers. +var Hashers = &sync.Pool{ + New: func() interface{} { + hasher, err := blake2b.New256(nil) + if err != nil { + // Conversation on why we panic here: + // https://github.com/ChainSafe/gossamer/pull/2009#discussion_r753430764 + panic("cannot create Blake2b-256 hasher: " + err.Error()) + } + return hasher + }, +} diff --git a/lib/trie/print.go b/lib/trie/print.go index 1804f86d99..ce3f85a979 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -8,6 +8,10 @@ import ( "fmt" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie/branch" + "github.com/ChainSafe/gossamer/lib/trie/leaf" + "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/ChainSafe/gossamer/lib/trie/pools" "github.com/disiqueira/gotree" ) @@ -23,49 +27,48 @@ func (t *Trie) String() string { return fmt.Sprintf("\n%s", tree.Print()) } -func (t *Trie) string(tree gotree.Tree, curr Node, idx int) { +func (t *Trie) string(tree gotree.Tree, curr node.Node, idx int) { switch c := curr.(type) { - case *Branch: - buffer := encodingBufferPool.Get().(*bytes.Buffer) + case *branch.Branch: + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) buffer.Reset() - const parallel = false - _ = encodeBranch(c, buffer, parallel) - c.encoding = buffer.Bytes() + _ = c.Encode(buffer) + c.Encoding = buffer.Bytes() var bstr string - if len(c.encoding) > 1024 { - bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", idx, c.String(), common.MustBlake2bHash(c.encoding), c.generation) + if len(c.Encoding) > 1024 { + bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", idx, c.String(), common.MustBlake2bHash(c.Encoding), c.Generation) } else { - bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.encoding, c.generation) + bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.Encoding, c.Generation) } - encodingBufferPool.Put(buffer) + pools.EncodingBuffers.Put(buffer) sub := tree.Add(bstr) - for i, child := range c.children { + for i, child := range c.Children { if child != nil { t.string(sub, child, i) } } - case *Leaf: - buffer := encodingBufferPool.Get().(*bytes.Buffer) + case *leaf.Leaf: + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) buffer.Reset() - _ = encodeLeaf(c, buffer) + _ = c.Encode(buffer) - c.encodingMu.Lock() - defer c.encodingMu.Unlock() - c.encoding = buffer.Bytes() + // TODO lock or use methods on leaf to set the encoding bytes. + // Right now this is only used for debugging so no need to lock + c.Encoding = buffer.Bytes() var bstr string - if len(c.encoding) > 1024 { - bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", idx, c.String(), common.MustBlake2bHash(c.encoding), c.generation) + if len(c.Encoding) > 1024 { + bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", idx, c.String(), common.MustBlake2bHash(c.Encoding), c.Generation) } else { - bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.encoding, c.generation) + bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.Encoding, c.Generation) } - encodingBufferPool.Put(buffer) + pools.EncodingBuffers.Put(buffer) tree.Add(bstr) default: diff --git a/lib/trie/proof.go b/lib/trie/proof.go index 2b77f846f5..12c1d382a0 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -11,6 +11,7 @@ import ( "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie/decode" ) var ( @@ -40,7 +41,7 @@ func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, e } for _, k := range keys { - nk := keyToNibbles(k) + nk := decode.KeyLEToNibbles(k) recorder := new(recorder) err := findAndRecord(proofTrie, nk, recorder) diff --git a/lib/trie/trie.go b/lib/trie/trie.go index ad99492c8b..5e7435db5e 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -8,6 +8,12 @@ import ( "fmt" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie/branch" + "github.com/ChainSafe/gossamer/lib/trie/decode" + "github.com/ChainSafe/gossamer/lib/trie/encode" + "github.com/ChainSafe/gossamer/lib/trie/leaf" + "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/ChainSafe/gossamer/lib/trie/pools" ) // EmptyHash is the empty trie hash. @@ -18,7 +24,7 @@ var EmptyHash, _ = NewEmptyTrie().Hash() // Use NewTrie to create a trie that sits on top of a database. type Trie struct { generation uint64 - root Node + root node.Node childTries map[common.Hash]*Trie // Used to store the child tries. deletedKeys []common.Hash parallel bool @@ -30,7 +36,7 @@ func NewEmptyTrie() *Trie { } // NewTrie creates a trie with an existing root node -func NewTrie(root Node) *Trie { +func NewTrie(root node.Node) *Trie { return &Trie{ root: root, childTries: make(map[common.Hash]*Trie), @@ -63,7 +69,7 @@ func (t *Trie) Snapshot() *Trie { return newTrie } -func (t *Trie) maybeUpdateGeneration(n Node) Node { +func (t *Trie) maybeUpdateGeneration(n node.Node) node.Node { if n == nil { return nil } @@ -102,13 +108,20 @@ func (t *Trie) DeepCopy() (*Trie, error) { } // RootNode returns the root of the trie -func (t *Trie) RootNode() Node { +func (t *Trie) RootNode() node.Node { return t.root } // encodeRoot returns the encoded root of the trie func (t *Trie) encodeRoot(buffer *bytes.Buffer) (err error) { - return encodeNode(t.RootNode(), buffer, t.parallel) + if t.root == nil { + _, err = buffer.Write([]byte{0}) + if err != nil { + return fmt.Errorf("cannot write nil root node to buffer: %w", err) + } + return nil + } + return t.root.Encode(buffer) } // MustHash returns the hashed root of the trie. It panics if it fails to hash the root node. @@ -123,9 +136,9 @@ func (t *Trie) MustHash() common.Hash { // Hash returns the hashed root of the trie func (t *Trie) Hash() (common.Hash, error) { - buffer := encodingBufferPool.Get().(*bytes.Buffer) + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) buffer.Reset() - defer encodingBufferPool.Put(buffer) + defer pools.EncodingBuffers.Put(buffer) err := t.encodeRoot(buffer) if err != nil { @@ -140,17 +153,17 @@ func (t *Trie) Entries() map[string][]byte { return t.entries(t.root, nil, make(map[string][]byte)) } -func (t *Trie) entries(current Node, prefix []byte, kv map[string][]byte) map[string][]byte { +func (t *Trie) entries(current node.Node, prefix []byte, kv map[string][]byte) map[string][]byte { switch c := current.(type) { - case *Branch: - if c.value != nil { - kv[string(nibblesToKeyLE(append(prefix, c.key...)))] = c.value + case *branch.Branch: + if c.Value != nil { + kv[string(encode.NibblesToKeyLE(append(prefix, c.Key...)))] = c.Value } - for i, child := range c.children { - t.entries(child, append(prefix, append(c.key, byte(i))...), kv) + for i, child := range c.Children { + t.entries(child, append(prefix, append(c.Key, byte(i))...), kv) } - case *Leaf: - kv[string(nibblesToKeyLE(append(prefix, c.key...)))] = c.value + case *leaf.Leaf: + kv[string(encode.NibblesToKeyLE(append(prefix, c.Key...)))] = c.Value return kv } @@ -159,20 +172,20 @@ func (t *Trie) entries(current Node, prefix []byte, kv map[string][]byte) map[st // NextKey returns the next key in the trie in lexicographic order. It returns nil if there is no next key func (t *Trie) NextKey(key []byte) []byte { - k := keyToNibbles(key) + k := decode.KeyLEToNibbles(key) next := t.nextKey(t.root, nil, k) if next == nil { return nil } - return nibblesToKeyLE(next) + return encode.NibblesToKeyLE(next) } -func (t *Trie) nextKey(curr Node, prefix, key []byte) []byte { +func (t *Trie) nextKey(curr node.Node, prefix, key []byte) []byte { switch c := curr.(type) { - case *Branch: - fullKey := append(prefix, c.key...) + case *branch.Branch: + fullKey := append(prefix, c.Key...) var cmp int if len(key) < len(fullKey) { if bytes.Compare(key, fullKey[:len(key)]) == 1 { // arg key is greater than full, return nil @@ -190,11 +203,11 @@ func (t *Trie) nextKey(curr Node, prefix, key []byte) []byte { // return key of first child, or key of this branch, // if it's a branch with value. if (cmp == 0 && len(key) == len(fullKey)) || cmp == 1 { - if c.value != nil && bytes.Compare(fullKey, key) > 0 { + if c.Value != nil && bytes.Compare(fullKey, key) > 0 { return fullKey } - for i, child := range c.children { + for i, child := range c.Children { if child == nil { continue } @@ -209,7 +222,7 @@ func (t *Trie) nextKey(curr Node, prefix, key []byte) []byte { // node key isn't greater than the arg key, continue to iterate if cmp < 1 && len(key) > len(fullKey) { idx := key[len(fullKey)] - for i, child := range c.children[idx:] { + for i, child := range c.Children[idx:] { if child == nil { continue } @@ -220,8 +233,8 @@ func (t *Trie) nextKey(curr Node, prefix, key []byte) []byte { } } } - case *Leaf: - fullKey := append(prefix, c.key...) + case *leaf.Leaf: + fullKey := append(prefix, c.Key...) var cmp int if len(key) < len(fullKey) { if bytes.Compare(key, fullKey[:len(key)]) == 1 { // arg key is greater than full, return nil @@ -236,7 +249,7 @@ func (t *Trie) nextKey(curr Node, prefix, key []byte) []byte { } if cmp == 1 { - return append(prefix, c.key...) + return append(prefix, c.Key...) } case nil: return nil @@ -250,15 +263,15 @@ func (t *Trie) Put(key, value []byte) { } func (t *Trie) tryPut(key, value []byte) { - k := keyToNibbles(key) + k := decode.KeyLEToNibbles(key) - t.root = t.insert(t.root, k, &Leaf{key: nil, value: value, dirty: true, generation: t.generation}) + t.root = t.insert(t.root, k, &leaf.Leaf{Key: nil, Value: value, Dirty: true, Generation: t.generation}) } // insert attempts to insert a key with value into the trie -func (t *Trie) insert(parent Node, key []byte, value Node) Node { +func (t *Trie) insert(parent node.Node, key []byte, value node.Node) node.Node { switch p := t.maybeUpdateGeneration(parent).(type) { - case *Branch: + case *branch.Branch: n := t.updateBranch(p, key, value) if p != nil && n != nil && n.IsDirty() { @@ -268,32 +281,32 @@ func (t *Trie) insert(parent Node, key []byte, value Node) Node { case nil: value.SetKey(key) return value - case *Leaf: + case *leaf.Leaf: // if a value already exists in the trie at this key, overwrite it with the new value // if the values are the same, don't mark node dirty - if p.value != nil && bytes.Equal(p.key, key) { - if !bytes.Equal(value.(*Leaf).value, p.value) { - p.value = value.(*Leaf).value - p.dirty = true + if p.Value != nil && bytes.Equal(p.Key, key) { + if !bytes.Equal(value.(*leaf.Leaf).Value, p.Value) { + p.Value = value.(*leaf.Leaf).Value + p.Dirty = true } return p } - length := lenCommonPrefix(key, p.key) + length := lenCommonPrefix(key, p.Key) // need to convert this leaf into a branch - br := &Branch{key: key[:length], dirty: true, generation: t.generation} - parentKey := p.key + br := &branch.Branch{Key: key[:length], Dirty: true, Generation: t.generation} + parentKey := p.Key // value goes at this branch if len(key) == length { - br.value = value.(*Leaf).value + br.Value = value.(*leaf.Leaf).Value br.SetDirty(true) // if we are not replacing previous leaf, then add it as a child to the new branch if len(parentKey) > len(key) { - p.key = p.key[length+1:] - br.children[parentKey[length]] = p + p.Key = p.Key[length+1:] + br.Children[parentKey[length]] = p p.SetDirty(true) } @@ -302,17 +315,17 @@ func (t *Trie) insert(parent Node, key []byte, value Node) Node { value.SetKey(key[length+1:]) - if length == len(p.key) { + if length == len(p.Key) { // if leaf's key is covered by this branch, then make the leaf's // value the value at this branch - br.value = p.value - br.children[key[length]] = value + br.Value = p.Value + br.Children[key[length]] = value } else { // otherwise, make the leaf a child of the branch and update its partial key - p.key = p.key[length+1:] + p.Key = p.Key[length+1:] p.SetDirty(true) - br.children[parentKey[length]] = p - br.children[key[length]] = value + br.Children[parentKey[length]] = p + br.Children[key[length]] = value } return br @@ -324,34 +337,34 @@ func (t *Trie) insert(parent Node, key []byte, value Node) Node { // updateBranch attempts to add the value node to a branch // inserts the value node as the branch's child at the index that's // the first nibble of the key -func (t *Trie) updateBranch(p *Branch, key []byte, value Node) (n Node) { - length := lenCommonPrefix(key, p.key) +func (t *Trie) updateBranch(p *branch.Branch, key []byte, value node.Node) (n node.Node) { + length := lenCommonPrefix(key, p.Key) // whole parent key matches - if length == len(p.key) { + if length == len(p.Key) { // if node has same key as this branch, then update the value at this branch - if bytes.Equal(key, p.key) { + if bytes.Equal(key, p.Key) { p.SetDirty(true) switch v := value.(type) { - case *Branch: - p.value = v.value - case *Leaf: - p.value = v.value + case *branch.Branch: + p.Value = v.Value + case *leaf.Leaf: + p.Value = v.Value } return p } - switch c := p.children[key[length]].(type) { - case *Branch, *Leaf: + switch c := p.Children[key[length]].(type) { + case *branch.Branch, *leaf.Leaf: n = t.insert(c, key[length+1:], value) - p.children[key[length]] = n + p.Children[key[length]] = n n.SetDirty(true) p.SetDirty(true) return p case nil: // otherwise, add node as child of this branch - value.(*Leaf).key = key[length+1:] - p.children[key[length]] = value + value.(*leaf.Leaf).Key = key[length+1:] + p.Children[key[length]] = value p.SetDirty(true) return p } @@ -361,15 +374,15 @@ func (t *Trie) updateBranch(p *Branch, key []byte, value Node) (n Node) { // we need to branch out at the point where the keys diverge // update partial keys, new branch has key up to matching length - br := &Branch{key: key[:length], dirty: true, generation: t.generation} + br := &branch.Branch{Key: key[:length], Dirty: true, Generation: t.generation} - parentIndex := p.key[length] - br.children[parentIndex] = t.insert(nil, p.key[length+1:], p) + parentIndex := p.Key[length] + br.Children[parentIndex] = t.insert(nil, p.Key[length+1:], p) if len(key) <= length { - br.value = value.(*Leaf).value + br.Value = value.(*leaf.Leaf).Value } else { - br.children[key[length]] = t.insert(nil, key[length+1:], value) + br.Children[key[length]] = t.insert(nil, key[length+1:], value) } br.SetDirty(true) @@ -397,7 +410,7 @@ func (t *Trie) LoadFromMap(data map[string]string) error { func (t *Trie) GetKeysWithPrefix(prefix []byte) [][]byte { var p []byte if len(prefix) != 0 { - p = keyToNibbles(prefix) + p = decode.KeyLEToNibbles(prefix) if p[len(p)-1] == 0 { p = p[:len(p)-1] } @@ -406,28 +419,28 @@ func (t *Trie) GetKeysWithPrefix(prefix []byte) [][]byte { return t.getKeysWithPrefix(t.root, []byte{}, p, [][]byte{}) } -func (t *Trie) getKeysWithPrefix(parent Node, prefix, key []byte, keys [][]byte) [][]byte { +func (t *Trie) getKeysWithPrefix(parent node.Node, prefix, key []byte, keys [][]byte) [][]byte { switch p := parent.(type) { - case *Branch: - length := lenCommonPrefix(p.key, key) + case *branch.Branch: + length := lenCommonPrefix(p.Key, key) - if bytes.Equal(p.key[:length], key) || len(key) == 0 { + if bytes.Equal(p.Key[:length], key) || len(key) == 0 { // node has prefix, add to list and add all descendant nodes to list keys = t.addAllKeys(p, prefix, keys) return keys } - if len(key) <= len(p.key) || length < len(p.key) { + if len(key) <= len(p.Key) || length < len(p.Key) { // no prefixed keys to be found here, return return keys } - key = key[len(p.key):] - keys = t.getKeysWithPrefix(p.children[key[0]], append(append(prefix, p.key...), key[0]), key[1:], keys) - case *Leaf: - length := lenCommonPrefix(p.key, key) - if bytes.Equal(p.key[:length], key) || len(key) == 0 { - keys = append(keys, nibblesToKeyLE(append(prefix, p.key...))) + key = key[len(p.Key):] + keys = t.getKeysWithPrefix(p.Children[key[0]], append(append(prefix, p.Key...), key[0]), key[1:], keys) + case *leaf.Leaf: + length := lenCommonPrefix(p.Key, key) + if bytes.Equal(p.Key[:length], key) || len(key) == 0 { + keys = append(keys, encode.NibblesToKeyLE(append(prefix, p.Key...))) } case nil: return keys @@ -437,18 +450,18 @@ func (t *Trie) getKeysWithPrefix(parent Node, prefix, key []byte, keys [][]byte) // addAllKeys appends all keys that are descendants of the parent node to a slice of keys // it uses the prefix to determine the entire key -func (t *Trie) addAllKeys(parent Node, prefix []byte, keys [][]byte) [][]byte { +func (t *Trie) addAllKeys(parent node.Node, prefix []byte, keys [][]byte) [][]byte { switch p := parent.(type) { - case *Branch: - if p.value != nil { - keys = append(keys, nibblesToKeyLE(append(prefix, p.key...))) + case *branch.Branch: + if p.Value != nil { + keys = append(keys, encode.NibblesToKeyLE(append(prefix, p.Key...))) } - for i, child := range p.children { - keys = t.addAllKeys(child, append(append(prefix, p.key...), byte(i)), keys) + for i, child := range p.Children { + keys = t.addAllKeys(child, append(append(prefix, p.Key...), byte(i)), keys) } - case *Leaf: - keys = append(keys, nibblesToKeyLE(append(prefix, p.key...))) + case *leaf.Leaf: + keys = append(keys, encode.NibblesToKeyLE(append(prefix, p.Key...))) case nil: return keys } @@ -463,36 +476,36 @@ func (t *Trie) Get(key []byte) []byte { return nil } - return l.value + return l.Value } -func (t *Trie) tryGet(key []byte) *Leaf { - k := keyToNibbles(key) +func (t *Trie) tryGet(key []byte) *leaf.Leaf { + k := decode.KeyLEToNibbles(key) return t.retrieve(t.root, k) } -func (t *Trie) retrieve(parent Node, key []byte) *Leaf { +func (t *Trie) retrieve(parent node.Node, key []byte) *leaf.Leaf { var ( - value *Leaf + value *leaf.Leaf ) switch p := parent.(type) { - case *Branch: - length := lenCommonPrefix(p.key, key) + case *branch.Branch: + length := lenCommonPrefix(p.Key, key) // found the value at this node - if bytes.Equal(p.key, key) || len(key) == 0 { - return &Leaf{key: p.key, value: p.value, dirty: false} + if bytes.Equal(p.Key, key) || len(key) == 0 { + return &leaf.Leaf{Key: p.Key, Value: p.Value, Dirty: false} } // did not find value - if bytes.Equal(p.key[:length], key) && len(key) < len(p.key) { + if bytes.Equal(p.Key[:length], key) && len(key) < len(p.Key) { return nil } - value = t.retrieve(p.children[key[length]], key[length+1:]) - case *Leaf: - if bytes.Equal(p.key, key) { + value = t.retrieve(p.Children[key[length]], key[length+1:]) + case *leaf.Leaf: + if bytes.Equal(p.Key, key) { value = p } case nil: @@ -507,7 +520,7 @@ func (t *Trie) ClearPrefixLimit(prefix []byte, limit uint32) (uint32, bool) { return 0, false } - p := keyToNibbles(prefix) + p := decode.KeyLEToNibbles(prefix) if len(p) > 0 && p[len(p)-1] == 0 { p = p[:len(p)-1] } @@ -520,12 +533,12 @@ func (t *Trie) ClearPrefixLimit(prefix []byte, limit uint32) (uint32, bool) { // clearPrefixLimit deletes the keys having the prefix till limit reached and returns updated trie root node, // true if any node in the trie got updated, and next bool returns true if there is no keys left with prefix. -func (t *Trie) clearPrefixLimit(cn Node, prefix []byte, limit *uint32) (Node, bool, bool) { +func (t *Trie) clearPrefixLimit(cn node.Node, prefix []byte, limit *uint32) (node.Node, bool, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { - case *Branch: - length := lenCommonPrefix(c.key, prefix) + case *branch.Branch: + length := lenCommonPrefix(c.Key, prefix) if length == len(prefix) { n, _ := t.deleteNodes(c, []byte{}, limit) if n == nil { @@ -534,36 +547,36 @@ func (t *Trie) clearPrefixLimit(cn Node, prefix []byte, limit *uint32) (Node, bo return n, true, false } - if len(prefix) == len(c.key)+1 && length == len(prefix)-1 { - i := prefix[len(c.key)] - c.children[i], _ = t.deleteNodes(c.children[i], []byte{}, limit) + if len(prefix) == len(c.Key)+1 && length == len(prefix)-1 { + i := prefix[len(c.Key)] + c.Children[i], _ = t.deleteNodes(c.Children[i], []byte{}, limit) c.SetDirty(true) curr = handleDeletion(c, prefix) - if c.children[i] == nil { + if c.Children[i] == nil { return curr, true, true } return c, true, false } - if len(prefix) <= len(c.key) || length < len(c.key) { + if len(prefix) <= len(c.Key) || length < len(c.Key) { // this node doesn't have the prefix, return return c, false, true } - i := prefix[len(c.key)] + i := prefix[len(c.Key)] var wasUpdated, allDeleted bool - c.children[i], wasUpdated, allDeleted = t.clearPrefixLimit(c.children[i], prefix[len(c.key)+1:], limit) + c.Children[i], wasUpdated, allDeleted = t.clearPrefixLimit(c.Children[i], prefix[len(c.Key)+1:], limit) if wasUpdated { c.SetDirty(true) curr = handleDeletion(c, prefix) } return curr, curr.IsDirty(), allDeleted - case *Leaf: - length := lenCommonPrefix(c.key, prefix) + case *leaf.Leaf: + length := lenCommonPrefix(c.Key, prefix) if length == len(prefix) { *limit-- return nil, true, true @@ -578,35 +591,35 @@ func (t *Trie) clearPrefixLimit(cn Node, prefix []byte, limit *uint32) (Node, bo return nil, false, true } -func (t *Trie) deleteNodes(cn Node, prefix []byte, limit *uint32) (Node, bool) { +func (t *Trie) deleteNodes(cn node.Node, prefix []byte, limit *uint32) (node.Node, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { - case *Leaf: + case *leaf.Leaf: if *limit == 0 { return c, false } *limit-- return nil, true - case *Branch: - if len(c.key) != 0 { - prefix = append(prefix, c.key...) + case *branch.Branch: + if len(c.Key) != 0 { + prefix = append(prefix, c.Key...) } - for i, child := range c.children { + for i, child := range c.Children { if child == nil { continue } var isDel bool - if c.children[i], isDel = t.deleteNodes(child, prefix, limit); !isDel { + if c.Children[i], isDel = t.deleteNodes(child, prefix, limit); !isDel { continue } c.SetDirty(true) curr = handleDeletion(c, prefix) - isAllNil := c.numChildren() == 0 - if isAllNil && c.value == nil { + isAllNil := c.NumChildren() == 0 + if isAllNil && c.Value == nil { curr = nil } @@ -620,7 +633,7 @@ func (t *Trie) deleteNodes(cn Node, prefix []byte, limit *uint32) (Node, bool) { } // Delete the current node as well - if c.value != nil { + if c.Value != nil { *limit-- } return nil, true @@ -636,7 +649,7 @@ func (t *Trie) ClearPrefix(prefix []byte) { return } - p := keyToNibbles(prefix) + p := decode.KeyLEToNibbles(prefix) if len(p) > 0 && p[len(p)-1] == 0 { p = p[:len(p)-1] } @@ -644,11 +657,11 @@ func (t *Trie) ClearPrefix(prefix []byte) { t.root, _ = t.clearPrefix(t.root, p) } -func (t *Trie) clearPrefix(cn Node, prefix []byte) (Node, bool) { +func (t *Trie) clearPrefix(cn node.Node, prefix []byte) (node.Node, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { - case *Branch: - length := lenCommonPrefix(c.key, prefix) + case *branch.Branch: + length := lenCommonPrefix(c.Key, prefix) if length == len(prefix) { // found prefix at this branch, delete it @@ -657,32 +670,32 @@ func (t *Trie) clearPrefix(cn Node, prefix []byte) (Node, bool) { // Store the current node and return it, if the trie is not updated. - if len(prefix) == len(c.key)+1 && length == len(prefix)-1 { + if len(prefix) == len(c.Key)+1 && length == len(prefix)-1 { // found prefix at child index, delete child - i := prefix[len(c.key)] - c.children[i] = nil + i := prefix[len(c.Key)] + c.Children[i] = nil c.SetDirty(true) curr = handleDeletion(c, prefix) return curr, true } - if len(prefix) <= len(c.key) || length < len(c.key) { + if len(prefix) <= len(c.Key) || length < len(c.Key) { // this node doesn't have the prefix, return return c, false } var wasUpdated bool - i := prefix[len(c.key)] + i := prefix[len(c.Key)] - c.children[i], wasUpdated = t.clearPrefix(c.children[i], prefix[len(c.key)+1:]) + c.Children[i], wasUpdated = t.clearPrefix(c.Children[i], prefix[len(c.Key)+1:]) if wasUpdated { c.SetDirty(true) curr = handleDeletion(c, prefix) } return curr, curr.IsDirty() - case *Leaf: - length := lenCommonPrefix(c.key, prefix) + case *leaf.Leaf: + length := lenCommonPrefix(c.Key, prefix) if length == len(prefix) { return nil, true } @@ -696,35 +709,35 @@ func (t *Trie) clearPrefix(cn Node, prefix []byte) (Node, bool) { // Delete removes any existing value for key from the trie. func (t *Trie) Delete(key []byte) { - k := keyToNibbles(key) + k := decode.KeyLEToNibbles(key) t.root, _ = t.delete(t.root, k) } -func (t *Trie) delete(parent Node, key []byte) (Node, bool) { +func (t *Trie) delete(parent node.Node, key []byte) (node.Node, bool) { // Store the current node and return it, if the trie is not updated. switch p := t.maybeUpdateGeneration(parent).(type) { - case *Branch: + case *branch.Branch: - length := lenCommonPrefix(p.key, key) - if bytes.Equal(p.key, key) || len(key) == 0 { + length := lenCommonPrefix(p.Key, key) + if bytes.Equal(p.Key, key) || len(key) == 0 { // found the value at this node - p.value = nil + p.Value = nil p.SetDirty(true) return handleDeletion(p, key), true } - n, del := t.delete(p.children[key[length]], key[length+1:]) + n, del := t.delete(p.Children[key[length]], key[length+1:]) if !del { // If nothing was deleted then don't copy the path. return p, false } - p.children[key[length]] = n + p.Children[key[length]] = n p.SetDirty(true) n = handleDeletion(p, key) return n, true - case *Leaf: - if bytes.Equal(key, p.key) || len(key) == 0 { + case *leaf.Leaf: + if bytes.Equal(key, p.Key) || len(key) == 0 { // Key exists. Delete it. return nil, true } @@ -740,15 +753,15 @@ func (t *Trie) delete(parent Node, key []byte) (Node, bool) { // handleDeletion is called when a value is deleted from a branch // if the updated branch only has 1 child, it should be combined with that child // if the updated branch only has a value, it should be turned into a leaf -func handleDeletion(p *Branch, key []byte) Node { - var n Node = p - length := lenCommonPrefix(p.key, key) - bitmap := p.childrenBitmap() +func handleDeletion(p *branch.Branch, key []byte) node.Node { + var n node.Node = p + length := lenCommonPrefix(p.Key, key) + bitmap := p.ChildrenBitmap() // if branch has no children, just a value, turn it into a leaf - if bitmap == 0 && p.value != nil { - n = &Leaf{key: key[:length], value: p.value, dirty: true} - } else if p.numChildren() == 1 && p.value == nil { + if bitmap == 0 && p.Value != nil { + n = &leaf.Leaf{Key: key[:length], Value: p.Value, Dirty: true} + } else if p.NumChildren() == 1 && p.Value == nil { // there is only 1 child and no value, combine the child branch with this branch // find index of child var i int @@ -759,22 +772,22 @@ func handleDeletion(p *Branch, key []byte) Node { } } - child := p.children[i] + child := p.Children[i] switch c := child.(type) { - case *Leaf: - n = &Leaf{key: append(append(p.key, []byte{byte(i)}...), c.key...), value: c.value} - case *Branch: - br := new(Branch) - br.key = append(p.key, append([]byte{byte(i)}, c.key...)...) + case *leaf.Leaf: + n = &leaf.Leaf{Key: append(append(p.Key, []byte{byte(i)}...), c.Key...), Value: c.Value} + case *branch.Branch: + br := new(branch.Branch) + br.Key = append(p.Key, append([]byte{byte(i)}, c.Key...)...) // adopt the grandchildren - for i, grandchild := range c.children { + for i, grandchild := range c.Children { if grandchild != nil { - br.children[i] = grandchild + br.Children[i] = grandchild } } - br.value = c.value + br.Value = c.Value n = br default: // do nothing diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index 29cdc54e0a..150f272f63 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -21,6 +21,8 @@ import ( "github.com/stretchr/testify/require" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/trie/decode" + "github.com/ChainSafe/gossamer/lib/trie/leaf" ) type commonPrefixTest struct { @@ -68,7 +70,7 @@ func TestNewEmptyTrie(t *testing.T) { } func TestNewTrie(t *testing.T) { - trie := NewTrie(&Leaf{key: []byte{0}, value: []byte{17}}) + trie := NewTrie(&leaf.Leaf{Key: []byte{0}, Value: []byte{17}}) if trie == nil { t.Error("did not initialise trie") } @@ -160,10 +162,10 @@ func runTests(t *testing.T, trie *Trie, tests []Test) { leaf := trie.tryGet(test.key) if leaf == nil { t.Errorf("Fail to get key %x: nil leaf", test.key) - } else if !bytes.Equal(leaf.value, test.value) { - t.Errorf("Fail to get key %x with value %x: got %x", test.key, test.value, leaf.value) - } else if !bytes.Equal(leaf.key, test.pk) { - t.Errorf("Fail to get correct partial key %x with key %x: got %x", test.pk, test.key, leaf.key) + } else if !bytes.Equal(leaf.Value, test.value) { + t.Errorf("Fail to get key %x with value %x: got %x", test.key, test.value, leaf.Value) + } else if !bytes.Equal(leaf.Key, test.pk) { + t.Errorf("Fail to get correct partial key %x with key %x: got %x", test.pk, test.key, leaf.Key) } } }) @@ -873,7 +875,7 @@ func TestClearPrefix(t *testing.T) { require.Equal(t, dcTrieHash, ssTrieHash) ssTrie.ClearPrefix(prefix) - prefixNibbles := keyToNibbles(prefix) + prefixNibbles := decode.KeyLEToNibbles(prefix) if len(prefixNibbles) > 0 && prefixNibbles[len(prefixNibbles)-1] == 0 { prefixNibbles = prefixNibbles[:len(prefixNibbles)-1] } @@ -881,7 +883,7 @@ func TestClearPrefix(t *testing.T) { for _, test := range tests { res := ssTrie.Get(test.key) - keyNibbles := keyToNibbles(test.key) + keyNibbles := decode.KeyLEToNibbles(test.key) length := lenCommonPrefix(keyNibbles, prefixNibbles) if length == len(prefixNibbles) { require.Nil(t, res) @@ -942,7 +944,11 @@ func TestClearPrefix_Small(t *testing.T) { } ssTrie.ClearPrefix([]byte("noo")) - require.Equal(t, ssTrie.root, &Leaf{key: keyToNibbles([]byte("other")), value: []byte("other"), dirty: true}) + require.Equal(t, ssTrie.root, &leaf.Leaf{ + Key: decode.KeyLEToNibbles([]byte("other")), + Value: []byte("other"), + Dirty: true, + }) // Get the updated root hash of all tries. tHash, err = trie.Hash() @@ -1310,7 +1316,7 @@ func TestTrie_ClearPrefixLimit(t *testing.T) { } testFn := func(testCase []Test, prefix []byte) { - prefixNibbles := keyToNibbles(prefix) + prefixNibbles := decode.KeyLEToNibbles(prefix) if len(prefixNibbles) > 0 && prefixNibbles[len(prefixNibbles)-1] == 0 { prefixNibbles = prefixNibbles[:len(prefixNibbles)-1] } @@ -1329,7 +1335,7 @@ func TestTrie_ClearPrefixLimit(t *testing.T) { for _, test := range testCase { val := trieClearPrefix.Get(test.key) - keyNibbles := keyToNibbles(test.key) + keyNibbles := decode.KeyLEToNibbles(test.key) length := lenCommonPrefix(keyNibbles, prefixNibbles) if length == len(prefixNibbles) { @@ -1418,7 +1424,7 @@ func TestTrie_ClearPrefixLimitSnapshot(t *testing.T) { for _, testCase := range cases { for _, prefix := range prefixes { - prefixNibbles := keyToNibbles(prefix) + prefixNibbles := decode.KeyLEToNibbles(prefix) if len(prefixNibbles) > 0 && prefixNibbles[len(prefixNibbles)-1] == 0 { prefixNibbles = prefixNibbles[:len(prefixNibbles)-1] } @@ -1458,7 +1464,7 @@ func TestTrie_ClearPrefixLimitSnapshot(t *testing.T) { for _, test := range testCase { val := ssTrie.Get(test.key) - keyNibbles := keyToNibbles(test.key) + keyNibbles := decode.KeyLEToNibbles(test.key) length := lenCommonPrefix(keyNibbles, prefixNibbles) if length == len(prefixNibbles) { From d87debaa91f3717a043a61175d454410a87212b0 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 29 Nov 2021 15:46:26 +0000 Subject: [PATCH 07/50] Address TODOs --- lib/trie/branch/children.go | 11 +++---- lib/trie/branch/decode.go | 9 +----- lib/trie/branch/decode_test.go | 36 +++++++---------------- lib/trie/branch/encode.go | 4 --- lib/trie/encodedecode_test/branch_test.go | 1 - lib/trie/leaf/decode.go | 11 ++----- lib/trie/leaf/decode_test.go | 28 +++++------------- lib/trie/leaf/encode.go | 13 -------- 8 files changed, 26 insertions(+), 87 deletions(-) diff --git a/lib/trie/branch/children.go b/lib/trie/branch/children.go index 54395458da..ff911dc513 100644 --- a/lib/trie/branch/children.go +++ b/lib/trie/branch/children.go @@ -5,10 +5,8 @@ package branch // ChildrenBitmap returns the 16 bit bitmap // of the children in the branch. -func (b *Branch) ChildrenBitmap() uint16 { - var bitmap uint16 - var i uint - for i = 0; i < 16; i++ { +func (b *Branch) ChildrenBitmap() (bitmap uint16) { + for i := uint(0); i < 16; i++ { if b.Children[i] != nil { bitmap = bitmap | 1<> 6 if nodeType != 2 && nodeType != 3 { return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotABranch, nodeType) @@ -81,7 +74,7 @@ func Decode(reader io.Reader, header byte) (branch *Branch, err error) { } } - branch.Dirty = true // TODO move as soon as it gets modified? + branch.Dirty = true return branch, nil } diff --git a/lib/trie/branch/decode_test.go b/lib/trie/branch/decode_test.go index 9d89ea2d7a..db2ce1f43f 100644 --- a/lib/trie/branch/decode_test.go +++ b/lib/trie/branch/decode_test.go @@ -44,46 +44,36 @@ func Test_Decode(t *testing.T) { errWrapped error errMessage string }{ - "no data with header 0": { - reader: bytes.NewBuffer(nil), - errWrapped: ErrReadHeaderByte, - errMessage: "cannot read header byte: EOF", - }, "no data with header 1": { reader: bytes.NewBuffer(nil), - header: 1, - errWrapped: ErrNodeTypeIsNotABranch, - errMessage: "node type is not a branch: 0", - }, - "first byte as 0 header 0": { - reader: bytes.NewBuffer([]byte{0}), + header: 65, errWrapped: ErrNodeTypeIsNotABranch, - errMessage: "node type is not a branch: 0", + errMessage: "node type is not a branch: 1", }, "key decoding error": { reader: bytes.NewBuffer([]byte{ - 129, // node type 2 and key length 1 // missing key data byte }), + header: 129, // node type 2 and key length 1 errWrapped: decode.ErrReadKeyData, errMessage: "cannot decode key: cannot read key data: EOF", }, "children bitmap read error": { reader: bytes.NewBuffer([]byte{ - 129, // node type 2 and key length 1 - 9, // key data + 9, // key data // missing children bitmap 2 bytes }), + header: 129, // node type 2 and key length 1 errWrapped: ErrReadChildrenBitmap, errMessage: "cannot read children bitmap: EOF", }, "children decoding error": { reader: bytes.NewBuffer([]byte{ - 129, // node type 2 and key length 1 9, // key data 0, 4, // children bitmap // missing children scale encoded data }), + header: 129, // node type 2 and key length 1 errWrapped: ErrDecodeChildHash, errMessage: "cannot decode child hash: at index 10: EOF", }, @@ -91,13 +81,13 @@ func Test_Decode(t *testing.T) { reader: bytes.NewBuffer( concatByteSlices([][]byte{ { - 129, // node type 2 and key length 1 9, // key data 0, 4, // children bitmap }, scaleEncodeBytes(t, 1, 2, 3, 4, 5), // child hash }), ), + header: 129, // node type 2 and key length 1 branch: &Branch{ Key: []byte{9}, Children: [16]node.Node{ @@ -113,29 +103,25 @@ func Test_Decode(t *testing.T) { "value decoding error for node type 3": { reader: bytes.NewBuffer( concatByteSlices([][]byte{ - { - 193, // node type 3 and key length 1 - 9, // key data - }, + {9}, // key data {0, 4}, // children bitmap // missing encoded branch value }), ), + header: 193, // node type 3 and key length 1 errWrapped: ErrDecodeValue, errMessage: "cannot decode value: EOF", }, "success node type 3": { reader: bytes.NewBuffer( concatByteSlices([][]byte{ - { - 193, // node type 3 and key length 1 - 9, // key data - }, + {9}, // key data {0, 4}, // children bitmap scaleEncodeBytes(t, 7, 8, 9), // branch value scaleEncodeBytes(t, 1, 2, 3, 4, 5), // child hash }), ), + header: 193, // node type 3 and key length 1 branch: &Branch{ Key: []byte{9}, Value: []byte{7, 8, 9}, diff --git a/lib/trie/branch/encode.go b/lib/trie/branch/encode.go index 3a1bfa4e74..04d7bc9cf9 100644 --- a/lib/trie/branch/encode.go +++ b/lib/trie/branch/encode.go @@ -21,10 +21,6 @@ import ( // and then SCALE encodes it. This is used to encode children // nodes of branches. func (b *Branch) ScaleEncodeHash() (encoding []byte, err error) { - // if b == nil { // TODO remove - // panic("Should write 0 to buffer") - // } - buffer := pools.DigestBuffers.Get().(*bytes.Buffer) buffer.Reset() defer pools.DigestBuffers.Put(buffer) diff --git a/lib/trie/encodedecode_test/branch_test.go b/lib/trie/encodedecode_test/branch_test.go index 41b5d68dcc..4cecf6893c 100644 --- a/lib/trie/encodedecode_test/branch_test.go +++ b/lib/trie/encodedecode_test/branch_test.go @@ -60,7 +60,6 @@ func Test_Branch_Encode_Decode(t *testing.T) { Key: []byte{5}, Children: [16]node.Node{ &leaf.Leaf{ - // TODO key and value are nil here?? Why? Hash: []byte{0x41, 0x9, 0x4, 0xa}, }, }, diff --git a/lib/trie/leaf/decode.go b/lib/trie/leaf/decode.go index a01afb4a1c..6c20fffbcc 100644 --- a/lib/trie/leaf/decode.go +++ b/lib/trie/leaf/decode.go @@ -19,14 +19,7 @@ var ( ) // Decode reads and decodes from a reader with the encoding specified in lib/trie/encode/doc.go. -func Decode(r io.Reader, header byte) (leaf *Leaf, err error) { // TODO return leaf - if header == 0 { // TODO remove this is taken care of by the caller - header, err = decode.ReadNextByte(r) - if err != nil { - return nil, fmt.Errorf("%w: %s", ErrReadHeaderByte, err) - } - } - +func Decode(r io.Reader, header byte) (leaf *Leaf, err error) { nodeType := header >> 6 if nodeType != 1 { return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotALeaf, nodeType) @@ -51,7 +44,7 @@ func Decode(r io.Reader, header byte) (leaf *Leaf, err error) { // TODO return l leaf.Value = value } - leaf.Dirty = true // TODO move this as soon as it gets modified + leaf.Dirty = true return leaf, nil } diff --git a/lib/trie/leaf/decode_test.go b/lib/trie/leaf/decode_test.go index 5e98eb71d5..e702b666e2 100644 --- a/lib/trie/leaf/decode_test.go +++ b/lib/trie/leaf/decode_test.go @@ -42,45 +42,35 @@ func Test_Decode(t *testing.T) { errWrapped error errMessage string }{ - "no data with header 0": { - reader: bytes.NewBuffer(nil), - errWrapped: ErrReadHeaderByte, - errMessage: "cannot read header byte: EOF", - }, "no data with header 1": { reader: bytes.NewBuffer(nil), header: 1, errWrapped: ErrNodeTypeIsNotALeaf, errMessage: "node type is not a leaf: 0", }, - "first byte as 0 header 0": { - reader: bytes.NewBuffer([]byte{0}), - errWrapped: ErrNodeTypeIsNotALeaf, - errMessage: "node type is not a leaf: 0", - }, "key decoding error": { reader: bytes.NewBuffer([]byte{ - 65, // node type 1 and key length 1 // missing key data byte }), + header: 65, // node type 1 and key length 1 errWrapped: decode.ErrReadKeyData, errMessage: "cannot decode key: cannot read key data: EOF", }, "value decoding error": { reader: bytes.NewBuffer([]byte{ - 65, // node type 1 and key length 1 - 9, // key data + 9, // key data // missing value data }), + header: 65, // node type 1 and key length 1 errWrapped: ErrDecodeValue, errMessage: "cannot decode value: EOF", }, "zero value": { reader: bytes.NewBuffer([]byte{ - 65, // node type 1 and key length 1 - 9, // key data - 0, // missing value data + 9, // key data + 0, // missing value data }), + header: 65, // node type 1 and key length 1 leaf: &Leaf{ Key: []byte{9}, Dirty: true, @@ -89,13 +79,11 @@ func Test_Decode(t *testing.T) { "success": { reader: bytes.NewBuffer( concatByteSlices([][]byte{ - { - 65, // node type 1 and key length 1 - 9, // key data - }, + {9}, // key data scaleEncodeBytes(t, 1, 2, 3, 4, 5), // value data }), ), + header: 65, // node type 1 and key length 1 leaf: &Leaf{ Key: []byte{9}, Value: []byte{1, 2, 3, 4, 5}, diff --git a/lib/trie/leaf/encode.go b/lib/trie/leaf/encode.go index 8467d204f7..ca1dd62cc4 100644 --- a/lib/trie/leaf/encode.go +++ b/lib/trie/leaf/encode.go @@ -86,15 +86,6 @@ func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { // The encoding has the following format: // NodeHeader | Extra partial key length | Partial Key | Value func (l *Leaf) Encode(buffer encode.Buffer) (err error) { - // if l == nil { - // // TODO remove if not needed - // _, err := buffer.Write([]byte{0}) - // if err != nil { - // return fmt.Errorf("cannot write nil encoding to buffer: %w", err) - // } - // return nil - // } - l.encodingMu.RLock() if !l.Dirty && l.Encoding != nil { _, err = buffer.Write(l.Encoding) @@ -145,10 +136,6 @@ func (l *Leaf) Encode(buffer encode.Buffer) (err error) { // and then SCALE encodes it. This is used to encode children // nodes of branches. func (l *Leaf) ScaleEncodeHash() (b []byte, err error) { - // if l == nil { // TODO remove - // panic("Should write 0 to buffer") - // } - buffer := pools.DigestBuffers.Get().(*bytes.Buffer) buffer.Reset() defer pools.DigestBuffers.Put(buffer) From 77e4203c1f7c5974cb82d32e96d5cd701196de99 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 29 Nov 2021 16:10:10 +0000 Subject: [PATCH 08/50] Remove no longer needed mocks --- lib/trie/readwriter_mock_test.go | 64 -------------------------------- lib/trie/writer_mock_test.go | 49 ------------------------ 2 files changed, 113 deletions(-) delete mode 100644 lib/trie/readwriter_mock_test.go delete mode 100644 lib/trie/writer_mock_test.go diff --git a/lib/trie/readwriter_mock_test.go b/lib/trie/readwriter_mock_test.go deleted file mode 100644 index 6d1affa288..0000000000 --- a/lib/trie/readwriter_mock_test.go +++ /dev/null @@ -1,64 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: io (interfaces: ReadWriter) - -// Package trie is a generated GoMock package. -package trie - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockReadWriter is a mock of ReadWriter interface. -type MockReadWriter struct { - ctrl *gomock.Controller - recorder *MockReadWriterMockRecorder -} - -// MockReadWriterMockRecorder is the mock recorder for MockReadWriter. -type MockReadWriterMockRecorder struct { - mock *MockReadWriter -} - -// NewMockReadWriter creates a new mock instance. -func NewMockReadWriter(ctrl *gomock.Controller) *MockReadWriter { - mock := &MockReadWriter{ctrl: ctrl} - mock.recorder = &MockReadWriterMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockReadWriter) EXPECT() *MockReadWriterMockRecorder { - return m.recorder -} - -// Read mocks base method. -func (m *MockReadWriter) Read(arg0 []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Read", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Read indicates an expected call of Read. -func (mr *MockReadWriterMockRecorder) Read(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReadWriter)(nil).Read), arg0) -} - -// Write mocks base method. -func (m *MockReadWriter) Write(arg0 []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Write indicates an expected call of Write. -func (mr *MockReadWriterMockRecorder) Write(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockReadWriter)(nil).Write), arg0) -} diff --git a/lib/trie/writer_mock_test.go b/lib/trie/writer_mock_test.go deleted file mode 100644 index b1009272f2..0000000000 --- a/lib/trie/writer_mock_test.go +++ /dev/null @@ -1,49 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: io (interfaces: Writer) - -// Package trie is a generated GoMock package. -package trie - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockWriter is a mock of Writer interface. -type MockWriter struct { - ctrl *gomock.Controller - recorder *MockWriterMockRecorder -} - -// MockWriterMockRecorder is the mock recorder for MockWriter. -type MockWriterMockRecorder struct { - mock *MockWriter -} - -// NewMockWriter creates a new mock instance. -func NewMockWriter(ctrl *gomock.Controller) *MockWriter { - mock := &MockWriter{ctrl: ctrl} - mock.recorder = &MockWriterMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockWriter) EXPECT() *MockWriterMockRecorder { - return m.recorder -} - -// Write mocks base method. -func (m *MockWriter) Write(arg0 []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Write indicates an expected call of Write. -func (mr *MockWriterMockRecorder) Write(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockWriter)(nil).Write), arg0) -} From d5f06a483e9a85da1b44b2cda8a9159f0fa1e2f9 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 29 Nov 2021 17:19:02 +0000 Subject: [PATCH 09/50] Fix encode decode tests --- lib/trie/encodedecode_test/branch_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/trie/encodedecode_test/branch_test.go b/lib/trie/encodedecode_test/branch_test.go index 4cecf6893c..97e4e7f8c1 100644 --- a/lib/trie/encodedecode_test/branch_test.go +++ b/lib/trie/encodedecode_test/branch_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/ChainSafe/gossamer/lib/trie/branch" + "github.com/ChainSafe/gossamer/lib/trie/decode" "github.com/ChainSafe/gossamer/lib/trie/leaf" "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/stretchr/testify/assert" @@ -78,7 +79,9 @@ func Test_Branch_Encode_Decode(t *testing.T) { err := testCase.branchToEncode.Encode(buffer) require.NoError(t, err) - const header = 0 + header, err := decode.ReadNextByte(buffer) + require.NoError(t, err) + resultBranch, err := branch.Decode(buffer, header) require.NoError(t, err) From 07c52c231139a01a175494eea32ed3cab5517019 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 30 Nov 2021 09:14:28 +0000 Subject: [PATCH 10/50] Remove old parallel related code --- lib/trie/branch/encode_test.go | 66 ---------------------------------- lib/trie/trie.go | 4 --- lib/trie/trie_test.go | 31 +++------------- 3 files changed, 4 insertions(+), 97 deletions(-) diff --git a/lib/trie/branch/encode_test.go b/lib/trie/branch/encode_test.go index 4b9324f74b..7153329bc5 100644 --- a/lib/trie/branch/encode_test.go +++ b/lib/trie/branch/encode_test.go @@ -31,7 +31,6 @@ func Test_Branch_Encode(t *testing.T) { testCases := map[string]struct { branch *Branch writes []writeCall - parallel bool wrappedErr error errMessage string }{ @@ -179,71 +178,6 @@ func Test_Branch_Encode(t *testing.T) { "cannot encode child at index 3: " + "failed to write child to buffer: test error", }, - "buffer write error for children encoded in parallel": { - branch: &Branch{ - Key: []byte{1, 2, 3}, - Value: []byte{100}, - Children: [16]node.Node{ - nil, nil, nil, &leaf.Leaf{Key: []byte{9}}, - nil, nil, nil, &leaf.Leaf{Key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - }, - { // value - written: []byte{4, 100}, - }, - { // first children - written: []byte{12, 65, 9, 0}, - err: errTest, - }, - }, - parallel: true, - wrappedErr: errTest, - errMessage: "cannot encode children of branch: " + - "cannot encode child at index 3: " + - "failed to write child to buffer: " + - "test error", - }, - "success with parallel children encoding": { - branch: &Branch{ - Key: []byte{1, 2, 3}, - Value: []byte{100}, - Children: [16]node.Node{ - nil, nil, nil, &leaf.Leaf{Key: []byte{9}}, - nil, nil, nil, &leaf.Leaf{Key: []byte{11}}, - }, - }, - writes: []writeCall{ - { // header - written: []byte{195}, - }, - { // key LE - written: []byte{1, 35}, - }, - { // children bitmap - written: []byte{136, 0}, - }, - { // value - written: []byte{4, 100}, - }, - { // first children - written: []byte{12, 65, 9, 0}, - }, - { // second children - written: []byte{12, 65, 11, 0}, - }, - }, - parallel: true, - }, "success with sequential children encoding": { branch: &Branch{ Key: []byte{1, 2, 3}, diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 5e7435db5e..1995399974 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -27,7 +27,6 @@ type Trie struct { root node.Node childTries map[common.Hash]*Trie // Used to store the child tries. deletedKeys []common.Hash - parallel bool } // NewEmptyTrie creates a trie with a nil root @@ -42,7 +41,6 @@ func NewTrie(root node.Node) *Trie { childTries: make(map[common.Hash]*Trie), generation: 0, // Initially zero but increases after every snapshot. deletedKeys: make([]common.Hash, 0), - parallel: true, } } @@ -54,7 +52,6 @@ func (t *Trie) Snapshot() *Trie { generation: c.generation + 1, root: c.root, deletedKeys: make([]common.Hash, 0), - parallel: c.parallel, } } @@ -63,7 +60,6 @@ func (t *Trie) Snapshot() *Trie { root: t.root, childTries: children, deletedKeys: make([]common.Hash, 0), - parallel: t.parallel, } return newTrie diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index 150f272f63..e8c2f885bc 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -1131,35 +1131,12 @@ func Benchmark_Trie_Hash(b *testing.B) { trie.Put(test.key, test.value) } - trieTwo, err := trie.DeepCopy() - require.NoError(b, err) - - b.Run("Sequential hash", func(b *testing.B) { - trie.parallel = false - - b.StartTimer() - _, err := trie.Hash() - b.StopTimer() - - require.NoError(b, err) - - printMemUsage() - }) + b.StartTimer() + _, err := trie.Hash() + b.StopTimer() - b.Run("Parallel hash", func(b *testing.B) { - trieTwo.parallel = true - - b.StartTimer() - _, err := trieTwo.Hash() - b.StopTimer() - - require.NoError(b, err) - - printMemUsage() - }) -} + require.NoError(b, err) -func printMemUsage() { var m runtime.MemStats runtime.ReadMemStats(&m) // For info on each, see: https://golang.org/pkg/runtime/#MemStats From 4a7cb4c2b4d7213cdb04a5a1dfee4e6e2396506e Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 30 Nov 2021 13:13:07 +0000 Subject: [PATCH 11/50] Fix eventual bad type assertion --- lib/trie/database.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index db337d0591..09c2a0c535 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -162,7 +162,7 @@ func (t *Trie) load(db chaindb.Database, curr node.Node) error { hash := child.GetHash() enc, err := db.Get(hash) if err != nil { - return fmt.Errorf("failed to find node key=%x index=%d: %w", child.(*leaf.Leaf).Hash, i, err) + return fmt.Errorf("failed to find node key=%x index=%d: %w", hash, i, err) } child, err = decodeNode(bytes.NewBuffer(enc)) From 0063885786e7146f9b277b0ea7c7f88204018c29 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 30 Nov 2021 13:13:51 +0000 Subject: [PATCH 12/50] Remove commented Decode method in node interface --- lib/trie/node/interface.go | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/trie/node/interface.go b/lib/trie/node/interface.go index 2c1b26ee15..1e52a39204 100644 --- a/lib/trie/node/interface.go +++ b/lib/trie/node/interface.go @@ -12,7 +12,6 @@ type Node interface { Encode(buffer encode.Buffer) (err error) // TODO change to io.Writer EncodeAndHash() ([]byte, []byte, error) ScaleEncodeHash() (b []byte, err error) - // Decode(r io.Reader, h byte) error IsDirty() bool SetDirty(dirty bool) SetKey(key []byte) From a3ee3a777a510fefd0afec6524dc19046a0d64ad Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 30 Nov 2021 20:53:13 +0100 Subject: [PATCH 13/50] chore(lib/trie): `lib/trie/recorder` sub-package (#2082) * `lib/trie/recorder` subpackage * return an error on a call to Next() with no node * remove recorder `IsEmpty` method * Recorder `GetNodes()` --- lib/trie/lookup.go | 13 +++- lib/trie/proof.go | 13 ++-- lib/trie/record/node.go | 7 ++ lib/trie/record/recorder.go | 27 ++++++++ lib/trie/record/recorder_test.go | 115 +++++++++++++++++++++++++++++++ lib/trie/recorder.go | 34 --------- 6 files changed, 164 insertions(+), 45 deletions(-) create mode 100644 lib/trie/record/node.go create mode 100644 lib/trie/record/recorder.go create mode 100644 lib/trie/record/recorder_test.go delete mode 100644 lib/trie/recorder.go diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index 4df79c6c33..e51596dc69 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -8,20 +8,27 @@ import ( "github.com/ChainSafe/gossamer/lib/trie/branch" "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/ChainSafe/gossamer/lib/trie/record" ) +var _ recorder = (*record.Recorder)(nil) + +type recorder interface { + Record(hash, rawData []byte) +} + // findAndRecord search for a desired key recording all the nodes in the path including the desired node -func findAndRecord(t *Trie, key []byte, recorder *recorder) error { +func findAndRecord(t *Trie, key []byte, recorder recorder) error { return find(t.root, key, recorder) } -func find(parent node.Node, key []byte, recorder *recorder) error { +func find(parent node.Node, key []byte, recorder recorder) error { enc, hash, err := parent.EncodeAndHash() if err != nil { return err } - recorder.record(hash, enc) + recorder.Record(hash, enc) b, ok := parent.(*branch.Branch) if !ok { diff --git a/lib/trie/proof.go b/lib/trie/proof.go index 12c1d382a0..094fe48f7e 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -1,6 +1,3 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - package trie import ( @@ -12,6 +9,7 @@ import ( "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie/decode" + "github.com/ChainSafe/gossamer/lib/trie/record" ) var ( @@ -43,17 +41,16 @@ func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, e for _, k := range keys { nk := decode.KeyLEToNibbles(k) - recorder := new(recorder) + recorder := record.NewRecorder() err := findAndRecord(proofTrie, nk, recorder) if err != nil { return nil, err } - for !recorder.isEmpty() { - recNode := recorder.next() - nodeHashHex := common.BytesToHex(recNode.hash) + for _, recNode := range recorder.GetNodes() { + nodeHashHex := common.BytesToHex(recNode.Hash) if _, ok := trackedProofs[nodeHashHex]; !ok { - trackedProofs[nodeHashHex] = recNode.rawData + trackedProofs[nodeHashHex] = recNode.RawData } } } diff --git a/lib/trie/record/node.go b/lib/trie/record/node.go new file mode 100644 index 0000000000..eb3299e9bc --- /dev/null +++ b/lib/trie/record/node.go @@ -0,0 +1,7 @@ +package record + +// Node represents a record of a visited node +type Node struct { + RawData []byte + Hash []byte +} diff --git a/lib/trie/record/recorder.go b/lib/trie/record/recorder.go new file mode 100644 index 0000000000..130b434338 --- /dev/null +++ b/lib/trie/record/recorder.go @@ -0,0 +1,27 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package record + +// Recorder records the list of nodes found by Lookup.Find +type Recorder struct { + nodes []Node +} + +// NewRecorder creates a new recorder. +func NewRecorder() *Recorder { + return &Recorder{} +} + +// Record appends a node to the list of visited nodes. +func (r *Recorder) Record(hash, rawData []byte) { + r.nodes = append(r.nodes, Node{RawData: rawData, Hash: hash}) +} + +// GetNodes returns all the nodes recorded. +// Note it does not copy its slice of nodes. +// It's fine to not copy them since the recorder +// is not used again after a call to GetNodes() +func (r *Recorder) GetNodes() (nodes []Node) { + return r.nodes +} diff --git a/lib/trie/record/recorder_test.go b/lib/trie/record/recorder_test.go new file mode 100644 index 0000000000..638661b97a --- /dev/null +++ b/lib/trie/record/recorder_test.go @@ -0,0 +1,115 @@ +package record + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NewRecorder(t *testing.T) { + t.Parallel() + + expected := &Recorder{} + + recorder := NewRecorder() + + assert.Equal(t, expected, recorder) +} + +func Test_Recorder_Record(t *testing.T) { + testCases := map[string]struct { + recorder *Recorder + hash []byte + rawData []byte + expectedRecorder *Recorder + }{ + "nil data": { + recorder: &Recorder{}, + expectedRecorder: &Recorder{ + nodes: []Node{ + {}, + }, + }, + }, + "insert in empty recorder": { + recorder: &Recorder{}, + hash: []byte{1, 2}, + rawData: []byte{3, 4}, + expectedRecorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + }, + }, + }, + "insert in non-empty recorder": { + recorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, + }, + }, + hash: []byte{1, 2}, + rawData: []byte{3, 4}, + expectedRecorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + }, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.recorder.Record(testCase.hash, testCase.rawData) + + assert.Equal(t, testCase.expectedRecorder, testCase.recorder) + }) + } +} + +func Test_Recorder_GetNodes(t *testing.T) { + testCases := map[string]struct { + recorder *Recorder + nodes []Node + }{ + "no node": { + recorder: &Recorder{}, + }, + "get single node from recorder": { + recorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + }, + }, + nodes: []Node{{Hash: []byte{1, 2}, RawData: []byte{3, 4}}}, + }, + "get node from multiple nodes in recorder": { + recorder: &Recorder{ + nodes: []Node{ + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, + {Hash: []byte{9, 6}, RawData: []byte{7, 8}}, + }, + }, + nodes: []Node{ + {Hash: []byte{1, 2}, RawData: []byte{3, 4}}, + {Hash: []byte{5, 6}, RawData: []byte{7, 8}}, + {Hash: []byte{9, 6}, RawData: []byte{7, 8}}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + nodes := testCase.recorder.GetNodes() + + assert.Equal(t, testCase.nodes, nodes) + }) + } +} diff --git a/lib/trie/recorder.go b/lib/trie/recorder.go deleted file mode 100644 index 6db2a841d0..0000000000 --- a/lib/trie/recorder.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -// nodeRecord represets a record of a visited node -type nodeRecord struct { - rawData []byte - hash []byte -} - -// recorder keeps the list of nodes find by Lookup.Find -type recorder []nodeRecord - -// record insert a node inside the recorded list -func (r *recorder) record(h, rd []byte) { - *r = append(*r, nodeRecord{rawData: rd, hash: h}) -} - -// next returns the current item the cursor is on and increment the cursor by 1 -func (r *recorder) next() *nodeRecord { - if !r.isEmpty() { - n := (*r)[0] - *r = (*r)[1:] - return &n - } - - return nil -} - -// isEmpty returns bool if there is data inside the slice -func (r *recorder) isEmpty() bool { - return len(*r) <= 0 -} From e77adb20684a7c985b14693c238fc630b4918b63 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Wed, 1 Dec 2021 21:08:57 +0000 Subject: [PATCH 14/50] Remove ReadNextByte and use sync.Pool --- lib/trie/decode.go | 9 +++- lib/trie/decode/byte.go | 16 ------- lib/trie/decode/byte_test.go | 52 ----------------------- lib/trie/decode/key.go | 10 ++++- lib/trie/encodedecode_test/branch_test.go | 5 ++- lib/trie/pools/pools.go | 9 ++++ 6 files changed, 28 insertions(+), 73 deletions(-) delete mode 100644 lib/trie/decode/byte.go delete mode 100644 lib/trie/decode/byte_test.go diff --git a/lib/trie/decode.go b/lib/trie/decode.go index 452c60cdc7..46bf12d751 100644 --- a/lib/trie/decode.go +++ b/lib/trie/decode.go @@ -4,14 +4,15 @@ package trie import ( + "bytes" "errors" "fmt" "io" "github.com/ChainSafe/gossamer/lib/trie/branch" - "github.com/ChainSafe/gossamer/lib/trie/decode" "github.com/ChainSafe/gossamer/lib/trie/leaf" "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/ChainSafe/gossamer/lib/trie/pools" ) var ( @@ -20,10 +21,14 @@ var ( ) func decodeNode(reader io.Reader) (n node.Node, err error) { - header, err := decode.ReadNextByte(reader) + buffer := pools.SingleByteBuffers.Get().(*bytes.Buffer) + defer pools.SingleByteBuffers.Put(buffer) + oneByteBuf := buffer.Bytes() + _, err = reader.Read(oneByteBuf) if err != nil { return nil, fmt.Errorf("%w: %s", ErrReadHeaderByte, err) } + header := oneByteBuf[0] nodeType := header >> 6 switch nodeType { diff --git a/lib/trie/decode/byte.go b/lib/trie/decode/byte.go deleted file mode 100644 index 4560fcbc90..0000000000 --- a/lib/trie/decode/byte.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package decode - -import "io" - -// ReadNextByte reads the next byte from the reader. -func ReadNextByte(reader io.Reader) (b byte, err error) { - buffer := make([]byte, 1) - _, err = reader.Read(buffer) - if err != nil { - return 0, err - } - return buffer[0], nil -} diff --git a/lib/trie/decode/byte_test.go b/lib/trie/decode/byte_test.go deleted file mode 100644 index 4a8115d287..0000000000 --- a/lib/trie/decode/byte_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package decode - -import ( - "bytes" - "io" - "testing" - - "github.com/stretchr/testify/assert" -) - -func Test_ReadNextByte(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - reader io.Reader - b byte - errWrapped error - errMessage string - }{ - "empty buffer": { - reader: bytes.NewBuffer(nil), - errWrapped: io.EOF, - errMessage: "EOF", - }, - "single byte buffer": { - reader: bytes.NewBuffer([]byte{1}), - b: 1, - }, - "two bytes buffer": { - reader: bytes.NewBuffer([]byte{1, 2}), - b: 1, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - b, err := ReadNextByte(testCase.reader) - - assert.ErrorIs(t, err, testCase.errWrapped) - if err != nil { - assert.EqualError(t, err, testCase.errMessage) - } - assert.Equal(t, testCase.b, b) - }) - } -} diff --git a/lib/trie/decode/key.go b/lib/trie/decode/key.go index 7c014bc944..4ab24adaff 100644 --- a/lib/trie/decode/key.go +++ b/lib/trie/decode/key.go @@ -4,9 +4,12 @@ package decode import ( + "bytes" "errors" "fmt" "io" + + "github.com/ChainSafe/gossamer/lib/trie/pools" ) const maxPartialKeySize = ^uint16(0) @@ -23,11 +26,16 @@ func Key(reader io.Reader, keyLength byte) (b []byte, err error) { if keyLength == 0x3f { // partial key longer than 63, read next bytes for rest of pk len + buffer := pools.SingleByteBuffers.Get().(*bytes.Buffer) + defer pools.SingleByteBuffers.Put(buffer) + oneByteBuf := buffer.Bytes() for { - nextKeyLen, err := ReadNextByte(reader) + _, err = reader.Read(oneByteBuf) if err != nil { return nil, fmt.Errorf("%w: %s", ErrReadKeyLength, err) } + nextKeyLen := oneByteBuf[0] + publicKeyLength += int(nextKeyLen) if nextKeyLen < 0xff { diff --git a/lib/trie/encodedecode_test/branch_test.go b/lib/trie/encodedecode_test/branch_test.go index 97e4e7f8c1..2e48656888 100644 --- a/lib/trie/encodedecode_test/branch_test.go +++ b/lib/trie/encodedecode_test/branch_test.go @@ -8,7 +8,6 @@ import ( "testing" "github.com/ChainSafe/gossamer/lib/trie/branch" - "github.com/ChainSafe/gossamer/lib/trie/decode" "github.com/ChainSafe/gossamer/lib/trie/leaf" "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/stretchr/testify/assert" @@ -79,8 +78,10 @@ func Test_Branch_Encode_Decode(t *testing.T) { err := testCase.branchToEncode.Encode(buffer) require.NoError(t, err) - header, err := decode.ReadNextByte(buffer) + oneBuffer := make([]byte, 1) + _, err = buffer.Read(oneBuffer) require.NoError(t, err) + header := oneBuffer[0] resultBranch, err := branch.Decode(buffer, header) require.NoError(t, err) diff --git a/lib/trie/pools/pools.go b/lib/trie/pools/pools.go index 1bfe8f5a83..855232ef44 100644 --- a/lib/trie/pools/pools.go +++ b/lib/trie/pools/pools.go @@ -10,6 +10,15 @@ import ( "golang.org/x/crypto/blake2b" ) +// SingleByteBuffers is a sync pool of buffers of capacity 1. +var SingleByteBuffers = &sync.Pool{ + New: func() interface{} { + const bufferLength = 1 + b := make([]byte, bufferLength) + return bytes.NewBuffer(b) + }, +} + // DigestBuffers is a sync pool of buffers of capacity 32. var DigestBuffers = &sync.Pool{ New: func() interface{} { From 8697d2d066c11254c484c168985546eddbf5743d Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 2 Dec 2021 15:49:34 +0000 Subject: [PATCH 15/50] Rename `ExtraPartialKeyLength` to `KeyLength` --- lib/trie/branch/header.go | 2 +- lib/trie/encode/key.go | 16 ++++++++-------- lib/trie/encode/key_test.go | 20 ++++++++++---------- lib/trie/leaf/header.go | 2 +- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/lib/trie/branch/header.go b/lib/trie/branch/header.go index b990d00529..903b49e3fa 100644 --- a/lib/trie/branch/header.go +++ b/lib/trie/branch/header.go @@ -17,7 +17,7 @@ func (b *Branch) Header() (encoding []byte, err error) { var encodedPublicKeyLength []byte if len(b.Key) >= 63 { header = header | 0x3f - encodedPublicKeyLength, err = encode.ExtraPartialKeyLength(len(b.Key)) + encodedPublicKeyLength, err = encode.KeyLength(len(b.Key)) if err != nil { return nil, err } diff --git a/lib/trie/encode/key.go b/lib/trie/encode/key.go index 8a8e97ebd1..e69d637eed 100644 --- a/lib/trie/encode/key.go +++ b/lib/trie/encode/key.go @@ -12,22 +12,22 @@ const maxPartialKeySize = ^uint16(0) var ErrPartialKeyTooBig = errors.New("partial key length cannot be larger than or equal to 2^16") -// ExtraPartialKeyLength encodes the public key length. -func ExtraPartialKeyLength(publicKeyLength int) (encoding []byte, err error) { - publicKeyLength -= 63 +// KeyLength encodes the public key length. +func KeyLength(keyLength int) (encoding []byte, err error) { + keyLength -= 63 - if publicKeyLength >= int(maxPartialKeySize) { + if keyLength >= int(maxPartialKeySize) { return nil, fmt.Errorf("%w: %d", - ErrPartialKeyTooBig, publicKeyLength) + ErrPartialKeyTooBig, keyLength) } for i := uint16(0); i < maxPartialKeySize; i++ { - if publicKeyLength < 255 { - encoding = append(encoding, byte(publicKeyLength)) + if keyLength < 255 { + encoding = append(encoding, byte(keyLength)) break } encoding = append(encoding, byte(255)) - publicKeyLength -= 255 + keyLength -= 255 } return encoding, nil diff --git a/lib/trie/encode/key_test.go b/lib/trie/encode/key_test.go index acdd5d0158..1984787f58 100644 --- a/lib/trie/encode/key_test.go +++ b/lib/trie/encode/key_test.go @@ -9,27 +9,27 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_ExtraPartialKeyLength(t *testing.T) { +func Test_KeyLength(t *testing.T) { t.Parallel() testCases := map[string]struct { - publicKeyLength int - encoding []byte - err error + partialKeyLength int + encoding []byte + err error }{ "length equal to maximum": { - publicKeyLength: int(maxPartialKeySize) + 63, - err: ErrPartialKeyTooBig, + partialKeyLength: int(maxPartialKeySize) + 63, + err: ErrPartialKeyTooBig, }, "zero length": { encoding: []byte{0xc1}, }, "one length": { - publicKeyLength: 1, - encoding: []byte{0xc2}, + partialKeyLength: 1, + encoding: []byte{0xc2}, }, "length at maximum allowed": { - publicKeyLength: int(maxPartialKeySize) + 62, + partialKeyLength: int(maxPartialKeySize) + 62, encoding: []byte{ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, @@ -65,7 +65,7 @@ func Test_ExtraPartialKeyLength(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - encoding, err := ExtraPartialKeyLength(testCase.publicKeyLength) + encoding, err := KeyLength(testCase.partialKeyLength) assert.ErrorIs(t, err, testCase.err) assert.Equal(t, testCase.encoding, encoding) diff --git a/lib/trie/leaf/header.go b/lib/trie/leaf/header.go index 264d38ab05..2c0e009fa6 100644 --- a/lib/trie/leaf/header.go +++ b/lib/trie/leaf/header.go @@ -12,7 +12,7 @@ func (l *Leaf) Header() (encoding []byte, err error) { if len(l.Key) >= 63 { header = header | 0x3f - encodedPublicKeyLength, err = encode.ExtraPartialKeyLength(len(l.Key)) + encodedPublicKeyLength, err = encode.KeyLength(len(l.Key)) if err != nil { return nil, err } From b03ecff762b0ef620a76f60ba9cdc13d1f41a736 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 2 Dec 2021 16:01:11 +0000 Subject: [PATCH 16/50] Remove unneded `NibblesToKey` --- lib/trie/encode/key.go | 26 -------------------- lib/trie/encode/key_test.go | 48 ------------------------------------- 2 files changed, 74 deletions(-) diff --git a/lib/trie/encode/key.go b/lib/trie/encode/key.go index e69d637eed..fb4fbc2e40 100644 --- a/lib/trie/encode/key.go +++ b/lib/trie/encode/key.go @@ -33,32 +33,6 @@ func KeyLength(keyLength int) (encoding []byte, err error) { return encoding, nil } -// NibblesToKey converts a slice of nibbles with length k into a -// Big Endian byte slice. -// It assumes nibbles are already in Little Endian and does not rearrange nibbles. -// If the length of the input is odd, the result is -// [ in[1] in[0] | ... | 0000 in[k-1] ] -// Otherwise, the result is -// [ in[1] in[0] | ... | in[k-1] in[k-2] ] -func NibblesToKey(nibbles []byte) (key []byte) { - if len(nibbles)%2 == 0 { - key = make([]byte, len(nibbles)/2) - for i := 0; i < len(nibbles); i += 2 { - key[i/2] = (nibbles[i] & 0xf) | (nibbles[i+1] << 4 & 0xf0) - } - } else { - key = make([]byte, len(nibbles)/2+1) - for i := 0; i < len(nibbles); i += 2 { - key[i/2] = nibbles[i] & 0xf - if i < len(nibbles)-1 { - key[i/2] |= (nibbles[i+1] << 4 & 0xf0) - } - } - } - - return key -} - // NibblesToKeyLE converts a slice of nibbles with length k into a // Little Endian byte slice. // It assumes nibbles are already in Little Endian and does not rearrange nibbles. diff --git a/lib/trie/encode/key_test.go b/lib/trie/encode/key_test.go index 1984787f58..a4541083aa 100644 --- a/lib/trie/encode/key_test.go +++ b/lib/trie/encode/key_test.go @@ -73,54 +73,6 @@ func Test_KeyLength(t *testing.T) { } } -func Test_NibblesToKey(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - nibbles []byte - key []byte - }{ - "nil nibbles": { - key: []byte{}, - }, - "empty nibbles": { - nibbles: []byte{}, - key: []byte{}, - }, - "0xF 0xF": { - nibbles: []byte{0xF, 0xF}, - key: []byte{0xFF}, - }, - "0x3 0xa 0x0 0x5": { - nibbles: []byte{0x3, 0xa, 0x0, 0x5}, - key: []byte{0xa3, 0x50}, - }, - "0xa 0xa 0xf 0xf 0x0 0x1": { - nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}, - key: []byte{0xaa, 0xff, 0x10}, - }, - "0xa 0xa 0xf 0xf 0x0 0x1 0xc 0x2": { - nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}, - key: []byte{0xaa, 0xff, 0x10, 0x2c}, - }, - "0xa 0xa 0xf 0xf 0x0 0x1 0xc": { - nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc}, - key: []byte{0xaa, 0xff, 0x10, 0x0c}, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - key := NibblesToKey(testCase.nibbles) - - assert.Equal(t, testCase.key, key) - }) - } -} - func Test_NibblesToKeyLE(t *testing.T) { t.Parallel() From 207a5c3a39b37a403b4895decd24430e8c99a02e Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 6 Dec 2021 16:47:16 +0000 Subject: [PATCH 17/50] Rename `Header()` to `encodeHeader()` --- lib/trie/branch/encode.go | 2 +- lib/trie/branch/header.go | 4 ++-- lib/trie/branch/header_test.go | 4 ++-- lib/trie/leaf/encode.go | 2 +- lib/trie/leaf/header.go | 4 ++-- lib/trie/leaf/header_test.go | 4 ++-- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/lib/trie/branch/encode.go b/lib/trie/branch/encode.go index 04d7bc9cf9..9e9bfb0bf9 100644 --- a/lib/trie/branch/encode.go +++ b/lib/trie/branch/encode.go @@ -86,7 +86,7 @@ func (b *Branch) Encode(buffer encode.Buffer) (err error) { return nil } - encodedHeader, err := b.Header() + encodedHeader, err := b.encodeHeader() if err != nil { return fmt.Errorf("cannot encode header: %w", err) } diff --git a/lib/trie/branch/header.go b/lib/trie/branch/header.go index 903b49e3fa..343eeddae5 100644 --- a/lib/trie/branch/header.go +++ b/lib/trie/branch/header.go @@ -5,8 +5,8 @@ package branch import "github.com/ChainSafe/gossamer/lib/trie/encode" -// Header creates the encoded header for the branch. -func (b *Branch) Header() (encoding []byte, err error) { +// encodeHeader creates the encoded header for the branch. +func (b *Branch) encodeHeader() (encoding []byte, err error) { var header byte if b.Value == nil { header = 2 << 6 diff --git a/lib/trie/branch/header_test.go b/lib/trie/branch/header_test.go index ad251b468c..aac1edead4 100644 --- a/lib/trie/branch/header_test.go +++ b/lib/trie/branch/header_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_Branch_Header(t *testing.T) { +func Test_Branch_encodeHeader(t *testing.T) { testCases := map[string]struct { branch *Branch encoding []byte @@ -65,7 +65,7 @@ func Test_Branch_Header(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - encoding, err := testCase.branch.Header() + encoding, err := testCase.branch.encodeHeader() if testCase.wrappedErr != nil { assert.ErrorIs(t, err, testCase.wrappedErr) diff --git a/lib/trie/leaf/encode.go b/lib/trie/leaf/encode.go index ca1dd62cc4..3761908d36 100644 --- a/lib/trie/leaf/encode.go +++ b/lib/trie/leaf/encode.go @@ -97,7 +97,7 @@ func (l *Leaf) Encode(buffer encode.Buffer) (err error) { } l.encodingMu.RUnlock() - encodedHeader, err := l.Header() + encodedHeader, err := l.encodeHeader() if err != nil { return fmt.Errorf("cannot encode header: %w", err) } diff --git a/lib/trie/leaf/header.go b/lib/trie/leaf/header.go index 2c0e009fa6..6046bad173 100644 --- a/lib/trie/leaf/header.go +++ b/lib/trie/leaf/header.go @@ -5,8 +5,8 @@ package leaf import "github.com/ChainSafe/gossamer/lib/trie/encode" -// Header creates the encoded header for the leaf. -func (l *Leaf) Header() (encoding []byte, err error) { +// encodeHeader creates the encoded header for the leaf. +func (l *Leaf) encodeHeader() (encoding []byte, err error) { var header byte = 1 << 6 var encodedPublicKeyLength []byte diff --git a/lib/trie/leaf/header_test.go b/lib/trie/leaf/header_test.go index a1825034c1..5348dd2e3d 100644 --- a/lib/trie/leaf/header_test.go +++ b/lib/trie/leaf/header_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_Leaf_Header(t *testing.T) { +func Test_Leaf_encodeHeader(t *testing.T) { testCases := map[string]struct { leaf *Leaf encoding []byte @@ -59,7 +59,7 @@ func Test_Leaf_Header(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - encoding, err := testCase.leaf.Header() + encoding, err := testCase.leaf.encodeHeader() if testCase.wrappedErr != nil { assert.ErrorIs(t, err, testCase.wrappedErr) From 37425a7120e106a2b63f8fbe56937a342ccaf0e3 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 7 Dec 2021 08:38:30 +0000 Subject: [PATCH 18/50] encode headers directly to buffer --- lib/trie/branch/encode.go | 7 +- lib/trie/branch/encode_test.go | 15 +-- lib/trie/branch/header.go | 27 +++-- lib/trie/branch/header_test.go | 86 ++++++++++++---- lib/trie/encode/key.go | 18 +++- lib/trie/encode/key_test.go | 153 ++++++++++++++++++++-------- lib/trie/encode/writer_mock_test.go | 49 +++++++++ lib/trie/leaf/encode.go | 7 +- lib/trie/leaf/encode_test.go | 14 +-- lib/trie/leaf/header.go | 34 ++++--- lib/trie/leaf/header_test.go | 84 ++++++++++++--- 11 files changed, 359 insertions(+), 135 deletions(-) create mode 100644 lib/trie/encode/writer_mock_test.go diff --git a/lib/trie/branch/encode.go b/lib/trie/branch/encode.go index 9e9bfb0bf9..37d49f74cd 100644 --- a/lib/trie/branch/encode.go +++ b/lib/trie/branch/encode.go @@ -86,16 +86,11 @@ func (b *Branch) Encode(buffer encode.Buffer) (err error) { return nil } - encodedHeader, err := b.encodeHeader() + err = b.encodeHeader(buffer) if err != nil { return fmt.Errorf("cannot encode header: %w", err) } - _, err = buffer.Write(encodedHeader) - if err != nil { - return fmt.Errorf("cannot write encoded header to buffer: %w", err) - } - keyLE := encode.NibblesToKeyLE(b.Key) _, err = buffer.Write(keyLE) if err != nil { diff --git a/lib/trie/branch/encode_test.go b/lib/trie/branch/encode_test.go index 7153329bc5..3de07d6f5c 100644 --- a/lib/trie/branch/encode_test.go +++ b/lib/trie/branch/encode_test.go @@ -61,22 +61,13 @@ func Test_Branch_Encode(t *testing.T) { branch: &Branch{ Key: make([]byte, 63+(1<<16)), }, - wrappedErr: encode.ErrPartialKeyTooBig, - errMessage: "cannot encode header: partial key length cannot be larger than or equal to 2^16: 65536", - }, - "buffer write error for encoded header": { - branch: &Branch{ - Key: []byte{1, 2, 3}, - Value: []byte{100}, - }, writes: []writeCall{ { // header - written: []byte{195}, - err: errTest, + written: []byte{191}, }, }, - wrappedErr: errTest, - errMessage: "cannot write encoded header to buffer: test error", + wrappedErr: encode.ErrPartialKeyTooBig, + errMessage: "cannot encode header: partial key length cannot be larger than or equal to 2^16: 65536", }, "buffer write error for encoded key": { branch: &Branch{ diff --git a/lib/trie/branch/header.go b/lib/trie/branch/header.go index 343eeddae5..bbc7a683b3 100644 --- a/lib/trie/branch/header.go +++ b/lib/trie/branch/header.go @@ -3,10 +3,14 @@ package branch -import "github.com/ChainSafe/gossamer/lib/trie/encode" +import ( + "io" + + "github.com/ChainSafe/gossamer/lib/trie/encode" +) // encodeHeader creates the encoded header for the branch. -func (b *Branch) encodeHeader() (encoding []byte, err error) { +func (b *Branch) encodeHeader(writer io.Writer) (err error) { var header byte if b.Value == nil { header = 2 << 6 @@ -14,19 +18,24 @@ func (b *Branch) encodeHeader() (encoding []byte, err error) { header = 3 << 6 } - var encodedPublicKeyLength []byte if len(b.Key) >= 63 { header = header | 0x3f - encodedPublicKeyLength, err = encode.KeyLength(len(b.Key)) + _, err = writer.Write([]byte{header}) + if err != nil { + return err + } + + err = encode.KeyLength(len(b.Key), writer) if err != nil { - return nil, err + return err } } else { header = header | byte(len(b.Key)) + _, err = writer.Write([]byte{header}) + if err != nil { + return err + } } - encoding = make([]byte, 0, len(encodedPublicKeyLength)+1) - encoding = append(encoding, header) - encoding = append(encoding, encodedPublicKeyLength...) - return encoding, nil + return nil } diff --git a/lib/trie/branch/header_test.go b/lib/trie/branch/header_test.go index aac1edead4..2209518549 100644 --- a/lib/trie/branch/header_test.go +++ b/lib/trie/branch/header_test.go @@ -7,74 +7,126 @@ import ( "testing" "github.com/ChainSafe/gossamer/lib/trie/encode" + gomock "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) func Test_Branch_encodeHeader(t *testing.T) { testCases := map[string]struct { branch *Branch - encoding []byte - wrappedErr error + writes []writeCall + errWrapped error errMessage string }{ "no key": { - branch: &Branch{}, - encoding: []byte{0x80}, + branch: &Branch{}, + writes: []writeCall{ + {written: []byte{0x80}}, + }, }, "with value": { branch: &Branch{ Value: []byte{}, }, - encoding: []byte{0xc0}, + writes: []writeCall{ + {written: []byte{0xc0}}, + }, }, "key of length 30": { branch: &Branch{ Key: make([]byte, 30), }, - encoding: []byte{0x9e}, + writes: []writeCall{ + {written: []byte{0x9e}}, + }, }, "key of length 62": { branch: &Branch{ Key: make([]byte, 62), }, - encoding: []byte{0xbe}, + writes: []writeCall{ + {written: []byte{0xbe}}, + }, }, "key of length 63": { branch: &Branch{ Key: make([]byte, 63), }, - encoding: []byte{0xbf, 0x0}, + writes: []writeCall{ + {written: []byte{0xbf}}, + {written: []byte{0x0}}, + }, }, "key of length 64": { branch: &Branch{ Key: make([]byte, 64), }, - encoding: []byte{0xbf, 0x1}, + writes: []writeCall{ + {written: []byte{0xbf}}, + {written: []byte{0x1}}, + }, }, "key too big": { branch: &Branch{ Key: make([]byte, 65535+63), }, - wrappedErr: encode.ErrPartialKeyTooBig, + writes: []writeCall{ + {written: []byte{0xbf}}, + }, + errWrapped: encode.ErrPartialKeyTooBig, errMessage: "partial key length cannot be larger than or equal to 2^16: 65535", }, + "small key length write error": { + branch: &Branch{}, + writes: []writeCall{ + { + written: []byte{0x80}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: "test error", + }, + "long key length write error": { + branch: &Branch{ + Key: make([]byte, 64), + }, + writes: []writeCall{ + { + written: []byte{0xbf}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: "test error", + }, } for name, testCase := range testCases { testCase := testCase t.Run(name, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) - encoding, err := testCase.branch.encodeHeader() + writer := NewMockWriter(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := writer.EXPECT(). + Write(write.written). + Return(write.n, write.err) - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - assert.NoError(t, err) + if previousCall != nil { + call.After(previousCall) + } + previousCall = call } - assert.Equal(t, testCase.encoding, encoding) + err := testCase.branch.encodeHeader(writer) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } }) } } diff --git a/lib/trie/encode/key.go b/lib/trie/encode/key.go index fb4fbc2e40..3747df17b8 100644 --- a/lib/trie/encode/key.go +++ b/lib/trie/encode/key.go @@ -6,6 +6,7 @@ package encode import ( "errors" "fmt" + "io" ) const maxPartialKeySize = ^uint16(0) @@ -13,24 +14,31 @@ const maxPartialKeySize = ^uint16(0) var ErrPartialKeyTooBig = errors.New("partial key length cannot be larger than or equal to 2^16") // KeyLength encodes the public key length. -func KeyLength(keyLength int) (encoding []byte, err error) { +func KeyLength(keyLength int, writer io.Writer) (err error) { keyLength -= 63 if keyLength >= int(maxPartialKeySize) { - return nil, fmt.Errorf("%w: %d", + return fmt.Errorf("%w: %d", ErrPartialKeyTooBig, keyLength) } for i := uint16(0); i < maxPartialKeySize; i++ { if keyLength < 255 { - encoding = append(encoding, byte(keyLength)) + _, err = writer.Write([]byte{byte(keyLength)}) + if err != nil { + return err + } break } - encoding = append(encoding, byte(255)) + _, err = writer.Write([]byte{255}) + if err != nil { + return err + } + keyLength -= 255 } - return encoding, nil + return nil } // NibblesToKeyLE converts a slice of nibbles with length k into a diff --git a/lib/trie/encode/key_test.go b/lib/trie/encode/key_test.go index a4541083aa..058be5643e 100644 --- a/lib/trie/encode/key_test.go +++ b/lib/trie/encode/key_test.go @@ -4,59 +4,90 @@ package encode import ( + "bytes" + "errors" "testing" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +type writeCall struct { + written []byte + n int + err error +} + +var errTest = errors.New("test error") + +//go:generate mockgen -destination=writer_mock_test.go -package $GOPACKAGE io Writer + func Test_KeyLength(t *testing.T) { t.Parallel() testCases := map[string]struct { - partialKeyLength int - encoding []byte - err error + keyLength int + writes []writeCall + errWrapped error + errMessage string }{ "length equal to maximum": { - partialKeyLength: int(maxPartialKeySize) + 63, - err: ErrPartialKeyTooBig, + keyLength: int(maxPartialKeySize) + 63, + errWrapped: ErrPartialKeyTooBig, + errMessage: "partial key length cannot be " + + "larger than or equal to 2^16: 65535", }, "zero length": { - encoding: []byte{0xc1}, + writes: []writeCall{ + { + written: []byte{0xc1}, + }, + }, }, "one length": { - partialKeyLength: 1, - encoding: []byte{0xc2}, + keyLength: 1, + writes: []writeCall{ + { + written: []byte{0xc2}, + }, + }, }, - "length at maximum allowed": { - partialKeyLength: int(maxPartialKeySize) + 62, - encoding: []byte{ - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}, + "error at single byte write": { + keyLength: 1, + writes: []writeCall{ + { + written: []byte{0xc2}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: errTest.Error(), + }, + "error at first byte write": { + keyLength: 255 + 100 + 63, + writes: []writeCall{ + { + written: []byte{255}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: errTest.Error(), + }, + "error at last byte write": { + keyLength: 255 + 100 + 63, + writes: []writeCall{ + { + written: []byte{255}, + }, + { + written: []byte{100}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: errTest.Error(), }, } @@ -64,13 +95,55 @@ func Test_KeyLength(t *testing.T) { testCase := testCase t.Run(name, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + + writer := NewMockWriter(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := writer.EXPECT(). + Write(write.written). + Return(write.n, write.err) - encoding, err := KeyLength(testCase.partialKeyLength) + if write.err != nil { + break + } else if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } - assert.ErrorIs(t, err, testCase.err) - assert.Equal(t, testCase.encoding, encoding) + err := KeyLength(testCase.keyLength, writer) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } }) } + + t.Run("length at maximum", func(t *testing.T) { + t.Parallel() + + // Note: this test case cannot run with the + // mock writer since it's too slow, so we use + // an actual buffer. + + const keyLength = int(maxPartialKeySize) + 62 + const expectedEncodingLength = 257 + expectedBytes := make([]byte, expectedEncodingLength) + for i := 0; i < len(expectedBytes)-1; i++ { + expectedBytes[i] = 255 + } + expectedBytes[len(expectedBytes)-1] = 254 + + buffer := bytes.NewBuffer(nil) + buffer.Grow(expectedEncodingLength) + + err := KeyLength(keyLength, buffer) + + require.NoError(t, err) + assert.Equal(t, expectedBytes, buffer.Bytes()) + }) } func Test_NibblesToKeyLE(t *testing.T) { diff --git a/lib/trie/encode/writer_mock_test.go b/lib/trie/encode/writer_mock_test.go new file mode 100644 index 0000000000..171dc73964 --- /dev/null +++ b/lib/trie/encode/writer_mock_test.go @@ -0,0 +1,49 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: io (interfaces: Writer) + +// Package encode is a generated GoMock package. +package encode + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockWriter is a mock of Writer interface. +type MockWriter struct { + ctrl *gomock.Controller + recorder *MockWriterMockRecorder +} + +// MockWriterMockRecorder is the mock recorder for MockWriter. +type MockWriterMockRecorder struct { + mock *MockWriter +} + +// NewMockWriter creates a new mock instance. +func NewMockWriter(ctrl *gomock.Controller) *MockWriter { + mock := &MockWriter{ctrl: ctrl} + mock.recorder = &MockWriterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockWriter) EXPECT() *MockWriterMockRecorder { + return m.recorder +} + +// Write mocks base method. +func (m *MockWriter) Write(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write. +func (mr *MockWriterMockRecorder) Write(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockWriter)(nil).Write), arg0) +} diff --git a/lib/trie/leaf/encode.go b/lib/trie/leaf/encode.go index 3761908d36..477d6dd78d 100644 --- a/lib/trie/leaf/encode.go +++ b/lib/trie/leaf/encode.go @@ -97,16 +97,11 @@ func (l *Leaf) Encode(buffer encode.Buffer) (err error) { } l.encodingMu.RUnlock() - encodedHeader, err := l.encodeHeader() + err = l.encodeHeader(buffer) if err != nil { return fmt.Errorf("cannot encode header: %w", err) } - _, err = buffer.Write(encodedHeader) - if err != nil { - return fmt.Errorf("cannot write encoded header to buffer: %w", err) - } - keyLE := encode.NibblesToKeyLE(l.Key) _, err = buffer.Write(keyLE) if err != nil { diff --git a/lib/trie/leaf/encode_test.go b/lib/trie/leaf/encode_test.go index 513b61eb90..51646719b2 100644 --- a/lib/trie/leaf/encode_test.go +++ b/lib/trie/leaf/encode_test.go @@ -65,21 +65,13 @@ func Test_Leaf_Encode(t *testing.T) { leaf: &Leaf{ Key: make([]byte, 63+(1<<16)), }, - wrappedErr: encode.ErrPartialKeyTooBig, - errMessage: "cannot encode header: partial key length cannot be larger than or equal to 2^16: 65536", - }, - "buffer write error for encoded header": { - leaf: &Leaf{ - Key: []byte{1, 2, 3}, - }, writes: []writeCall{ { - written: []byte{67}, - err: errTest, + written: []byte{127}, }, }, - wrappedErr: errTest, - errMessage: "cannot write encoded header to buffer: test error", + wrappedErr: encode.ErrPartialKeyTooBig, + errMessage: "cannot encode header: partial key length cannot be larger than or equal to 2^16: 65536", }, "buffer write error for encoded key": { leaf: &Leaf{ diff --git a/lib/trie/leaf/header.go b/lib/trie/leaf/header.go index 6046bad173..930658f65f 100644 --- a/lib/trie/leaf/header.go +++ b/lib/trie/leaf/header.go @@ -3,24 +3,32 @@ package leaf -import "github.com/ChainSafe/gossamer/lib/trie/encode" +import ( + "io" + + "github.com/ChainSafe/gossamer/lib/trie/encode" +) // encodeHeader creates the encoded header for the leaf. -func (l *Leaf) encodeHeader() (encoding []byte, err error) { +func (l *Leaf) encodeHeader(writer io.Writer) (err error) { var header byte = 1 << 6 - var encodedPublicKeyLength []byte - - if len(l.Key) >= 63 { - header = header | 0x3f - encodedPublicKeyLength, err = encode.KeyLength(len(l.Key)) - if err != nil { - return nil, err - } - } else { + + if len(l.Key) < 63 { header = header | byte(len(l.Key)) + _, err = writer.Write([]byte{header}) + return err } - encoding = append([]byte{header}, encodedPublicKeyLength...) + header = header | 0x3f + _, err = writer.Write([]byte{header}) + if err != nil { + return err + } + + err = encode.KeyLength(len(l.Key), writer) + if err != nil { + return err + } - return encoding, nil + return nil } diff --git a/lib/trie/leaf/header_test.go b/lib/trie/leaf/header_test.go index 5348dd2e3d..08f03d0ed9 100644 --- a/lib/trie/leaf/header_test.go +++ b/lib/trie/leaf/header_test.go @@ -7,49 +7,91 @@ import ( "testing" "github.com/ChainSafe/gossamer/lib/trie/encode" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) func Test_Leaf_encodeHeader(t *testing.T) { testCases := map[string]struct { leaf *Leaf - encoding []byte - wrappedErr error + writes []writeCall + errWrapped error errMessage string }{ "no key": { - leaf: &Leaf{}, - encoding: []byte{0x40}, + leaf: &Leaf{}, + writes: []writeCall{ + {written: []byte{0x40}}, + }, }, "key of length 30": { leaf: &Leaf{ Key: make([]byte, 30), }, - encoding: []byte{0x5e}, + writes: []writeCall{ + {written: []byte{0x5e}}, + }, + }, + "short key write error": { + leaf: &Leaf{ + Key: make([]byte, 30), + }, + writes: []writeCall{ + { + written: []byte{0x5e}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: errTest.Error(), }, "key of length 62": { leaf: &Leaf{ Key: make([]byte, 62), }, - encoding: []byte{0x7e}, + writes: []writeCall{ + {written: []byte{0x7e}}, + }, }, "key of length 63": { leaf: &Leaf{ Key: make([]byte, 63), }, - encoding: []byte{0x7f, 0x0}, + writes: []writeCall{ + {written: []byte{0x7f}}, + {written: []byte{0x0}}, + }, }, "key of length 64": { leaf: &Leaf{ Key: make([]byte, 64), }, - encoding: []byte{0x7f, 0x1}, + writes: []writeCall{ + {written: []byte{0x7f}}, + {written: []byte{0x1}}, + }, + }, + "long key first byte write error": { + leaf: &Leaf{ + Key: make([]byte, 63), + }, + writes: []writeCall{ + { + written: []byte{0x7f}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: errTest.Error(), }, "key too big": { leaf: &Leaf{ Key: make([]byte, 65535+63), }, - wrappedErr: encode.ErrPartialKeyTooBig, + writes: []writeCall{ + {written: []byte{0x7f}}, + }, + errWrapped: encode.ErrPartialKeyTooBig, errMessage: "partial key length cannot be larger than or equal to 2^16: 65535", }, } @@ -58,17 +100,27 @@ func Test_Leaf_encodeHeader(t *testing.T) { testCase := testCase t.Run(name, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) - encoding, err := testCase.leaf.encodeHeader() + writer := NewMockWriter(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := writer.EXPECT(). + Write(write.written). + Return(write.n, write.err) - if testCase.wrappedErr != nil { - assert.ErrorIs(t, err, testCase.wrappedErr) - assert.EqualError(t, err, testCase.errMessage) - } else { - assert.NoError(t, err) + if previousCall != nil { + call.After(previousCall) + } + previousCall = call } - assert.Equal(t, testCase.encoding, encoding) + err := testCase.leaf.encodeHeader(writer) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } }) } } From 301188d9bf47f6e0bc7ae655cecfb66609e89f2f Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 7 Dec 2021 09:24:49 +0000 Subject: [PATCH 19/50] Merge packages in `lib/trie/node` - `lib/trie/leaf` - `lib/trie/branch` - `lib/trie/encode` - `lib/trie/decode` --- lib/trie/branch/buffer_mock_test.go | 77 ---------- lib/trie/branch/key.go | 11 -- lib/trie/branch/writer_mock_test.go | 49 ------ lib/trie/{encode/key.go => codec/nibbles.go} | 59 +++----- lib/trie/codec/nibbles_test.go | 142 ++++++++++++++++++ lib/trie/database.go | 24 ++- lib/trie/decode.go | 6 +- lib/trie/decode/key_test.go | 138 ----------------- lib/trie/decode_test.go | 11 +- lib/trie/encode/writer_mock_test.go | 49 ------ lib/trie/encodedecode_test/nibbles_test.go | 50 ------ lib/trie/leaf/copy.go | 29 ---- lib/trie/leaf/decode.go | 50 ------ lib/trie/leaf/decode_test.go | 109 -------------- lib/trie/leaf/dirty.go | 14 -- lib/trie/leaf/generation.go | 14 -- lib/trie/leaf/header.go | 34 ----- lib/trie/leaf/header_test.go | 126 ---------------- lib/trie/leaf/key.go | 11 -- lib/trie/lookup.go | 3 +- lib/trie/{branch => node}/branch.go | 7 +- .../encode.go => node/branch_encode.go} | 20 ++- .../branch_encode_test.go} | 92 +++++------- lib/trie/{encode => node}/buffer.go | 2 +- lib/trie/{leaf => node}/buffer_mock_test.go | 4 +- lib/trie/{branch => node}/children.go | 2 +- lib/trie/{branch => node}/children_test.go | 40 +++-- lib/trie/{branch => node}/copy.go | 29 +++- lib/trie/{branch => node}/decode.go | 46 +++++- lib/trie/{branch => node}/decode_test.go | 95 ++++++++++-- lib/trie/{branch => node}/dirty.go | 12 +- .../encode_decode_test.go} | 35 ++--- .../{encode/doc.go => node/encode_doc.go} | 5 +- lib/trie/node/encode_test.go | 11 ++ lib/trie/{branch => node}/generation.go | 12 +- lib/trie/{branch => node}/hash.go | 2 +- lib/trie/{branch => node}/header.go | 30 +++- lib/trie/{branch => node}/header_test.go | 121 ++++++++++++++- lib/trie/{decode => node}/key.go | 70 ++++++--- lib/trie/{encode => node}/key_test.go | 103 +++++++------ lib/trie/{leaf => node}/leaf.go | 5 +- .../{leaf/encode.go => node/leaf_encode.go} | 8 +- .../leaf_encode_test.go} | 14 +- lib/trie/node/{interface.go => node.go} | 6 +- lib/trie/{leaf => node}/writer_mock_test.go | 4 +- lib/trie/print.go | 6 +- lib/trie/proof.go | 4 +- lib/trie/trie.go | 115 +++++++------- lib/trie/trie_test.go | 22 +-- 49 files changed, 786 insertions(+), 1142 deletions(-) delete mode 100644 lib/trie/branch/buffer_mock_test.go delete mode 100644 lib/trie/branch/key.go delete mode 100644 lib/trie/branch/writer_mock_test.go rename lib/trie/{encode/key.go => codec/nibbles.go} (55%) create mode 100644 lib/trie/codec/nibbles_test.go delete mode 100644 lib/trie/decode/key_test.go delete mode 100644 lib/trie/encode/writer_mock_test.go delete mode 100644 lib/trie/encodedecode_test/nibbles_test.go delete mode 100644 lib/trie/leaf/copy.go delete mode 100644 lib/trie/leaf/decode.go delete mode 100644 lib/trie/leaf/decode_test.go delete mode 100644 lib/trie/leaf/dirty.go delete mode 100644 lib/trie/leaf/generation.go delete mode 100644 lib/trie/leaf/header.go delete mode 100644 lib/trie/leaf/header_test.go delete mode 100644 lib/trie/leaf/key.go rename lib/trie/{branch => node}/branch.go (85%) rename lib/trie/{branch/encode.go => node/branch_encode.go} (90%) rename lib/trie/{branch/encode_test.go => node/branch_encode_test.go} (86%) rename lib/trie/{encode => node}/buffer.go (94%) rename lib/trie/{leaf => node}/buffer_mock_test.go (97%) rename lib/trie/{branch => node}/children.go (97%) rename lib/trie/{branch => node}/children_test.go (75%) rename lib/trie/{branch => node}/copy.go (54%) rename lib/trie/{branch => node}/decode.go (63%) rename lib/trie/{branch => node}/decode_test.go (63%) rename lib/trie/{branch => node}/dirty.go (57%) rename lib/trie/{encodedecode_test/branch_test.go => node/encode_decode_test.go} (66%) rename lib/trie/{encode/doc.go => node/encode_doc.go} (92%) create mode 100644 lib/trie/node/encode_test.go rename lib/trie/{branch => node}/generation.go (56%) rename lib/trie/{branch => node}/hash.go (99%) rename lib/trie/{branch => node}/header.go (54%) rename lib/trie/{branch => node}/header_test.go (51%) rename lib/trie/{decode => node}/key.go (55%) rename lib/trie/{encode => node}/key_test.go (59%) rename lib/trie/{leaf => node}/leaf.go (87%) rename lib/trie/{leaf/encode.go => node/leaf_encode.go} (96%) rename lib/trie/{leaf/encode_test.go => node/leaf_encode_test.go} (96%) rename lib/trie/node/{interface.go => node.go} (77%) rename lib/trie/{leaf => node}/writer_mock_test.go (95%) diff --git a/lib/trie/branch/buffer_mock_test.go b/lib/trie/branch/buffer_mock_test.go deleted file mode 100644 index 3e864b1cfc..0000000000 --- a/lib/trie/branch/buffer_mock_test.go +++ /dev/null @@ -1,77 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ChainSafe/gossamer/lib/trie/encode (interfaces: Buffer) - -// Package branch is a generated GoMock package. -package branch - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockBuffer is a mock of Buffer interface. -type MockBuffer struct { - ctrl *gomock.Controller - recorder *MockBufferMockRecorder -} - -// MockBufferMockRecorder is the mock recorder for MockBuffer. -type MockBufferMockRecorder struct { - mock *MockBuffer -} - -// NewMockBuffer creates a new mock instance. -func NewMockBuffer(ctrl *gomock.Controller) *MockBuffer { - mock := &MockBuffer{ctrl: ctrl} - mock.recorder = &MockBufferMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockBuffer) EXPECT() *MockBufferMockRecorder { - return m.recorder -} - -// Bytes mocks base method. -func (m *MockBuffer) Bytes() []byte { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Bytes") - ret0, _ := ret[0].([]byte) - return ret0 -} - -// Bytes indicates an expected call of Bytes. -func (mr *MockBufferMockRecorder) Bytes() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bytes", reflect.TypeOf((*MockBuffer)(nil).Bytes)) -} - -// Len mocks base method. -func (m *MockBuffer) Len() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Len") - ret0, _ := ret[0].(int) - return ret0 -} - -// Len indicates an expected call of Len. -func (mr *MockBufferMockRecorder) Len() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockBuffer)(nil).Len)) -} - -// Write mocks base method. -func (m *MockBuffer) Write(arg0 []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Write indicates an expected call of Write. -func (mr *MockBufferMockRecorder) Write(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockBuffer)(nil).Write), arg0) -} diff --git a/lib/trie/branch/key.go b/lib/trie/branch/key.go deleted file mode 100644 index aa88e8a0c3..0000000000 --- a/lib/trie/branch/key.go +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package branch - -// SetKey sets the key to the branch. -// Note it does not copy it so modifying the passed key -// will modify the key stored in the branch. -func (b *Branch) SetKey(key []byte) { - b.Key = key -} diff --git a/lib/trie/branch/writer_mock_test.go b/lib/trie/branch/writer_mock_test.go deleted file mode 100644 index 609c8f248d..0000000000 --- a/lib/trie/branch/writer_mock_test.go +++ /dev/null @@ -1,49 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: io (interfaces: Writer) - -// Package branch is a generated GoMock package. -package branch - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockWriter is a mock of Writer interface. -type MockWriter struct { - ctrl *gomock.Controller - recorder *MockWriterMockRecorder -} - -// MockWriterMockRecorder is the mock recorder for MockWriter. -type MockWriterMockRecorder struct { - mock *MockWriter -} - -// NewMockWriter creates a new mock instance. -func NewMockWriter(ctrl *gomock.Controller) *MockWriter { - mock := &MockWriter{ctrl: ctrl} - mock.recorder = &MockWriterMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockWriter) EXPECT() *MockWriterMockRecorder { - return m.recorder -} - -// Write mocks base method. -func (m *MockWriter) Write(arg0 []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Write indicates an expected call of Write. -func (mr *MockWriterMockRecorder) Write(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockWriter)(nil).Write), arg0) -} diff --git a/lib/trie/encode/key.go b/lib/trie/codec/nibbles.go similarity index 55% rename from lib/trie/encode/key.go rename to lib/trie/codec/nibbles.go index 3747df17b8..11e5a2e818 100644 --- a/lib/trie/encode/key.go +++ b/lib/trie/codec/nibbles.go @@ -1,45 +1,7 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package encode - -import ( - "errors" - "fmt" - "io" -) - -const maxPartialKeySize = ^uint16(0) - -var ErrPartialKeyTooBig = errors.New("partial key length cannot be larger than or equal to 2^16") - -// KeyLength encodes the public key length. -func KeyLength(keyLength int, writer io.Writer) (err error) { - keyLength -= 63 - - if keyLength >= int(maxPartialKeySize) { - return fmt.Errorf("%w: %d", - ErrPartialKeyTooBig, keyLength) - } - - for i := uint16(0); i < maxPartialKeySize; i++ { - if keyLength < 255 { - _, err = writer.Write([]byte{byte(keyLength)}) - if err != nil { - return err - } - break - } - _, err = writer.Write([]byte{255}) - if err != nil { - return err - } - - keyLength -= 255 - } - - return nil -} +package codec // NibblesToKeyLE converts a slice of nibbles with length k into a // Little Endian byte slice. @@ -64,3 +26,22 @@ func NibblesToKeyLE(nibbles []byte) (keyLE []byte) { return keyLE } + +// KeyLEToNibbles converts a Little Endian byte slice into nibbles. +// It assumes bytes are already in Little Endian and does not rearrange nibbles. +func KeyLEToNibbles(in []byte) (nibbles []byte) { + if len(in) == 0 { + return []byte{} + } else if len(in) == 1 && in[0] == 0 { + return []byte{0, 0} + } + + l := len(in) * 2 + nibbles = make([]byte, l) + for i, b := range in { + nibbles[2*i] = b / 16 + nibbles[2*i+1] = b % 16 + } + + return nibbles +} diff --git a/lib/trie/codec/nibbles_test.go b/lib/trie/codec/nibbles_test.go new file mode 100644 index 0000000000..fa2bbf4fdd --- /dev/null +++ b/lib/trie/codec/nibbles_test.go @@ -0,0 +1,142 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package codec + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NibblesToKeyLE(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + nibbles []byte + keyLE []byte + }{ + "nil nibbles": { + keyLE: []byte{}, + }, + "empty nibbles": { + nibbles: []byte{}, + keyLE: []byte{}, + }, + "0xF 0xF": { + nibbles: []byte{0xF, 0xF}, + keyLE: []byte{0xFF}, + }, + "0x3 0xa 0x0 0x5": { + nibbles: []byte{0x3, 0xa, 0x0, 0x5}, + keyLE: []byte{0x3a, 0x05}, + }, + "0xa 0xa 0xf 0xf 0x0 0x1": { + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}, + keyLE: []byte{0xaa, 0xff, 0x01}, + }, + "0xa 0xa 0xf 0xf 0x0 0x1 0xc 0x2": { + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}, + keyLE: []byte{0xaa, 0xff, 0x01, 0xc2}, + }, + "0xa 0xa 0xf 0xf 0x0 0x1 0xc": { + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc}, + keyLE: []byte{0xa, 0xaf, 0xf0, 0x1c}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + keyLE := NibblesToKeyLE(testCase.nibbles) + + assert.Equal(t, testCase.keyLE, keyLE) + }) + } +} + +func Test_KeyLEToNibbles(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + in []byte + nibbles []byte + }{ + "nil input": { + nibbles: []byte{}, + }, + "empty input": { + in: []byte{}, + nibbles: []byte{}, + }, + "0x0": { + in: []byte{0x0}, + nibbles: []byte{0, 0}}, + "0xFF": { + in: []byte{0xFF}, + nibbles: []byte{0xF, 0xF}}, + "0x3a 0x05": { + in: []byte{0x3a, 0x05}, + nibbles: []byte{0x3, 0xa, 0x0, 0x5}}, + "0xAA 0xFF 0x01": { + in: []byte{0xAA, 0xFF, 0x01}, + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}}, + "0xAA 0xFF 0x01 0xc2": { + in: []byte{0xAA, 0xFF, 0x01, 0xc2}, + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}}, + "0xAA 0xFF 0x01 0xc0": { + in: []byte{0xAA, 0xFF, 0x01, 0xc0}, + nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x0}}, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + nibbles := KeyLEToNibbles(testCase.in) + + assert.Equal(t, testCase.nibbles, nibbles) + }) + } +} + +func Test_NibblesKeyLE(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + nibblesToEncode []byte + nibblesDecoded []byte + }{ + "empty input": { + nibblesToEncode: []byte{}, + nibblesDecoded: []byte{}, + }, + "one byte": { + nibblesToEncode: []byte{1}, + nibblesDecoded: []byte{0, 1}, + }, + "two bytes": { + nibblesToEncode: []byte{1, 2}, + nibblesDecoded: []byte{1, 2}, + }, + "three bytes": { + nibblesToEncode: []byte{1, 2, 3}, + nibblesDecoded: []byte{0, 1, 2, 3}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + keyLE := NibblesToKeyLE(testCase.nibblesToEncode) + nibblesDecoded := KeyLEToNibbles(keyLE) + + assert.Equal(t, testCase.nibblesDecoded, nibblesDecoded) + }) + } +} diff --git a/lib/trie/database.go b/lib/trie/database.go index 09c2a0c535..834b707f2c 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -9,9 +9,7 @@ import ( "fmt" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/branch" - "github.com/ChainSafe/gossamer/lib/trie/decode" - "github.com/ChainSafe/gossamer/lib/trie/leaf" + "github.com/ChainSafe/gossamer/lib/trie/codec" "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/ChainSafe/chaindb" @@ -50,7 +48,7 @@ func (t *Trie) store(db chaindb.Batch, curr node.Node) error { return err } - if c, ok := curr.(*branch.Branch); ok { + if c, ok := curr.(*node.Branch); ok { for _, child := range c.Children { if child == nil { continue @@ -108,7 +106,7 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { // loadProof is a recursive function that will create all the trie paths based // on the mapped proofs slice starting by the root func (t *Trie) loadProof(proof map[string]node.Node, curr node.Node) { - c, ok := curr.(*branch.Branch) + c, ok := curr.(*node.Branch) if !ok { return } @@ -153,7 +151,7 @@ func (t *Trie) Load(db chaindb.Database, root common.Hash) error { } func (t *Trie) load(db chaindb.Database, curr node.Node) error { - if c, ok := curr.(*branch.Branch); ok { + if c, ok := curr.(*node.Branch); ok { for i, child := range c.Children { if child == nil { continue @@ -186,7 +184,7 @@ func (t *Trie) load(db chaindb.Database, curr node.Node) error { // GetNodeHashes return hash of each key of the trie. func (t *Trie) GetNodeHashes(curr node.Node, keys map[common.Hash]struct{}) error { - if c, ok := curr.(*branch.Branch); ok { + if c, ok := curr.(*node.Branch); ok { for _, child := range c.Children { if child == nil { continue @@ -238,7 +236,7 @@ func GetFromDB(db chaindb.Database, root common.Hash, key []byte) ([]byte, error return nil, nil } - k := decode.KeyLEToNibbles(key) + k := codec.KeyLEToNibbles(key) enc, err := db.Get(root[:]) if err != nil { @@ -257,7 +255,7 @@ func getFromDB(db chaindb.Database, parent node.Node, key []byte) ([]byte, error var value []byte switch p := parent.(type) { - case *branch.Branch: + case *node.Branch: length := lenCommonPrefix(p.Key, key) // found the value at this node @@ -275,7 +273,7 @@ func getFromDB(db chaindb.Database, parent node.Node, key []byte) ([]byte, error } // load child with potential value - enc, err := db.Get(p.Children[key[length]].(*leaf.Leaf).Hash) + enc, err := db.Get(p.Children[key[length]].(*node.Leaf).Hash) if err != nil { return nil, fmt.Errorf("failed to find node in database: %w", err) } @@ -289,7 +287,7 @@ func getFromDB(db chaindb.Database, parent node.Node, key []byte) ([]byte, error if err != nil { return nil, err } - case *leaf.Leaf: + case *node.Leaf: if bytes.Equal(p.Key, key) { return p.Value, nil } @@ -337,7 +335,7 @@ func (t *Trie) writeDirty(db chaindb.Batch, curr node.Node) error { return err } - if c, ok := curr.(*branch.Branch); ok { + if c, ok := curr.(*node.Branch); ok { for _, child := range c.Children { if child == nil { continue @@ -383,7 +381,7 @@ func (t *Trie) getInsertedNodeHashes(curr node.Node) ([]common.Hash, error) { nodeHash := common.BytesToHash(hash) nodeHashes = append(nodeHashes, nodeHash) - if c, ok := curr.(*branch.Branch); ok { + if c, ok := curr.(*node.Branch); ok { for _, child := range c.Children { if child == nil { continue diff --git a/lib/trie/decode.go b/lib/trie/decode.go index 46bf12d751..c3305615c7 100644 --- a/lib/trie/decode.go +++ b/lib/trie/decode.go @@ -9,8 +9,6 @@ import ( "fmt" "io" - "github.com/ChainSafe/gossamer/lib/trie/branch" - "github.com/ChainSafe/gossamer/lib/trie/leaf" "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/ChainSafe/gossamer/lib/trie/pools" ) @@ -33,13 +31,13 @@ func decodeNode(reader io.Reader) (n node.Node, err error) { nodeType := header >> 6 switch nodeType { case node.LeafType: - n, err = leaf.Decode(reader, header) + n, err = node.DecodeLeaf(reader, header) if err != nil { return nil, fmt.Errorf("cannot decode leaf: %w", err) } return n, nil case node.BranchType, node.BranchWithValueType: - n, err = branch.Decode(reader, header) + n, err = node.DecodeBranch(reader, header) if err != nil { return nil, fmt.Errorf("cannot decode branch: %w", err) } diff --git a/lib/trie/decode/key_test.go b/lib/trie/decode/key_test.go deleted file mode 100644 index c25749d7e8..0000000000 --- a/lib/trie/decode/key_test.go +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package decode - -import ( - "bytes" - "io" - "testing" - - "github.com/stretchr/testify/assert" -) - -func repeatBytes(n int, b byte) (slice []byte) { - slice = make([]byte, n) - for i := range slice { - slice[i] = b - } - return slice -} - -func Test_Key(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - reader io.Reader - keyLength byte - b []byte - errWrapped error - errMessage string - }{ - "zero key length": { - b: []byte{}, - }, - "short key length": { - reader: bytes.NewBuffer([]byte{1, 2, 3}), - keyLength: 5, - b: []byte{0x1, 0x0, 0x2, 0x0, 0x3}, - }, - "key read error": { - reader: bytes.NewBuffer(nil), - keyLength: 5, - errWrapped: ErrReadKeyData, - errMessage: "cannot read key data: EOF", - }, - "long key length": { - reader: bytes.NewBuffer( - append( - []byte{ - 6, // key length - }, - repeatBytes(64, 7)..., // key data - )), - keyLength: 0x3f, - b: []byte{ - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7}, - }, - "key length read error": { - reader: bytes.NewBuffer(nil), - keyLength: 0x3f, - errWrapped: ErrReadKeyLength, - errMessage: "cannot read key length: EOF", - }, - "key length too big": { - reader: bytes.NewBuffer(repeatBytes(257, 0xff)), - keyLength: 0x3f, - errWrapped: ErrPartialKeyTooBig, - errMessage: "partial key length cannot be larger than or equal to 2^16: 65598", - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - b, err := Key(testCase.reader, testCase.keyLength) - - assert.ErrorIs(t, err, testCase.errWrapped) - if err != nil { - assert.EqualError(t, err, testCase.errMessage) - } - assert.Equal(t, testCase.b, b) - }) - } -} - -func Test_KeyToNibbles(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - in []byte - nibbles []byte - }{ - "nil input": { - nibbles: []byte{}, - }, - "empty input": { - in: []byte{}, - nibbles: []byte{}, - }, - "0x0": { - in: []byte{0x0}, - nibbles: []byte{0, 0}}, - "0xFF": { - in: []byte{0xFF}, - nibbles: []byte{0xF, 0xF}}, - "0x3a 0x05": { - in: []byte{0x3a, 0x05}, - nibbles: []byte{0x3, 0xa, 0x0, 0x5}}, - "0xAA 0xFF 0x01": { - in: []byte{0xAA, 0xFF, 0x01}, - nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}}, - "0xAA 0xFF 0x01 0xc2": { - in: []byte{0xAA, 0xFF, 0x01, 0xc2}, - nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}}, - "0xAA 0xFF 0x01 0xc0": { - in: []byte{0xAA, 0xFF, 0x01, 0xc0}, - nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x0}}, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - nibbles := KeyLEToNibbles(testCase.in) - - assert.Equal(t, testCase.nibbles, nibbles) - }) - } -} diff --git a/lib/trie/decode_test.go b/lib/trie/decode_test.go index 32002d4f24..a4c1b4c068 100644 --- a/lib/trie/decode_test.go +++ b/lib/trie/decode_test.go @@ -8,9 +8,6 @@ import ( "io" "testing" - "github.com/ChainSafe/gossamer/lib/trie/branch" - "github.com/ChainSafe/gossamer/lib/trie/decode" - "github.com/ChainSafe/gossamer/lib/trie/leaf" "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/ChainSafe/gossamer/pkg/scale" "github.com/stretchr/testify/assert" @@ -47,7 +44,7 @@ func Test_decodeNode(t *testing.T) { 65, // node type 1 and key length 1 // missing key data byte }), - errWrapped: decode.ErrReadKeyData, + errWrapped: node.ErrReadKeyData, errMessage: "cannot decode leaf: cannot decode key: cannot read key data: EOF", }, "leaf success": { @@ -60,7 +57,7 @@ func Test_decodeNode(t *testing.T) { scaleEncodeBytes(t, 1, 2, 3)..., ), ), - n: &leaf.Leaf{ + n: &node.Leaf{ Key: []byte{9}, Value: []byte{1, 2, 3}, Dirty: true, @@ -71,7 +68,7 @@ func Test_decodeNode(t *testing.T) { 129, // node type 2 and key length 1 // missing key data byte }), - errWrapped: decode.ErrReadKeyData, + errWrapped: node.ErrReadKeyData, errMessage: "cannot decode branch: cannot decode key: cannot read key data: EOF", }, "branch success": { @@ -82,7 +79,7 @@ func Test_decodeNode(t *testing.T) { 0, 0, // no children bitmap }, ), - n: &branch.Branch{ + n: &node.Branch{ Key: []byte{9}, Dirty: true, }, diff --git a/lib/trie/encode/writer_mock_test.go b/lib/trie/encode/writer_mock_test.go deleted file mode 100644 index 171dc73964..0000000000 --- a/lib/trie/encode/writer_mock_test.go +++ /dev/null @@ -1,49 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: io (interfaces: Writer) - -// Package encode is a generated GoMock package. -package encode - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockWriter is a mock of Writer interface. -type MockWriter struct { - ctrl *gomock.Controller - recorder *MockWriterMockRecorder -} - -// MockWriterMockRecorder is the mock recorder for MockWriter. -type MockWriterMockRecorder struct { - mock *MockWriter -} - -// NewMockWriter creates a new mock instance. -func NewMockWriter(ctrl *gomock.Controller) *MockWriter { - mock := &MockWriter{ctrl: ctrl} - mock.recorder = &MockWriterMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockWriter) EXPECT() *MockWriterMockRecorder { - return m.recorder -} - -// Write mocks base method. -func (m *MockWriter) Write(arg0 []byte) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Write indicates an expected call of Write. -func (mr *MockWriterMockRecorder) Write(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockWriter)(nil).Write), arg0) -} diff --git a/lib/trie/encodedecode_test/nibbles_test.go b/lib/trie/encodedecode_test/nibbles_test.go deleted file mode 100644 index 05fd0b5a95..0000000000 --- a/lib/trie/encodedecode_test/nibbles_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package encodedecode_test - -import ( - "testing" - - "github.com/ChainSafe/gossamer/lib/trie/decode" - "github.com/ChainSafe/gossamer/lib/trie/encode" - "github.com/stretchr/testify/assert" -) - -func Test_NibblesKeyLE(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - nibblesToEncode []byte - nibblesDecoded []byte - }{ - "empty input": { - nibblesToEncode: []byte{}, - nibblesDecoded: []byte{}, - }, - "one byte": { - nibblesToEncode: []byte{1}, - nibblesDecoded: []byte{0, 1}, - }, - "two bytes": { - nibblesToEncode: []byte{1, 2}, - nibblesDecoded: []byte{1, 2}, - }, - "three bytes": { - nibblesToEncode: []byte{1, 2, 3}, - nibblesDecoded: []byte{0, 1, 2, 3}, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - keyLE := encode.NibblesToKeyLE(testCase.nibblesToEncode) - nibblesDecoded := decode.KeyLEToNibbles(keyLE) - - assert.Equal(t, testCase.nibblesDecoded, nibblesDecoded) - }) - } -} diff --git a/lib/trie/leaf/copy.go b/lib/trie/leaf/copy.go deleted file mode 100644 index 3f07972249..0000000000 --- a/lib/trie/leaf/copy.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package leaf - -import "github.com/ChainSafe/gossamer/lib/trie/node" - -// Copy deep copies the leaf. -func (l *Leaf) Copy() node.Node { - l.RLock() - defer l.RUnlock() - - l.encodingMu.RLock() - defer l.encodingMu.RUnlock() - - cpy := &Leaf{ - Key: make([]byte, len(l.Key)), - Value: make([]byte, len(l.Value)), - Dirty: l.Dirty, - Hash: make([]byte, len(l.Hash)), - Encoding: make([]byte, len(l.Encoding)), - Generation: l.Generation, - } - copy(cpy.Key, l.Key) - copy(cpy.Value, l.Value) - copy(cpy.Hash, l.Hash) - copy(cpy.Encoding, l.Encoding) - return cpy -} diff --git a/lib/trie/leaf/decode.go b/lib/trie/leaf/decode.go deleted file mode 100644 index 6c20fffbcc..0000000000 --- a/lib/trie/leaf/decode.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package leaf - -import ( - "errors" - "fmt" - "io" - - "github.com/ChainSafe/gossamer/lib/trie/decode" - "github.com/ChainSafe/gossamer/pkg/scale" -) - -var ( - ErrReadHeaderByte = errors.New("cannot read header byte") - ErrNodeTypeIsNotALeaf = errors.New("node type is not a leaf") - ErrDecodeValue = errors.New("cannot decode value") -) - -// Decode reads and decodes from a reader with the encoding specified in lib/trie/encode/doc.go. -func Decode(r io.Reader, header byte) (leaf *Leaf, err error) { - nodeType := header >> 6 - if nodeType != 1 { - return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotALeaf, nodeType) - } - - leaf = new(Leaf) - - keyLen := header & 0x3f - leaf.Key, err = decode.Key(r, keyLen) - if err != nil { - return nil, fmt.Errorf("cannot decode key: %w", err) - } - - sd := scale.NewDecoder(r) - var value []byte - err = sd.Decode(&value) - if err != nil { - return nil, fmt.Errorf("%w: %s", ErrDecodeValue, err) - } - - if len(value) > 0 { - leaf.Value = value - } - - leaf.Dirty = true - - return leaf, nil -} diff --git a/lib/trie/leaf/decode_test.go b/lib/trie/leaf/decode_test.go deleted file mode 100644 index e702b666e2..0000000000 --- a/lib/trie/leaf/decode_test.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package leaf - -import ( - "bytes" - "io" - "testing" - - "github.com/ChainSafe/gossamer/lib/trie/decode" - "github.com/ChainSafe/gossamer/pkg/scale" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func scaleEncodeBytes(t *testing.T, b ...byte) (encoded []byte) { - encoded, err := scale.Marshal(b) - require.NoError(t, err) - return encoded -} - -func concatByteSlices(slices [][]byte) (concatenated []byte) { - length := 0 - for i := range slices { - length += len(slices[i]) - } - concatenated = make([]byte, 0, length) - for _, slice := range slices { - concatenated = append(concatenated, slice...) - } - return concatenated -} - -func Test_Decode(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - reader io.Reader - header byte - leaf *Leaf - errWrapped error - errMessage string - }{ - "no data with header 1": { - reader: bytes.NewBuffer(nil), - header: 1, - errWrapped: ErrNodeTypeIsNotALeaf, - errMessage: "node type is not a leaf: 0", - }, - "key decoding error": { - reader: bytes.NewBuffer([]byte{ - // missing key data byte - }), - header: 65, // node type 1 and key length 1 - errWrapped: decode.ErrReadKeyData, - errMessage: "cannot decode key: cannot read key data: EOF", - }, - "value decoding error": { - reader: bytes.NewBuffer([]byte{ - 9, // key data - // missing value data - }), - header: 65, // node type 1 and key length 1 - errWrapped: ErrDecodeValue, - errMessage: "cannot decode value: EOF", - }, - "zero value": { - reader: bytes.NewBuffer([]byte{ - 9, // key data - 0, // missing value data - }), - header: 65, // node type 1 and key length 1 - leaf: &Leaf{ - Key: []byte{9}, - Dirty: true, - }, - }, - "success": { - reader: bytes.NewBuffer( - concatByteSlices([][]byte{ - {9}, // key data - scaleEncodeBytes(t, 1, 2, 3, 4, 5), // value data - }), - ), - header: 65, // node type 1 and key length 1 - leaf: &Leaf{ - Key: []byte{9}, - Value: []byte{1, 2, 3, 4, 5}, - Dirty: true, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - leaf, err := Decode(testCase.reader, testCase.header) - - assert.ErrorIs(t, err, testCase.errWrapped) - if err != nil { - assert.EqualError(t, err, testCase.errMessage) - } - assert.Equal(t, testCase.leaf, leaf) - }) - } -} diff --git a/lib/trie/leaf/dirty.go b/lib/trie/leaf/dirty.go deleted file mode 100644 index b955754b03..0000000000 --- a/lib/trie/leaf/dirty.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package leaf - -// IsDirty returns the dirty status of the leaf. -func (l *Leaf) IsDirty() bool { - return l.Dirty -} - -// SetDirty sets the dirty status to the leaf. -func (l *Leaf) SetDirty(dirty bool) { - l.Dirty = dirty -} diff --git a/lib/trie/leaf/generation.go b/lib/trie/leaf/generation.go deleted file mode 100644 index 1ce46bf81d..0000000000 --- a/lib/trie/leaf/generation.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package leaf - -// SetGeneration sets the generation given to the leaf. -func (l *Leaf) SetGeneration(generation uint64) { - l.Generation = generation -} - -// GetGeneration returns the generation of the leaf. -func (l *Leaf) GetGeneration() uint64 { - return l.Generation -} diff --git a/lib/trie/leaf/header.go b/lib/trie/leaf/header.go deleted file mode 100644 index 930658f65f..0000000000 --- a/lib/trie/leaf/header.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package leaf - -import ( - "io" - - "github.com/ChainSafe/gossamer/lib/trie/encode" -) - -// encodeHeader creates the encoded header for the leaf. -func (l *Leaf) encodeHeader(writer io.Writer) (err error) { - var header byte = 1 << 6 - - if len(l.Key) < 63 { - header = header | byte(len(l.Key)) - _, err = writer.Write([]byte{header}) - return err - } - - header = header | 0x3f - _, err = writer.Write([]byte{header}) - if err != nil { - return err - } - - err = encode.KeyLength(len(l.Key), writer) - if err != nil { - return err - } - - return nil -} diff --git a/lib/trie/leaf/header_test.go b/lib/trie/leaf/header_test.go deleted file mode 100644 index 08f03d0ed9..0000000000 --- a/lib/trie/leaf/header_test.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package leaf - -import ( - "testing" - - "github.com/ChainSafe/gossamer/lib/trie/encode" - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" -) - -func Test_Leaf_encodeHeader(t *testing.T) { - testCases := map[string]struct { - leaf *Leaf - writes []writeCall - errWrapped error - errMessage string - }{ - "no key": { - leaf: &Leaf{}, - writes: []writeCall{ - {written: []byte{0x40}}, - }, - }, - "key of length 30": { - leaf: &Leaf{ - Key: make([]byte, 30), - }, - writes: []writeCall{ - {written: []byte{0x5e}}, - }, - }, - "short key write error": { - leaf: &Leaf{ - Key: make([]byte, 30), - }, - writes: []writeCall{ - { - written: []byte{0x5e}, - err: errTest, - }, - }, - errWrapped: errTest, - errMessage: errTest.Error(), - }, - "key of length 62": { - leaf: &Leaf{ - Key: make([]byte, 62), - }, - writes: []writeCall{ - {written: []byte{0x7e}}, - }, - }, - "key of length 63": { - leaf: &Leaf{ - Key: make([]byte, 63), - }, - writes: []writeCall{ - {written: []byte{0x7f}}, - {written: []byte{0x0}}, - }, - }, - "key of length 64": { - leaf: &Leaf{ - Key: make([]byte, 64), - }, - writes: []writeCall{ - {written: []byte{0x7f}}, - {written: []byte{0x1}}, - }, - }, - "long key first byte write error": { - leaf: &Leaf{ - Key: make([]byte, 63), - }, - writes: []writeCall{ - { - written: []byte{0x7f}, - err: errTest, - }, - }, - errWrapped: errTest, - errMessage: errTest.Error(), - }, - "key too big": { - leaf: &Leaf{ - Key: make([]byte, 65535+63), - }, - writes: []writeCall{ - {written: []byte{0x7f}}, - }, - errWrapped: encode.ErrPartialKeyTooBig, - errMessage: "partial key length cannot be larger than or equal to 2^16: 65535", - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - writer := NewMockWriter(ctrl) - var previousCall *gomock.Call - for _, write := range testCase.writes { - call := writer.EXPECT(). - Write(write.written). - Return(write.n, write.err) - - if previousCall != nil { - call.After(previousCall) - } - previousCall = call - } - - err := testCase.leaf.encodeHeader(writer) - - assert.ErrorIs(t, err, testCase.errWrapped) - if testCase.errWrapped != nil { - assert.EqualError(t, err, testCase.errMessage) - } - }) - } -} diff --git a/lib/trie/leaf/key.go b/lib/trie/leaf/key.go deleted file mode 100644 index 9a7d3a11d6..0000000000 --- a/lib/trie/leaf/key.go +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package leaf - -// SetKey sets the key to the leaf. -// Note it does not copy it so modifying the passed key -// will modify the key stored in the leaf. -func (l *Leaf) SetKey(key []byte) { - l.Key = key -} diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index e51596dc69..c0680a88fc 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -6,7 +6,6 @@ package trie import ( "bytes" - "github.com/ChainSafe/gossamer/lib/trie/branch" "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/ChainSafe/gossamer/lib/trie/record" ) @@ -30,7 +29,7 @@ func find(parent node.Node, key []byte, recorder recorder) error { recorder.Record(hash, enc) - b, ok := parent.(*branch.Branch) + b, ok := parent.(*node.Branch) if !ok { return nil } diff --git a/lib/trie/branch/branch.go b/lib/trie/node/branch.go similarity index 85% rename from lib/trie/branch/branch.go rename to lib/trie/node/branch.go index a5d0dfc36d..45547d1b3b 100644 --- a/lib/trie/branch/branch.go +++ b/lib/trie/node/branch.go @@ -1,22 +1,21 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package branch +package node import ( "fmt" "sync" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/node" ) -var _ node.Node = (*Branch)(nil) +var _ Node = (*Branch)(nil) // Branch is a branch in the trie. type Branch struct { Key []byte // partial key - Children [16]node.Node + Children [16]Node Value []byte Dirty bool Hash []byte diff --git a/lib/trie/branch/encode.go b/lib/trie/node/branch_encode.go similarity index 90% rename from lib/trie/branch/encode.go rename to lib/trie/node/branch_encode.go index 37d49f74cd..e24fdd7261 100644 --- a/lib/trie/branch/encode.go +++ b/lib/trie/node/branch_encode.go @@ -1,7 +1,7 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package branch +package node import ( "bytes" @@ -10,9 +10,7 @@ import ( "io" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/encode" - "github.com/ChainSafe/gossamer/lib/trie/leaf" - "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/ChainSafe/gossamer/lib/trie/codec" "github.com/ChainSafe/gossamer/lib/trie/pools" "github.com/ChainSafe/gossamer/pkg/scale" ) @@ -77,7 +75,7 @@ func (b *Branch) hash(digestBuffer io.Writer) (err error) { // Encode encodes a branch with the encoding specified at the top of this package // to the buffer given. -func (b *Branch) Encode(buffer encode.Buffer) (err error) { +func (b *Branch) Encode(buffer Buffer) (err error) { if !b.Dirty && b.Encoding != nil { _, err = buffer.Write(b.Encoding) if err != nil { @@ -91,7 +89,7 @@ func (b *Branch) Encode(buffer encode.Buffer) (err error) { return fmt.Errorf("cannot encode header: %w", err) } - keyLE := encode.NibblesToKeyLE(b.Key) + keyLE := codec.NibblesToKeyLE(b.Key) _, err = buffer.Write(keyLE) if err != nil { return fmt.Errorf("cannot write encoded key to buffer: %w", err) @@ -128,7 +126,7 @@ func (b *Branch) Encode(buffer encode.Buffer) (err error) { return nil } -func encodeChildrenInParallel(children [16]node.Node, buffer io.Writer) (err error) { +func encodeChildrenInParallel(children [16]Node, buffer io.Writer) (err error) { type result struct { index int buffer *bytes.Buffer @@ -138,7 +136,7 @@ func encodeChildrenInParallel(children [16]node.Node, buffer io.Writer) (err err resultsCh := make(chan result) for i, child := range children { - go func(index int, child node.Node) { + go func(index int, child Node) { buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) buffer.Reset() // buffer is put back in the pool after processing its @@ -195,7 +193,7 @@ func encodeChildrenInParallel(children [16]node.Node, buffer io.Writer) (err err return err } -func encodeChildrenSequentially(children [16]node.Node, buffer io.Writer) (err error) { +func encodeChildrenSequentially(children [16]Node, buffer io.Writer) (err error) { for i, child := range children { err = encodeChild(child, buffer) if err != nil { @@ -205,12 +203,12 @@ func encodeChildrenSequentially(children [16]node.Node, buffer io.Writer) (err e return nil } -func encodeChild(child node.Node, buffer io.Writer) (err error) { +func encodeChild(child Node, buffer io.Writer) (err error) { var isNil bool switch impl := child.(type) { case *Branch: isNil = impl == nil - case *leaf.Leaf: + case *Leaf: isNil = impl == nil default: isNil = child == nil diff --git a/lib/trie/branch/encode_test.go b/lib/trie/node/branch_encode_test.go similarity index 86% rename from lib/trie/branch/encode_test.go rename to lib/trie/node/branch_encode_test.go index 3de07d6f5c..2589e37e0f 100644 --- a/lib/trie/branch/encode_test.go +++ b/lib/trie/node/branch_encode_test.go @@ -1,28 +1,16 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package branch +package node import ( - "errors" "testing" - "github.com/ChainSafe/gossamer/lib/trie/encode" - "github.com/ChainSafe/gossamer/lib/trie/leaf" - "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type writeCall struct { - written []byte - n int - err error -} - -var errTest = errors.New("test error") - //go:generate mockgen -destination=buffer_mock_test.go -package $GOPACKAGE github.com/ChainSafe/gossamer/lib/trie/encode Buffer func Test_Branch_Encode(t *testing.T) { @@ -66,7 +54,7 @@ func Test_Branch_Encode(t *testing.T) { written: []byte{191}, }, }, - wrappedErr: encode.ErrPartialKeyTooBig, + wrappedErr: ErrPartialKeyTooBig, errMessage: "cannot encode header: partial key length cannot be larger than or equal to 2^16: 65536", }, "buffer write error for encoded key": { @@ -90,9 +78,9 @@ func Test_Branch_Encode(t *testing.T) { branch: &Branch{ Key: []byte{1, 2, 3}, Value: []byte{100}, - Children: [16]node.Node{ - nil, nil, nil, &leaf.Leaf{Key: []byte{9}}, - nil, nil, nil, &leaf.Leaf{Key: []byte{11}}, + Children: [16]Node{ + nil, nil, nil, &Leaf{Key: []byte{9}}, + nil, nil, nil, &Leaf{Key: []byte{11}}, }, }, writes: []writeCall{ @@ -114,9 +102,9 @@ func Test_Branch_Encode(t *testing.T) { branch: &Branch{ Key: []byte{1, 2, 3}, Value: []byte{100}, - Children: [16]node.Node{ - nil, nil, nil, &leaf.Leaf{Key: []byte{9}}, - nil, nil, nil, &leaf.Leaf{Key: []byte{11}}, + Children: [16]Node{ + nil, nil, nil, &Leaf{Key: []byte{9}}, + nil, nil, nil, &Leaf{Key: []byte{11}}, }, }, writes: []writeCall{ @@ -141,9 +129,9 @@ func Test_Branch_Encode(t *testing.T) { branch: &Branch{ Key: []byte{1, 2, 3}, Value: []byte{100}, - Children: [16]node.Node{ - nil, nil, nil, &leaf.Leaf{Key: []byte{9}}, - nil, nil, nil, &leaf.Leaf{Key: []byte{11}}, + Children: [16]Node{ + nil, nil, nil, &Leaf{Key: []byte{9}}, + nil, nil, nil, &Leaf{Key: []byte{11}}, }, }, writes: []writeCall{ @@ -173,9 +161,9 @@ func Test_Branch_Encode(t *testing.T) { branch: &Branch{ Key: []byte{1, 2, 3}, Value: []byte{100}, - Children: [16]node.Node{ - nil, nil, nil, &leaf.Leaf{Key: []byte{9}}, - nil, nil, nil, &leaf.Leaf{Key: []byte{11}}, + Children: [16]Node{ + nil, nil, nil, &Leaf{Key: []byte{9}}, + nil, nil, nil, &Leaf{Key: []byte{11}}, }, }, writes: []writeCall{ @@ -236,15 +224,15 @@ func Test_encodeChildrenInParallel(t *testing.T) { t.Parallel() testCases := map[string]struct { - children [16]node.Node + children [16]Node writes []writeCall wrappedErr error errMessage string }{ "no children": {}, "first child not nil": { - children: [16]node.Node{ - &leaf.Leaf{Key: []byte{1}}, + children: [16]Node{ + &Leaf{Key: []byte{1}}, }, writes: []writeCall{ { @@ -253,11 +241,11 @@ func Test_encodeChildrenInParallel(t *testing.T) { }, }, "last child not nil": { - children: [16]node.Node{ + children: [16]Node{ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - &leaf.Leaf{Key: []byte{1}}, + &Leaf{Key: []byte{1}}, }, writes: []writeCall{ { @@ -266,9 +254,9 @@ func Test_encodeChildrenInParallel(t *testing.T) { }, }, "first two children not nil": { - children: [16]node.Node{ - &leaf.Leaf{Key: []byte{1}}, - &leaf.Leaf{Key: []byte{2}}, + children: [16]Node{ + &Leaf{Key: []byte{1}}, + &Leaf{Key: []byte{2}}, }, writes: []writeCall{ { @@ -280,11 +268,11 @@ func Test_encodeChildrenInParallel(t *testing.T) { }, }, "encoding error": { - children: [16]node.Node{ + children: [16]Node{ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - &leaf.Leaf{ + &Leaf{ Key: []byte{1}, }, nil, nil, nil, nil, @@ -336,15 +324,15 @@ func Test_encodeChildrenSequentially(t *testing.T) { t.Parallel() testCases := map[string]struct { - children [16]node.Node + children [16]Node writes []writeCall wrappedErr error errMessage string }{ "no children": {}, "first child not nil": { - children: [16]node.Node{ - &leaf.Leaf{Key: []byte{1}}, + children: [16]Node{ + &Leaf{Key: []byte{1}}, }, writes: []writeCall{ { @@ -353,11 +341,11 @@ func Test_encodeChildrenSequentially(t *testing.T) { }, }, "last child not nil": { - children: [16]node.Node{ + children: [16]Node{ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - &leaf.Leaf{Key: []byte{1}}, + &Leaf{Key: []byte{1}}, }, writes: []writeCall{ { @@ -366,9 +354,9 @@ func Test_encodeChildrenSequentially(t *testing.T) { }, }, "first two children not nil": { - children: [16]node.Node{ - &leaf.Leaf{Key: []byte{1}}, - &leaf.Leaf{Key: []byte{2}}, + children: [16]Node{ + &Leaf{Key: []byte{1}}, + &Leaf{Key: []byte{2}}, }, writes: []writeCall{ { @@ -380,11 +368,11 @@ func Test_encodeChildrenSequentially(t *testing.T) { }, }, "encoding error": { - children: [16]node.Node{ + children: [16]Node{ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - &leaf.Leaf{ + &Leaf{ Key: []byte{1}, }, nil, nil, nil, nil, @@ -438,7 +426,7 @@ func Test_encodeChild(t *testing.T) { t.Parallel() testCases := map[string]struct { - child node.Node + child Node writeCall bool write writeCall wrappedErr error @@ -446,13 +434,13 @@ func Test_encodeChild(t *testing.T) { }{ "nil node": {}, "nil leaf": { - child: (*leaf.Leaf)(nil), + child: (*Leaf)(nil), }, "nil branch": { child: (*Branch)(nil), }, "empty leaf child": { - child: &leaf.Leaf{}, + child: &Leaf{}, writeCall: true, write: writeCall{ written: []byte{8, 64, 0}, @@ -476,7 +464,7 @@ func Test_encodeChild(t *testing.T) { errMessage: "failed to write child to buffer: test error", }, "leaf child": { - child: &leaf.Leaf{ + child: &Leaf{ Key: []byte{1}, Value: []byte{2}, }, @@ -489,8 +477,8 @@ func Test_encodeChild(t *testing.T) { child: &Branch{ Key: []byte{1}, Value: []byte{2}, - Children: [16]node.Node{ - nil, nil, &leaf.Leaf{ + Children: [16]Node{ + nil, nil, &Leaf{ Key: []byte{5}, Value: []byte{6}, }, diff --git a/lib/trie/encode/buffer.go b/lib/trie/node/buffer.go similarity index 94% rename from lib/trie/encode/buffer.go rename to lib/trie/node/buffer.go index 748f30ed97..c1515f94c6 100644 --- a/lib/trie/encode/buffer.go +++ b/lib/trie/node/buffer.go @@ -1,7 +1,7 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package encode +package node import "io" diff --git a/lib/trie/leaf/buffer_mock_test.go b/lib/trie/node/buffer_mock_test.go similarity index 97% rename from lib/trie/leaf/buffer_mock_test.go rename to lib/trie/node/buffer_mock_test.go index d3404d0a48..1357336f23 100644 --- a/lib/trie/leaf/buffer_mock_test.go +++ b/lib/trie/node/buffer_mock_test.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/ChainSafe/gossamer/lib/trie/encode (interfaces: Buffer) -// Package leaf is a generated GoMock package. -package leaf +// Package node is a generated GoMock package. +package node import ( reflect "reflect" diff --git a/lib/trie/branch/children.go b/lib/trie/node/children.go similarity index 97% rename from lib/trie/branch/children.go rename to lib/trie/node/children.go index ff911dc513..bd581cf657 100644 --- a/lib/trie/branch/children.go +++ b/lib/trie/node/children.go @@ -1,7 +1,7 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package branch +package node // ChildrenBitmap returns the 16 bit bitmap // of the children in the branch. diff --git a/lib/trie/branch/children_test.go b/lib/trie/node/children_test.go similarity index 75% rename from lib/trie/branch/children_test.go rename to lib/trie/node/children_test.go index a525a57769..4b60039656 100644 --- a/lib/trie/branch/children_test.go +++ b/lib/trie/node/children_test.go @@ -1,13 +1,11 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package branch +package node import ( "testing" - "github.com/ChainSafe/gossamer/lib/trie/leaf" - "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/stretchr/testify/assert" ) @@ -23,31 +21,31 @@ func Test_Branch_ChildrenBitmap(t *testing.T) { }, "index 0": { branch: &Branch{ - Children: [16]node.Node{ - &leaf.Leaf{}, + Children: [16]Node{ + &Leaf{}, }, }, bitmap: 1, }, "index 0 and 4": { branch: &Branch{ - Children: [16]node.Node{ - &leaf.Leaf{}, + Children: [16]Node{ + &Leaf{}, nil, nil, nil, - &leaf.Leaf{}, + &Leaf{}, }, }, bitmap: 1<<4 + 1, }, "index 0, 4 and 15": { branch: &Branch{ - Children: [16]node.Node{ - &leaf.Leaf{}, + Children: [16]Node{ + &Leaf{}, nil, nil, nil, - &leaf.Leaf{}, + &Leaf{}, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - &leaf.Leaf{}, + &Leaf{}, }, }, bitmap: 1<<15 + 1<<4 + 1, @@ -78,31 +76,31 @@ func Test_Branch_NumChildren(t *testing.T) { }, "one": { branch: &Branch{ - Children: [16]node.Node{ - &leaf.Leaf{}, + Children: [16]Node{ + &Leaf{}, }, }, count: 1, }, "two": { branch: &Branch{ - Children: [16]node.Node{ - &leaf.Leaf{}, + Children: [16]Node{ + &Leaf{}, nil, nil, nil, - &leaf.Leaf{}, + &Leaf{}, }, }, count: 2, }, "three": { branch: &Branch{ - Children: [16]node.Node{ - &leaf.Leaf{}, + Children: [16]Node{ + &Leaf{}, nil, nil, nil, - &leaf.Leaf{}, + &Leaf{}, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - &leaf.Leaf{}, + &Leaf{}, }, }, count: 3, diff --git a/lib/trie/branch/copy.go b/lib/trie/node/copy.go similarity index 54% rename from lib/trie/branch/copy.go rename to lib/trie/node/copy.go index 6e7d78b6b3..41f1b068d1 100644 --- a/lib/trie/branch/copy.go +++ b/lib/trie/node/copy.go @@ -1,12 +1,10 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package branch - -import "github.com/ChainSafe/gossamer/lib/trie/node" +package node // Copy deep copies the branch. -func (b *Branch) Copy() node.Node { +func (b *Branch) Copy() Node { b.RLock() defer b.RUnlock() @@ -31,3 +29,26 @@ func (b *Branch) Copy() node.Node { copy(cpy.Encoding, b.Encoding) return cpy } + +// Copy deep copies the leaf. +func (l *Leaf) Copy() Node { + l.RLock() + defer l.RUnlock() + + l.encodingMu.RLock() + defer l.encodingMu.RUnlock() + + cpy := &Leaf{ + Key: make([]byte, len(l.Key)), + Value: make([]byte, len(l.Value)), + Dirty: l.Dirty, + Hash: make([]byte, len(l.Hash)), + Encoding: make([]byte, len(l.Encoding)), + Generation: l.Generation, + } + copy(cpy.Key, l.Key) + copy(cpy.Value, l.Value) + copy(cpy.Hash, l.Hash) + copy(cpy.Encoding, l.Encoding) + return cpy +} diff --git a/lib/trie/branch/decode.go b/lib/trie/node/decode.go similarity index 63% rename from lib/trie/branch/decode.go rename to lib/trie/node/decode.go index 118f071386..a640196c4f 100644 --- a/lib/trie/branch/decode.go +++ b/lib/trie/node/decode.go @@ -1,32 +1,31 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package branch +package node import ( "errors" "fmt" "io" - "github.com/ChainSafe/gossamer/lib/trie/decode" - "github.com/ChainSafe/gossamer/lib/trie/leaf" "github.com/ChainSafe/gossamer/pkg/scale" ) var ( ErrReadHeaderByte = errors.New("cannot read header byte") ErrNodeTypeIsNotABranch = errors.New("node type is not a branch") - ErrReadChildrenBitmap = errors.New("cannot read children bitmap") + ErrNodeTypeIsNotALeaf = errors.New("node type is not a leaf") ErrDecodeValue = errors.New("cannot decode value") + ErrReadChildrenBitmap = errors.New("cannot read children bitmap") ErrDecodeChildHash = errors.New("cannot decode child hash") ) -// Decode reads and decodes from a reader with the encoding specified in lib/trie/encode/doc.go. +// DecodeBranch reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. // Note that since the encoded branch stores the hash of the children nodes, we are not // reconstructing the child nodes from the encoding. This function instead stubs where the // children are known to be with an empty leaf. The children nodes hashes are then used to // find other values using the persistent database. -func Decode(reader io.Reader, header byte) (branch *Branch, err error) { +func DecodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { nodeType := header >> 6 if nodeType != 2 && nodeType != 3 { return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotABranch, nodeType) @@ -35,7 +34,7 @@ func Decode(reader io.Reader, header byte) (branch *Branch, err error) { branch = new(Branch) keyLen := header & 0x3f - branch.Key, err = decode.Key(reader, keyLen) + branch.Key, err = decodeKey(reader, keyLen) if err != nil { return nil, fmt.Errorf("cannot decode key: %w", err) } @@ -69,7 +68,7 @@ func Decode(reader io.Reader, header byte) (branch *Branch, err error) { ErrDecodeChildHash, i, err) } - branch.Children[i] = &leaf.Leaf{ + branch.Children[i] = &Leaf{ Hash: hash, } } @@ -78,3 +77,34 @@ func Decode(reader io.Reader, header byte) (branch *Branch, err error) { return branch, nil } + +// DecodeLeaf reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. +func DecodeLeaf(r io.Reader, header byte) (leaf *Leaf, err error) { + nodeType := header >> 6 + if nodeType != 1 { + return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotALeaf, nodeType) + } + + leaf = new(Leaf) + + keyLen := header & 0x3f + leaf.Key, err = decodeKey(r, keyLen) + if err != nil { + return nil, fmt.Errorf("cannot decode key: %w", err) + } + + sd := scale.NewDecoder(r) + var value []byte + err = sd.Decode(&value) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrDecodeValue, err) + } + + if len(value) > 0 { + leaf.Value = value + } + + leaf.Dirty = true + + return leaf, nil +} diff --git a/lib/trie/branch/decode_test.go b/lib/trie/node/decode_test.go similarity index 63% rename from lib/trie/branch/decode_test.go rename to lib/trie/node/decode_test.go index db2ce1f43f..daf64b5992 100644 --- a/lib/trie/branch/decode_test.go +++ b/lib/trie/node/decode_test.go @@ -1,16 +1,13 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package branch +package node import ( "bytes" "io" "testing" - "github.com/ChainSafe/gossamer/lib/trie/decode" - "github.com/ChainSafe/gossamer/lib/trie/leaf" - "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/ChainSafe/gossamer/pkg/scale" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -34,7 +31,7 @@ func concatByteSlices(slices [][]byte) (concatenated []byte) { return concatenated } -func Test_Decode(t *testing.T) { +func Test_DecodeBranch(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -55,7 +52,7 @@ func Test_Decode(t *testing.T) { // missing key data byte }), header: 129, // node type 2 and key length 1 - errWrapped: decode.ErrReadKeyData, + errWrapped: ErrReadKeyData, errMessage: "cannot decode key: cannot read key data: EOF", }, "children bitmap read error": { @@ -90,10 +87,10 @@ func Test_Decode(t *testing.T) { header: 129, // node type 2 and key length 1 branch: &Branch{ Key: []byte{9}, - Children: [16]node.Node{ + Children: [16]Node{ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - &leaf.Leaf{ + &Leaf{ Hash: []byte{1, 2, 3, 4, 5}, }, }, @@ -125,10 +122,10 @@ func Test_Decode(t *testing.T) { branch: &Branch{ Key: []byte{9}, Value: []byte{7, 8, 9}, - Children: [16]node.Node{ + Children: [16]Node{ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, - &leaf.Leaf{ + &Leaf{ Hash: []byte{1, 2, 3, 4, 5}, }, }, @@ -142,7 +139,7 @@ func Test_Decode(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - branch, err := Decode(testCase.reader, testCase.header) + branch, err := DecodeBranch(testCase.reader, testCase.header) assert.ErrorIs(t, err, testCase.errWrapped) if err != nil { @@ -152,3 +149,79 @@ func Test_Decode(t *testing.T) { }) } } + +func Test_DecodeLeaf(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reader io.Reader + header byte + leaf *Leaf + errWrapped error + errMessage string + }{ + "no data with header 1": { + reader: bytes.NewBuffer(nil), + header: 1, + errWrapped: ErrNodeTypeIsNotALeaf, + errMessage: "node type is not a leaf: 0", + }, + "key decoding error": { + reader: bytes.NewBuffer([]byte{ + // missing key data byte + }), + header: 65, // node type 1 and key length 1 + errWrapped: ErrReadKeyData, + errMessage: "cannot decode key: cannot read key data: EOF", + }, + "value decoding error": { + reader: bytes.NewBuffer([]byte{ + 9, // key data + // missing value data + }), + header: 65, // node type 1 and key length 1 + errWrapped: ErrDecodeValue, + errMessage: "cannot decode value: EOF", + }, + "zero value": { + reader: bytes.NewBuffer([]byte{ + 9, // key data + 0, // missing value data + }), + header: 65, // node type 1 and key length 1 + leaf: &Leaf{ + Key: []byte{9}, + Dirty: true, + }, + }, + "success": { + reader: bytes.NewBuffer( + concatByteSlices([][]byte{ + {9}, // key data + scaleEncodeBytes(t, 1, 2, 3, 4, 5), // value data + }), + ), + header: 65, // node type 1 and key length 1 + leaf: &Leaf{ + Key: []byte{9}, + Value: []byte{1, 2, 3, 4, 5}, + Dirty: true, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + leaf, err := DecodeLeaf(testCase.reader, testCase.header) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.leaf, leaf) + }) + } +} diff --git a/lib/trie/branch/dirty.go b/lib/trie/node/dirty.go similarity index 57% rename from lib/trie/branch/dirty.go rename to lib/trie/node/dirty.go index 930c01fa91..7922139b18 100644 --- a/lib/trie/branch/dirty.go +++ b/lib/trie/node/dirty.go @@ -1,7 +1,7 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package branch +package node // IsDirty returns the dirty status of the branch. func (b *Branch) IsDirty() bool { @@ -12,3 +12,13 @@ func (b *Branch) IsDirty() bool { func (b *Branch) SetDirty(dirty bool) { b.Dirty = dirty } + +// IsDirty returns the dirty status of the leaf. +func (l *Leaf) IsDirty() bool { + return l.Dirty +} + +// SetDirty sets the dirty status to the leaf. +func (l *Leaf) SetDirty(dirty bool) { + l.Dirty = dirty +} diff --git a/lib/trie/encodedecode_test/branch_test.go b/lib/trie/node/encode_decode_test.go similarity index 66% rename from lib/trie/encodedecode_test/branch_test.go rename to lib/trie/node/encode_decode_test.go index 2e48656888..cc380060d6 100644 --- a/lib/trie/encodedecode_test/branch_test.go +++ b/lib/trie/node/encode_decode_test.go @@ -1,15 +1,12 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package encodedecode_test +package node import ( "bytes" "testing" - "github.com/ChainSafe/gossamer/lib/trie/branch" - "github.com/ChainSafe/gossamer/lib/trie/leaf" - "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,48 +15,48 @@ func Test_Branch_Encode_Decode(t *testing.T) { t.Parallel() testCases := map[string]struct { - branchToEncode *branch.Branch - branchDecoded *branch.Branch + branchToEncode *Branch + branchDecoded *Branch }{ "empty branch": { - branchToEncode: new(branch.Branch), - branchDecoded: &branch.Branch{ + branchToEncode: new(Branch), + branchDecoded: &Branch{ Key: []byte{}, Dirty: true, }, }, "branch with key 5": { - branchToEncode: &branch.Branch{ + branchToEncode: &Branch{ Key: []byte{5}, }, - branchDecoded: &branch.Branch{ + branchDecoded: &Branch{ Key: []byte{5}, Dirty: true, }, }, "branch with two bytes key": { - branchToEncode: &branch.Branch{ + branchToEncode: &Branch{ Key: []byte{0xf, 0xa}, // note: each byte cannot be larger than 0xf }, - branchDecoded: &branch.Branch{ + branchDecoded: &Branch{ Key: []byte{0xf, 0xa}, Dirty: true, }, }, "branch with child": { - branchToEncode: &branch.Branch{ + branchToEncode: &Branch{ Key: []byte{5}, - Children: [16]node.Node{ - &leaf.Leaf{ + Children: [16]Node{ + &Leaf{ Key: []byte{9}, Value: []byte{10}, }, }, }, - branchDecoded: &branch.Branch{ + branchDecoded: &Branch{ Key: []byte{5}, - Children: [16]node.Node{ - &leaf.Leaf{ + Children: [16]Node{ + &Leaf{ Hash: []byte{0x41, 0x9, 0x4, 0xa}, }, }, @@ -83,7 +80,7 @@ func Test_Branch_Encode_Decode(t *testing.T) { require.NoError(t, err) header := oneBuffer[0] - resultBranch, err := branch.Decode(buffer, header) + resultBranch, err := DecodeBranch(buffer, header) require.NoError(t, err) assert.Equal(t, testCase.branchDecoded, resultBranch) diff --git a/lib/trie/encode/doc.go b/lib/trie/node/encode_doc.go similarity index 92% rename from lib/trie/encode/doc.go rename to lib/trie/node/encode_doc.go index e2fc9fd64d..6cb2699d59 100644 --- a/lib/trie/encode/doc.go +++ b/lib/trie/node/encode_doc.go @@ -1,7 +1,4 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package encode +package node //nolint:lll // Modified Merkle-Patricia Trie diff --git a/lib/trie/node/encode_test.go b/lib/trie/node/encode_test.go new file mode 100644 index 0000000000..ae8593ba23 --- /dev/null +++ b/lib/trie/node/encode_test.go @@ -0,0 +1,11 @@ +package node + +import "errors" + +type writeCall struct { + written []byte + n int + err error +} + +var errTest = errors.New("test error") diff --git a/lib/trie/branch/generation.go b/lib/trie/node/generation.go similarity index 56% rename from lib/trie/branch/generation.go rename to lib/trie/node/generation.go index a5d8f4e510..ac7adc6aca 100644 --- a/lib/trie/branch/generation.go +++ b/lib/trie/node/generation.go @@ -1,7 +1,7 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package branch +package node // SetGeneration sets the generation given to the branch. func (b *Branch) SetGeneration(generation uint64) { @@ -12,3 +12,13 @@ func (b *Branch) SetGeneration(generation uint64) { func (b *Branch) GetGeneration() uint64 { return b.Generation } + +// SetGeneration sets the generation given to the leaf. +func (l *Leaf) SetGeneration(generation uint64) { + l.Generation = generation +} + +// GetGeneration returns the generation of the leaf. +func (l *Leaf) GetGeneration() uint64 { + return l.Generation +} diff --git a/lib/trie/branch/hash.go b/lib/trie/node/hash.go similarity index 99% rename from lib/trie/branch/hash.go rename to lib/trie/node/hash.go index d826dc7ba8..14086a732c 100644 --- a/lib/trie/branch/hash.go +++ b/lib/trie/node/hash.go @@ -1,7 +1,7 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package branch +package node import ( "bytes" diff --git a/lib/trie/branch/header.go b/lib/trie/node/header.go similarity index 54% rename from lib/trie/branch/header.go rename to lib/trie/node/header.go index bbc7a683b3..b5d28a6415 100644 --- a/lib/trie/branch/header.go +++ b/lib/trie/node/header.go @@ -1,12 +1,10 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package branch +package node import ( "io" - - "github.com/ChainSafe/gossamer/lib/trie/encode" ) // encodeHeader creates the encoded header for the branch. @@ -25,7 +23,7 @@ func (b *Branch) encodeHeader(writer io.Writer) (err error) { return err } - err = encode.KeyLength(len(b.Key), writer) + err = encodeKeyLength(len(b.Key), writer) if err != nil { return err } @@ -39,3 +37,27 @@ func (b *Branch) encodeHeader(writer io.Writer) (err error) { return nil } + +// encodeHeader creates the encoded header for the leaf. +func (l *Leaf) encodeHeader(writer io.Writer) (err error) { + var header byte = 1 << 6 + + if len(l.Key) < 63 { + header = header | byte(len(l.Key)) + _, err = writer.Write([]byte{header}) + return err + } + + header = header | 0x3f + _, err = writer.Write([]byte{header}) + if err != nil { + return err + } + + err = encodeKeyLength(len(l.Key), writer) + if err != nil { + return err + } + + return nil +} diff --git a/lib/trie/branch/header_test.go b/lib/trie/node/header_test.go similarity index 51% rename from lib/trie/branch/header_test.go rename to lib/trie/node/header_test.go index 2209518549..78f40344ca 100644 --- a/lib/trie/branch/header_test.go +++ b/lib/trie/node/header_test.go @@ -1,13 +1,12 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package branch +package node import ( "testing" - "github.com/ChainSafe/gossamer/lib/trie/encode" - gomock "github.com/golang/mock/gomock" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) @@ -73,7 +72,7 @@ func Test_Branch_encodeHeader(t *testing.T) { writes: []writeCall{ {written: []byte{0xbf}}, }, - errWrapped: encode.ErrPartialKeyTooBig, + errWrapped: ErrPartialKeyTooBig, errMessage: "partial key length cannot be larger than or equal to 2^16: 65535", }, "small key length write error": { @@ -130,3 +129,117 @@ func Test_Branch_encodeHeader(t *testing.T) { }) } } + +func Test_Leaf_encodeHeader(t *testing.T) { + testCases := map[string]struct { + leaf *Leaf + writes []writeCall + errWrapped error + errMessage string + }{ + "no key": { + leaf: &Leaf{}, + writes: []writeCall{ + {written: []byte{0x40}}, + }, + }, + "key of length 30": { + leaf: &Leaf{ + Key: make([]byte, 30), + }, + writes: []writeCall{ + {written: []byte{0x5e}}, + }, + }, + "short key write error": { + leaf: &Leaf{ + Key: make([]byte, 30), + }, + writes: []writeCall{ + { + written: []byte{0x5e}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: errTest.Error(), + }, + "key of length 62": { + leaf: &Leaf{ + Key: make([]byte, 62), + }, + writes: []writeCall{ + {written: []byte{0x7e}}, + }, + }, + "key of length 63": { + leaf: &Leaf{ + Key: make([]byte, 63), + }, + writes: []writeCall{ + {written: []byte{0x7f}}, + {written: []byte{0x0}}, + }, + }, + "key of length 64": { + leaf: &Leaf{ + Key: make([]byte, 64), + }, + writes: []writeCall{ + {written: []byte{0x7f}}, + {written: []byte{0x1}}, + }, + }, + "long key first byte write error": { + leaf: &Leaf{ + Key: make([]byte, 63), + }, + writes: []writeCall{ + { + written: []byte{0x7f}, + err: errTest, + }, + }, + errWrapped: errTest, + errMessage: errTest.Error(), + }, + "key too big": { + leaf: &Leaf{ + Key: make([]byte, 65535+63), + }, + writes: []writeCall{ + {written: []byte{0x7f}}, + }, + errWrapped: ErrPartialKeyTooBig, + errMessage: "partial key length cannot be larger than or equal to 2^16: 65535", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + writer := NewMockWriter(ctrl) + var previousCall *gomock.Call + for _, write := range testCase.writes { + call := writer.EXPECT(). + Write(write.written). + Return(write.n, write.err) + + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } + + err := testCase.leaf.encodeHeader(writer) + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} diff --git a/lib/trie/decode/key.go b/lib/trie/node/key.go similarity index 55% rename from lib/trie/decode/key.go rename to lib/trie/node/key.go index 4ab24adaff..2248efd963 100644 --- a/lib/trie/decode/key.go +++ b/lib/trie/node/key.go @@ -1,7 +1,7 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package decode +package node import ( "bytes" @@ -9,9 +9,24 @@ import ( "fmt" "io" + "github.com/ChainSafe/gossamer/lib/trie/codec" "github.com/ChainSafe/gossamer/lib/trie/pools" ) +// SetKey sets the key to the branch. +// Note it does not copy it so modifying the passed key +// will modify the key stored in the branch. +func (b *Branch) SetKey(key []byte) { + b.Key = key +} + +// SetKey sets the key to the leaf. +// Note it does not copy it so modifying the passed key +// will modify the key stored in the leaf. +func (l *Leaf) SetKey(key []byte) { + l.Key = key +} + const maxPartialKeySize = ^uint16(0) var ( @@ -20,8 +35,36 @@ var ( ErrReadKeyData = errors.New("cannot read key data") ) -// Key decodes a key from a reader. -func Key(reader io.Reader, keyLength byte) (b []byte, err error) { +// encodeKeyLength encodes the key length. +func encodeKeyLength(keyLength int, writer io.Writer) (err error) { + keyLength -= 63 + + if keyLength >= int(maxPartialKeySize) { + return fmt.Errorf("%w: %d", + ErrPartialKeyTooBig, keyLength) + } + + for i := uint16(0); i < maxPartialKeySize; i++ { + if keyLength < 255 { + _, err = writer.Write([]byte{byte(keyLength)}) + if err != nil { + return err + } + break + } + _, err = writer.Write([]byte{255}) + if err != nil { + return err + } + + keyLength -= 255 + } + + return nil +} + +// decodeKey decodes a key from a reader. +func decodeKey(reader io.Reader, keyLength byte) (b []byte, err error) { publicKeyLength := int(keyLength) if keyLength == 0x3f { @@ -62,24 +105,5 @@ func Key(reader io.Reader, keyLength byte) (b []byte, err error) { ErrReadKeyData, n, len(key)) } - return KeyLEToNibbles(key)[publicKeyLength%2:], nil -} - -// KeyLEToNibbles converts a Little Endian byte slice into nibbles. -// It assumes bytes are already in Little Endian and does not rearrange nibbles. -func KeyLEToNibbles(in []byte) (nibbles []byte) { - if len(in) == 0 { - return []byte{} - } else if len(in) == 1 && in[0] == 0 { - return []byte{0, 0} - } - - l := len(in) * 2 - nibbles = make([]byte, l) - for i, b := range in { - nibbles[2*i] = b / 16 - nibbles[2*i+1] = b % 16 - } - - return nibbles + return codec.KeyLEToNibbles(key)[publicKeyLength%2:], nil } diff --git a/lib/trie/encode/key_test.go b/lib/trie/node/key_test.go similarity index 59% rename from lib/trie/encode/key_test.go rename to lib/trie/node/key_test.go index 058be5643e..088ab38eb4 100644 --- a/lib/trie/encode/key_test.go +++ b/lib/trie/node/key_test.go @@ -1,11 +1,8 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package encode +package node import ( "bytes" - "errors" + "io" "testing" "github.com/golang/mock/gomock" @@ -13,17 +10,15 @@ import ( "github.com/stretchr/testify/require" ) -type writeCall struct { - written []byte - n int - err error +func repeatBytes(n int, b byte) (slice []byte) { + slice = make([]byte, n) + for i := range slice { + slice[i] = b + } + return slice } -var errTest = errors.New("test error") - -//go:generate mockgen -destination=writer_mock_test.go -package $GOPACKAGE io Writer - -func Test_KeyLength(t *testing.T) { +func Test_encodeKeyLength(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -112,7 +107,7 @@ func Test_KeyLength(t *testing.T) { previousCall = call } - err := KeyLength(testCase.keyLength, writer) + err := encodeKeyLength(testCase.keyLength, writer) assert.ErrorIs(t, err, testCase.errWrapped) if testCase.errWrapped != nil { @@ -139,46 +134,66 @@ func Test_KeyLength(t *testing.T) { buffer := bytes.NewBuffer(nil) buffer.Grow(expectedEncodingLength) - err := KeyLength(keyLength, buffer) + err := encodeKeyLength(keyLength, buffer) require.NoError(t, err) assert.Equal(t, expectedBytes, buffer.Bytes()) }) } -func Test_NibblesToKeyLE(t *testing.T) { +func Test_decodeKey(t *testing.T) { t.Parallel() testCases := map[string]struct { - nibbles []byte - keyLE []byte + reader io.Reader + keyLength byte + b []byte + errWrapped error + errMessage string }{ - "nil nibbles": { - keyLE: []byte{}, + "zero key length": { + b: []byte{}, }, - "empty nibbles": { - nibbles: []byte{}, - keyLE: []byte{}, + "short key length": { + reader: bytes.NewBuffer([]byte{1, 2, 3}), + keyLength: 5, + b: []byte{0x1, 0x0, 0x2, 0x0, 0x3}, }, - "0xF 0xF": { - nibbles: []byte{0xF, 0xF}, - keyLE: []byte{0xFF}, + "key read error": { + reader: bytes.NewBuffer(nil), + keyLength: 5, + errWrapped: ErrReadKeyData, + errMessage: "cannot read key data: EOF", }, - "0x3 0xa 0x0 0x5": { - nibbles: []byte{0x3, 0xa, 0x0, 0x5}, - keyLE: []byte{0x3a, 0x05}, + "long key length": { + reader: bytes.NewBuffer( + append( + []byte{ + 6, // key length + }, + repeatBytes(64, 7)..., // key data + )), + keyLength: 0x3f, + b: []byte{ + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, + 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7}, }, - "0xa 0xa 0xf 0xf 0x0 0x1": { - nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1}, - keyLE: []byte{0xaa, 0xff, 0x01}, + "key length read error": { + reader: bytes.NewBuffer(nil), + keyLength: 0x3f, + errWrapped: ErrReadKeyLength, + errMessage: "cannot read key length: EOF", }, - "0xa 0xa 0xf 0xf 0x0 0x1 0xc 0x2": { - nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc, 0x2}, - keyLE: []byte{0xaa, 0xff, 0x01, 0xc2}, - }, - "0xa 0xa 0xf 0xf 0x0 0x1 0xc": { - nibbles: []byte{0xa, 0xa, 0xf, 0xf, 0x0, 0x1, 0xc}, - keyLE: []byte{0xa, 0xaf, 0xf0, 0x1c}, + "key length too big": { + reader: bytes.NewBuffer(repeatBytes(257, 0xff)), + keyLength: 0x3f, + errWrapped: ErrPartialKeyTooBig, + errMessage: "partial key length cannot be larger than or equal to 2^16: 65598", }, } @@ -187,9 +202,13 @@ func Test_NibblesToKeyLE(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - keyLE := NibblesToKeyLE(testCase.nibbles) + b, err := decodeKey(testCase.reader, testCase.keyLength) - assert.Equal(t, testCase.keyLE, keyLE) + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.b, b) }) } } diff --git a/lib/trie/leaf/leaf.go b/lib/trie/node/leaf.go similarity index 87% rename from lib/trie/leaf/leaf.go rename to lib/trie/node/leaf.go index 7e86ad9c95..d77fe3339f 100644 --- a/lib/trie/leaf/leaf.go +++ b/lib/trie/node/leaf.go @@ -1,17 +1,16 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package leaf +package node import ( "fmt" "sync" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/node" ) -var _ node.Node = (*Leaf)(nil) +var _ Node = (*Leaf)(nil) // Leaf is a leaf in the trie. type Leaf struct { diff --git a/lib/trie/leaf/encode.go b/lib/trie/node/leaf_encode.go similarity index 96% rename from lib/trie/leaf/encode.go rename to lib/trie/node/leaf_encode.go index 477d6dd78d..8caafb10e5 100644 --- a/lib/trie/leaf/encode.go +++ b/lib/trie/node/leaf_encode.go @@ -1,7 +1,7 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package leaf +package node import ( "bytes" @@ -10,7 +10,7 @@ import ( "io" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/encode" + "github.com/ChainSafe/gossamer/lib/trie/codec" "github.com/ChainSafe/gossamer/lib/trie/pools" "github.com/ChainSafe/gossamer/pkg/scale" ) @@ -85,7 +85,7 @@ func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { // Encode encodes a leaf to the buffer given. // The encoding has the following format: // NodeHeader | Extra partial key length | Partial Key | Value -func (l *Leaf) Encode(buffer encode.Buffer) (err error) { +func (l *Leaf) Encode(buffer Buffer) (err error) { l.encodingMu.RLock() if !l.Dirty && l.Encoding != nil { _, err = buffer.Write(l.Encoding) @@ -102,7 +102,7 @@ func (l *Leaf) Encode(buffer encode.Buffer) (err error) { return fmt.Errorf("cannot encode header: %w", err) } - keyLE := encode.NibblesToKeyLE(l.Key) + keyLE := codec.NibblesToKeyLE(l.Key) _, err = buffer.Write(keyLE) if err != nil { return fmt.Errorf("cannot write LE key to buffer: %w", err) diff --git a/lib/trie/leaf/encode_test.go b/lib/trie/node/leaf_encode_test.go similarity index 96% rename from lib/trie/leaf/encode_test.go rename to lib/trie/node/leaf_encode_test.go index 51646719b2..2edcb0e289 100644 --- a/lib/trie/leaf/encode_test.go +++ b/lib/trie/node/leaf_encode_test.go @@ -1,26 +1,16 @@ // Copyright 2021 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package leaf +package node import ( - "errors" "testing" - "github.com/ChainSafe/gossamer/lib/trie/encode" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type writeCall struct { - written []byte - n int - err error -} - -var errTest = errors.New("test error") - //go:generate mockgen -destination=buffer_mock_test.go -package $GOPACKAGE github.com/ChainSafe/gossamer/lib/trie/encode Buffer func Test_Leaf_Encode(t *testing.T) { @@ -70,7 +60,7 @@ func Test_Leaf_Encode(t *testing.T) { written: []byte{127}, }, }, - wrappedErr: encode.ErrPartialKeyTooBig, + wrappedErr: ErrPartialKeyTooBig, errMessage: "cannot encode header: partial key length cannot be larger than or equal to 2^16: 65536", }, "buffer write error for encoded key": { diff --git a/lib/trie/node/interface.go b/lib/trie/node/node.go similarity index 77% rename from lib/trie/node/interface.go rename to lib/trie/node/node.go index 1e52a39204..63fe57443f 100644 --- a/lib/trie/node/interface.go +++ b/lib/trie/node/node.go @@ -3,13 +3,9 @@ package node -import ( - "github.com/ChainSafe/gossamer/lib/trie/encode" -) - // Node is node in the trie and can be a leaf or a branch. type Node interface { - Encode(buffer encode.Buffer) (err error) // TODO change to io.Writer + Encode(buffer Buffer) (err error) // TODO change to io.Writer EncodeAndHash() ([]byte, []byte, error) ScaleEncodeHash() (b []byte, err error) IsDirty() bool diff --git a/lib/trie/leaf/writer_mock_test.go b/lib/trie/node/writer_mock_test.go similarity index 95% rename from lib/trie/leaf/writer_mock_test.go rename to lib/trie/node/writer_mock_test.go index 04cb474a72..9665f01c85 100644 --- a/lib/trie/leaf/writer_mock_test.go +++ b/lib/trie/node/writer_mock_test.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. // Source: io (interfaces: Writer) -// Package leaf is a generated GoMock package. -package leaf +// Package node is a generated GoMock package. +package node import ( reflect "reflect" diff --git a/lib/trie/print.go b/lib/trie/print.go index ce3f85a979..4cbec8c39c 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -8,8 +8,6 @@ import ( "fmt" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/branch" - "github.com/ChainSafe/gossamer/lib/trie/leaf" "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/ChainSafe/gossamer/lib/trie/pools" @@ -29,7 +27,7 @@ func (t *Trie) String() string { func (t *Trie) string(tree gotree.Tree, curr node.Node, idx int) { switch c := curr.(type) { - case *branch.Branch: + case *node.Branch: buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) buffer.Reset() @@ -51,7 +49,7 @@ func (t *Trie) string(tree gotree.Tree, curr node.Node, idx int) { t.string(sub, child, i) } } - case *leaf.Leaf: + case *node.Leaf: buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) buffer.Reset() diff --git a/lib/trie/proof.go b/lib/trie/proof.go index 094fe48f7e..e79c47e28e 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -8,7 +8,7 @@ import ( "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/decode" + "github.com/ChainSafe/gossamer/lib/trie/codec" "github.com/ChainSafe/gossamer/lib/trie/record" ) @@ -39,7 +39,7 @@ func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, e } for _, k := range keys { - nk := decode.KeyLEToNibbles(k) + nk := codec.KeyLEToNibbles(k) recorder := record.NewRecorder() err := findAndRecord(proofTrie, nk, recorder) diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 1995399974..4595d7cbff 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -8,10 +8,7 @@ import ( "fmt" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/branch" - "github.com/ChainSafe/gossamer/lib/trie/decode" - "github.com/ChainSafe/gossamer/lib/trie/encode" - "github.com/ChainSafe/gossamer/lib/trie/leaf" + "github.com/ChainSafe/gossamer/lib/trie/codec" "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/ChainSafe/gossamer/lib/trie/pools" ) @@ -151,15 +148,15 @@ func (t *Trie) Entries() map[string][]byte { func (t *Trie) entries(current node.Node, prefix []byte, kv map[string][]byte) map[string][]byte { switch c := current.(type) { - case *branch.Branch: + case *node.Branch: if c.Value != nil { - kv[string(encode.NibblesToKeyLE(append(prefix, c.Key...)))] = c.Value + kv[string(codec.NibblesToKeyLE(append(prefix, c.Key...)))] = c.Value } for i, child := range c.Children { t.entries(child, append(prefix, append(c.Key, byte(i))...), kv) } - case *leaf.Leaf: - kv[string(encode.NibblesToKeyLE(append(prefix, c.Key...)))] = c.Value + case *node.Leaf: + kv[string(codec.NibblesToKeyLE(append(prefix, c.Key...)))] = c.Value return kv } @@ -168,19 +165,19 @@ func (t *Trie) entries(current node.Node, prefix []byte, kv map[string][]byte) m // NextKey returns the next key in the trie in lexicographic order. It returns nil if there is no next key func (t *Trie) NextKey(key []byte) []byte { - k := decode.KeyLEToNibbles(key) + k := codec.KeyLEToNibbles(key) next := t.nextKey(t.root, nil, k) if next == nil { return nil } - return encode.NibblesToKeyLE(next) + return codec.NibblesToKeyLE(next) } func (t *Trie) nextKey(curr node.Node, prefix, key []byte) []byte { switch c := curr.(type) { - case *branch.Branch: + case *node.Branch: fullKey := append(prefix, c.Key...) var cmp int if len(key) < len(fullKey) { @@ -229,7 +226,7 @@ func (t *Trie) nextKey(curr node.Node, prefix, key []byte) []byte { } } } - case *leaf.Leaf: + case *node.Leaf: fullKey := append(prefix, c.Key...) var cmp int if len(key) < len(fullKey) { @@ -259,15 +256,15 @@ func (t *Trie) Put(key, value []byte) { } func (t *Trie) tryPut(key, value []byte) { - k := decode.KeyLEToNibbles(key) + k := codec.KeyLEToNibbles(key) - t.root = t.insert(t.root, k, &leaf.Leaf{Key: nil, Value: value, Dirty: true, Generation: t.generation}) + t.root = t.insert(t.root, k, &node.Leaf{Key: nil, Value: value, Dirty: true, Generation: t.generation}) } // insert attempts to insert a key with value into the trie func (t *Trie) insert(parent node.Node, key []byte, value node.Node) node.Node { switch p := t.maybeUpdateGeneration(parent).(type) { - case *branch.Branch: + case *node.Branch: n := t.updateBranch(p, key, value) if p != nil && n != nil && n.IsDirty() { @@ -277,12 +274,12 @@ func (t *Trie) insert(parent node.Node, key []byte, value node.Node) node.Node { case nil: value.SetKey(key) return value - case *leaf.Leaf: + case *node.Leaf: // if a value already exists in the trie at this key, overwrite it with the new value // if the values are the same, don't mark node dirty if p.Value != nil && bytes.Equal(p.Key, key) { - if !bytes.Equal(value.(*leaf.Leaf).Value, p.Value) { - p.Value = value.(*leaf.Leaf).Value + if !bytes.Equal(value.(*node.Leaf).Value, p.Value) { + p.Value = value.(*node.Leaf).Value p.Dirty = true } return p @@ -291,12 +288,12 @@ func (t *Trie) insert(parent node.Node, key []byte, value node.Node) node.Node { length := lenCommonPrefix(key, p.Key) // need to convert this leaf into a branch - br := &branch.Branch{Key: key[:length], Dirty: true, Generation: t.generation} + br := &node.Branch{Key: key[:length], Dirty: true, Generation: t.generation} parentKey := p.Key // value goes at this branch if len(key) == length { - br.Value = value.(*leaf.Leaf).Value + br.Value = value.(*node.Leaf).Value br.SetDirty(true) // if we are not replacing previous leaf, then add it as a child to the new branch @@ -333,7 +330,7 @@ func (t *Trie) insert(parent node.Node, key []byte, value node.Node) node.Node { // updateBranch attempts to add the value node to a branch // inserts the value node as the branch's child at the index that's // the first nibble of the key -func (t *Trie) updateBranch(p *branch.Branch, key []byte, value node.Node) (n node.Node) { +func (t *Trie) updateBranch(p *node.Branch, key []byte, value node.Node) (n node.Node) { length := lenCommonPrefix(key, p.Key) // whole parent key matches @@ -342,16 +339,16 @@ func (t *Trie) updateBranch(p *branch.Branch, key []byte, value node.Node) (n no if bytes.Equal(key, p.Key) { p.SetDirty(true) switch v := value.(type) { - case *branch.Branch: + case *node.Branch: p.Value = v.Value - case *leaf.Leaf: + case *node.Leaf: p.Value = v.Value } return p } switch c := p.Children[key[length]].(type) { - case *branch.Branch, *leaf.Leaf: + case *node.Branch, *node.Leaf: n = t.insert(c, key[length+1:], value) p.Children[key[length]] = n n.SetDirty(true) @@ -359,7 +356,7 @@ func (t *Trie) updateBranch(p *branch.Branch, key []byte, value node.Node) (n no return p case nil: // otherwise, add node as child of this branch - value.(*leaf.Leaf).Key = key[length+1:] + value.(*node.Leaf).Key = key[length+1:] p.Children[key[length]] = value p.SetDirty(true) return p @@ -370,13 +367,13 @@ func (t *Trie) updateBranch(p *branch.Branch, key []byte, value node.Node) (n no // we need to branch out at the point where the keys diverge // update partial keys, new branch has key up to matching length - br := &branch.Branch{Key: key[:length], Dirty: true, Generation: t.generation} + br := &node.Branch{Key: key[:length], Dirty: true, Generation: t.generation} parentIndex := p.Key[length] br.Children[parentIndex] = t.insert(nil, p.Key[length+1:], p) if len(key) <= length { - br.Value = value.(*leaf.Leaf).Value + br.Value = value.(*node.Leaf).Value } else { br.Children[key[length]] = t.insert(nil, key[length+1:], value) } @@ -406,7 +403,7 @@ func (t *Trie) LoadFromMap(data map[string]string) error { func (t *Trie) GetKeysWithPrefix(prefix []byte) [][]byte { var p []byte if len(prefix) != 0 { - p = decode.KeyLEToNibbles(prefix) + p = codec.KeyLEToNibbles(prefix) if p[len(p)-1] == 0 { p = p[:len(p)-1] } @@ -417,7 +414,7 @@ func (t *Trie) GetKeysWithPrefix(prefix []byte) [][]byte { func (t *Trie) getKeysWithPrefix(parent node.Node, prefix, key []byte, keys [][]byte) [][]byte { switch p := parent.(type) { - case *branch.Branch: + case *node.Branch: length := lenCommonPrefix(p.Key, key) if bytes.Equal(p.Key[:length], key) || len(key) == 0 { @@ -433,10 +430,10 @@ func (t *Trie) getKeysWithPrefix(parent node.Node, prefix, key []byte, keys [][] key = key[len(p.Key):] keys = t.getKeysWithPrefix(p.Children[key[0]], append(append(prefix, p.Key...), key[0]), key[1:], keys) - case *leaf.Leaf: + case *node.Leaf: length := lenCommonPrefix(p.Key, key) if bytes.Equal(p.Key[:length], key) || len(key) == 0 { - keys = append(keys, encode.NibblesToKeyLE(append(prefix, p.Key...))) + keys = append(keys, codec.NibblesToKeyLE(append(prefix, p.Key...))) } case nil: return keys @@ -448,16 +445,16 @@ func (t *Trie) getKeysWithPrefix(parent node.Node, prefix, key []byte, keys [][] // it uses the prefix to determine the entire key func (t *Trie) addAllKeys(parent node.Node, prefix []byte, keys [][]byte) [][]byte { switch p := parent.(type) { - case *branch.Branch: + case *node.Branch: if p.Value != nil { - keys = append(keys, encode.NibblesToKeyLE(append(prefix, p.Key...))) + keys = append(keys, codec.NibblesToKeyLE(append(prefix, p.Key...))) } for i, child := range p.Children { keys = t.addAllKeys(child, append(append(prefix, p.Key...), byte(i)), keys) } - case *leaf.Leaf: - keys = append(keys, encode.NibblesToKeyLE(append(prefix, p.Key...))) + case *node.Leaf: + keys = append(keys, codec.NibblesToKeyLE(append(prefix, p.Key...))) case nil: return keys } @@ -475,23 +472,23 @@ func (t *Trie) Get(key []byte) []byte { return l.Value } -func (t *Trie) tryGet(key []byte) *leaf.Leaf { - k := decode.KeyLEToNibbles(key) +func (t *Trie) tryGet(key []byte) *node.Leaf { + k := codec.KeyLEToNibbles(key) return t.retrieve(t.root, k) } -func (t *Trie) retrieve(parent node.Node, key []byte) *leaf.Leaf { +func (t *Trie) retrieve(parent node.Node, key []byte) *node.Leaf { var ( - value *leaf.Leaf + value *node.Leaf ) switch p := parent.(type) { - case *branch.Branch: + case *node.Branch: length := lenCommonPrefix(p.Key, key) // found the value at this node if bytes.Equal(p.Key, key) || len(key) == 0 { - return &leaf.Leaf{Key: p.Key, Value: p.Value, Dirty: false} + return &node.Leaf{Key: p.Key, Value: p.Value, Dirty: false} } // did not find value @@ -500,7 +497,7 @@ func (t *Trie) retrieve(parent node.Node, key []byte) *leaf.Leaf { } value = t.retrieve(p.Children[key[length]], key[length+1:]) - case *leaf.Leaf: + case *node.Leaf: if bytes.Equal(p.Key, key) { value = p } @@ -516,7 +513,7 @@ func (t *Trie) ClearPrefixLimit(prefix []byte, limit uint32) (uint32, bool) { return 0, false } - p := decode.KeyLEToNibbles(prefix) + p := codec.KeyLEToNibbles(prefix) if len(p) > 0 && p[len(p)-1] == 0 { p = p[:len(p)-1] } @@ -533,7 +530,7 @@ func (t *Trie) clearPrefixLimit(cn node.Node, prefix []byte, limit *uint32) (nod curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { - case *branch.Branch: + case *node.Branch: length := lenCommonPrefix(c.Key, prefix) if length == len(prefix) { n, _ := t.deleteNodes(c, []byte{}, limit) @@ -571,7 +568,7 @@ func (t *Trie) clearPrefixLimit(cn node.Node, prefix []byte, limit *uint32) (nod } return curr, curr.IsDirty(), allDeleted - case *leaf.Leaf: + case *node.Leaf: length := lenCommonPrefix(c.Key, prefix) if length == len(prefix) { *limit-- @@ -591,13 +588,13 @@ func (t *Trie) deleteNodes(cn node.Node, prefix []byte, limit *uint32) (node.Nod curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { - case *leaf.Leaf: + case *node.Leaf: if *limit == 0 { return c, false } *limit-- return nil, true - case *branch.Branch: + case *node.Branch: if len(c.Key) != 0 { prefix = append(prefix, c.Key...) } @@ -645,7 +642,7 @@ func (t *Trie) ClearPrefix(prefix []byte) { return } - p := decode.KeyLEToNibbles(prefix) + p := codec.KeyLEToNibbles(prefix) if len(p) > 0 && p[len(p)-1] == 0 { p = p[:len(p)-1] } @@ -656,7 +653,7 @@ func (t *Trie) ClearPrefix(prefix []byte) { func (t *Trie) clearPrefix(cn node.Node, prefix []byte) (node.Node, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { - case *branch.Branch: + case *node.Branch: length := lenCommonPrefix(c.Key, prefix) if length == len(prefix) { @@ -690,7 +687,7 @@ func (t *Trie) clearPrefix(cn node.Node, prefix []byte) (node.Node, bool) { } return curr, curr.IsDirty() - case *leaf.Leaf: + case *node.Leaf: length := lenCommonPrefix(c.Key, prefix) if length == len(prefix) { return nil, true @@ -705,14 +702,14 @@ func (t *Trie) clearPrefix(cn node.Node, prefix []byte) (node.Node, bool) { // Delete removes any existing value for key from the trie. func (t *Trie) Delete(key []byte) { - k := decode.KeyLEToNibbles(key) + k := codec.KeyLEToNibbles(key) t.root, _ = t.delete(t.root, k) } func (t *Trie) delete(parent node.Node, key []byte) (node.Node, bool) { // Store the current node and return it, if the trie is not updated. switch p := t.maybeUpdateGeneration(parent).(type) { - case *branch.Branch: + case *node.Branch: length := lenCommonPrefix(p.Key, key) if bytes.Equal(p.Key, key) || len(key) == 0 { @@ -732,7 +729,7 @@ func (t *Trie) delete(parent node.Node, key []byte) (node.Node, bool) { p.SetDirty(true) n = handleDeletion(p, key) return n, true - case *leaf.Leaf: + case *node.Leaf: if bytes.Equal(key, p.Key) || len(key) == 0 { // Key exists. Delete it. return nil, true @@ -749,14 +746,14 @@ func (t *Trie) delete(parent node.Node, key []byte) (node.Node, bool) { // handleDeletion is called when a value is deleted from a branch // if the updated branch only has 1 child, it should be combined with that child // if the updated branch only has a value, it should be turned into a leaf -func handleDeletion(p *branch.Branch, key []byte) node.Node { +func handleDeletion(p *node.Branch, key []byte) node.Node { var n node.Node = p length := lenCommonPrefix(p.Key, key) bitmap := p.ChildrenBitmap() // if branch has no children, just a value, turn it into a leaf if bitmap == 0 && p.Value != nil { - n = &leaf.Leaf{Key: key[:length], Value: p.Value, Dirty: true} + n = &node.Leaf{Key: key[:length], Value: p.Value, Dirty: true} } else if p.NumChildren() == 1 && p.Value == nil { // there is only 1 child and no value, combine the child branch with this branch // find index of child @@ -770,10 +767,10 @@ func handleDeletion(p *branch.Branch, key []byte) node.Node { child := p.Children[i] switch c := child.(type) { - case *leaf.Leaf: - n = &leaf.Leaf{Key: append(append(p.Key, []byte{byte(i)}...), c.Key...), Value: c.Value} - case *branch.Branch: - br := new(branch.Branch) + case *node.Leaf: + n = &node.Leaf{Key: append(append(p.Key, []byte{byte(i)}...), c.Key...), Value: c.Value} + case *node.Branch: + br := new(node.Branch) br.Key = append(p.Key, append([]byte{byte(i)}, c.Key...)...) // adopt the grandchildren diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index e8c2f885bc..4c8c0acab1 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -21,8 +21,8 @@ import ( "github.com/stretchr/testify/require" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/decode" - "github.com/ChainSafe/gossamer/lib/trie/leaf" + "github.com/ChainSafe/gossamer/lib/trie/codec" + "github.com/ChainSafe/gossamer/lib/trie/node" ) type commonPrefixTest struct { @@ -70,7 +70,7 @@ func TestNewEmptyTrie(t *testing.T) { } func TestNewTrie(t *testing.T) { - trie := NewTrie(&leaf.Leaf{Key: []byte{0}, Value: []byte{17}}) + trie := NewTrie(&node.Leaf{Key: []byte{0}, Value: []byte{17}}) if trie == nil { t.Error("did not initialise trie") } @@ -875,7 +875,7 @@ func TestClearPrefix(t *testing.T) { require.Equal(t, dcTrieHash, ssTrieHash) ssTrie.ClearPrefix(prefix) - prefixNibbles := decode.KeyLEToNibbles(prefix) + prefixNibbles := codec.KeyLEToNibbles(prefix) if len(prefixNibbles) > 0 && prefixNibbles[len(prefixNibbles)-1] == 0 { prefixNibbles = prefixNibbles[:len(prefixNibbles)-1] } @@ -883,7 +883,7 @@ func TestClearPrefix(t *testing.T) { for _, test := range tests { res := ssTrie.Get(test.key) - keyNibbles := decode.KeyLEToNibbles(test.key) + keyNibbles := codec.KeyLEToNibbles(test.key) length := lenCommonPrefix(keyNibbles, prefixNibbles) if length == len(prefixNibbles) { require.Nil(t, res) @@ -944,8 +944,8 @@ func TestClearPrefix_Small(t *testing.T) { } ssTrie.ClearPrefix([]byte("noo")) - require.Equal(t, ssTrie.root, &leaf.Leaf{ - Key: decode.KeyLEToNibbles([]byte("other")), + require.Equal(t, ssTrie.root, &node.Leaf{ + Key: codec.KeyLEToNibbles([]byte("other")), Value: []byte("other"), Dirty: true, }) @@ -1293,7 +1293,7 @@ func TestTrie_ClearPrefixLimit(t *testing.T) { } testFn := func(testCase []Test, prefix []byte) { - prefixNibbles := decode.KeyLEToNibbles(prefix) + prefixNibbles := codec.KeyLEToNibbles(prefix) if len(prefixNibbles) > 0 && prefixNibbles[len(prefixNibbles)-1] == 0 { prefixNibbles = prefixNibbles[:len(prefixNibbles)-1] } @@ -1312,7 +1312,7 @@ func TestTrie_ClearPrefixLimit(t *testing.T) { for _, test := range testCase { val := trieClearPrefix.Get(test.key) - keyNibbles := decode.KeyLEToNibbles(test.key) + keyNibbles := codec.KeyLEToNibbles(test.key) length := lenCommonPrefix(keyNibbles, prefixNibbles) if length == len(prefixNibbles) { @@ -1401,7 +1401,7 @@ func TestTrie_ClearPrefixLimitSnapshot(t *testing.T) { for _, testCase := range cases { for _, prefix := range prefixes { - prefixNibbles := decode.KeyLEToNibbles(prefix) + prefixNibbles := codec.KeyLEToNibbles(prefix) if len(prefixNibbles) > 0 && prefixNibbles[len(prefixNibbles)-1] == 0 { prefixNibbles = prefixNibbles[:len(prefixNibbles)-1] } @@ -1441,7 +1441,7 @@ func TestTrie_ClearPrefixLimitSnapshot(t *testing.T) { for _, test := range testCase { val := ssTrie.Get(test.key) - keyNibbles := decode.KeyLEToNibbles(test.key) + keyNibbles := codec.KeyLEToNibbles(test.key) length := lenCommonPrefix(keyNibbles, prefixNibbles) if length == len(prefixNibbles) { From d35fc0fcbd8d2a44d2bb84900449a33a4f619a5c Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 7 Dec 2021 09:35:21 +0000 Subject: [PATCH 20/50] `lib/trie/node` -> `internal/trie/node` --- {lib => internal}/trie/node/branch.go | 0 {lib => internal}/trie/node/branch_encode.go | 0 {lib => internal}/trie/node/branch_encode_test.go | 0 {lib => internal}/trie/node/buffer.go | 0 {lib => internal}/trie/node/buffer_mock_test.go | 0 {lib => internal}/trie/node/children.go | 0 {lib => internal}/trie/node/children_test.go | 0 {lib => internal}/trie/node/copy.go | 0 {lib => internal}/trie/node/decode.go | 0 {lib => internal}/trie/node/decode_test.go | 0 {lib => internal}/trie/node/dirty.go | 0 {lib => internal}/trie/node/encode_decode_test.go | 0 {lib => internal}/trie/node/encode_doc.go | 0 {lib => internal}/trie/node/encode_test.go | 0 {lib => internal}/trie/node/generation.go | 0 {lib => internal}/trie/node/hash.go | 0 {lib => internal}/trie/node/header.go | 0 {lib => internal}/trie/node/header_test.go | 0 {lib => internal}/trie/node/key.go | 0 {lib => internal}/trie/node/key_test.go | 0 {lib => internal}/trie/node/leaf.go | 0 {lib => internal}/trie/node/leaf_encode.go | 0 {lib => internal}/trie/node/leaf_encode_test.go | 0 {lib => internal}/trie/node/node.go | 0 {lib => internal}/trie/node/types.go | 0 {lib => internal}/trie/node/writer_mock_test.go | 0 lib/trie/database.go | 2 +- lib/trie/decode.go | 2 +- lib/trie/decode_test.go | 2 +- lib/trie/lookup.go | 2 +- lib/trie/node.go | 6 ++++++ lib/trie/print.go | 2 +- lib/trie/trie.go | 2 +- lib/trie/trie_test.go | 2 +- 34 files changed, 13 insertions(+), 7 deletions(-) rename {lib => internal}/trie/node/branch.go (100%) rename {lib => internal}/trie/node/branch_encode.go (100%) rename {lib => internal}/trie/node/branch_encode_test.go (100%) rename {lib => internal}/trie/node/buffer.go (100%) rename {lib => internal}/trie/node/buffer_mock_test.go (100%) rename {lib => internal}/trie/node/children.go (100%) rename {lib => internal}/trie/node/children_test.go (100%) rename {lib => internal}/trie/node/copy.go (100%) rename {lib => internal}/trie/node/decode.go (100%) rename {lib => internal}/trie/node/decode_test.go (100%) rename {lib => internal}/trie/node/dirty.go (100%) rename {lib => internal}/trie/node/encode_decode_test.go (100%) rename {lib => internal}/trie/node/encode_doc.go (100%) rename {lib => internal}/trie/node/encode_test.go (100%) rename {lib => internal}/trie/node/generation.go (100%) rename {lib => internal}/trie/node/hash.go (100%) rename {lib => internal}/trie/node/header.go (100%) rename {lib => internal}/trie/node/header_test.go (100%) rename {lib => internal}/trie/node/key.go (100%) rename {lib => internal}/trie/node/key_test.go (100%) rename {lib => internal}/trie/node/leaf.go (100%) rename {lib => internal}/trie/node/leaf_encode.go (100%) rename {lib => internal}/trie/node/leaf_encode_test.go (100%) rename {lib => internal}/trie/node/node.go (100%) rename {lib => internal}/trie/node/types.go (100%) rename {lib => internal}/trie/node/writer_mock_test.go (100%) create mode 100644 lib/trie/node.go diff --git a/lib/trie/node/branch.go b/internal/trie/node/branch.go similarity index 100% rename from lib/trie/node/branch.go rename to internal/trie/node/branch.go diff --git a/lib/trie/node/branch_encode.go b/internal/trie/node/branch_encode.go similarity index 100% rename from lib/trie/node/branch_encode.go rename to internal/trie/node/branch_encode.go diff --git a/lib/trie/node/branch_encode_test.go b/internal/trie/node/branch_encode_test.go similarity index 100% rename from lib/trie/node/branch_encode_test.go rename to internal/trie/node/branch_encode_test.go diff --git a/lib/trie/node/buffer.go b/internal/trie/node/buffer.go similarity index 100% rename from lib/trie/node/buffer.go rename to internal/trie/node/buffer.go diff --git a/lib/trie/node/buffer_mock_test.go b/internal/trie/node/buffer_mock_test.go similarity index 100% rename from lib/trie/node/buffer_mock_test.go rename to internal/trie/node/buffer_mock_test.go diff --git a/lib/trie/node/children.go b/internal/trie/node/children.go similarity index 100% rename from lib/trie/node/children.go rename to internal/trie/node/children.go diff --git a/lib/trie/node/children_test.go b/internal/trie/node/children_test.go similarity index 100% rename from lib/trie/node/children_test.go rename to internal/trie/node/children_test.go diff --git a/lib/trie/node/copy.go b/internal/trie/node/copy.go similarity index 100% rename from lib/trie/node/copy.go rename to internal/trie/node/copy.go diff --git a/lib/trie/node/decode.go b/internal/trie/node/decode.go similarity index 100% rename from lib/trie/node/decode.go rename to internal/trie/node/decode.go diff --git a/lib/trie/node/decode_test.go b/internal/trie/node/decode_test.go similarity index 100% rename from lib/trie/node/decode_test.go rename to internal/trie/node/decode_test.go diff --git a/lib/trie/node/dirty.go b/internal/trie/node/dirty.go similarity index 100% rename from lib/trie/node/dirty.go rename to internal/trie/node/dirty.go diff --git a/lib/trie/node/encode_decode_test.go b/internal/trie/node/encode_decode_test.go similarity index 100% rename from lib/trie/node/encode_decode_test.go rename to internal/trie/node/encode_decode_test.go diff --git a/lib/trie/node/encode_doc.go b/internal/trie/node/encode_doc.go similarity index 100% rename from lib/trie/node/encode_doc.go rename to internal/trie/node/encode_doc.go diff --git a/lib/trie/node/encode_test.go b/internal/trie/node/encode_test.go similarity index 100% rename from lib/trie/node/encode_test.go rename to internal/trie/node/encode_test.go diff --git a/lib/trie/node/generation.go b/internal/trie/node/generation.go similarity index 100% rename from lib/trie/node/generation.go rename to internal/trie/node/generation.go diff --git a/lib/trie/node/hash.go b/internal/trie/node/hash.go similarity index 100% rename from lib/trie/node/hash.go rename to internal/trie/node/hash.go diff --git a/lib/trie/node/header.go b/internal/trie/node/header.go similarity index 100% rename from lib/trie/node/header.go rename to internal/trie/node/header.go diff --git a/lib/trie/node/header_test.go b/internal/trie/node/header_test.go similarity index 100% rename from lib/trie/node/header_test.go rename to internal/trie/node/header_test.go diff --git a/lib/trie/node/key.go b/internal/trie/node/key.go similarity index 100% rename from lib/trie/node/key.go rename to internal/trie/node/key.go diff --git a/lib/trie/node/key_test.go b/internal/trie/node/key_test.go similarity index 100% rename from lib/trie/node/key_test.go rename to internal/trie/node/key_test.go diff --git a/lib/trie/node/leaf.go b/internal/trie/node/leaf.go similarity index 100% rename from lib/trie/node/leaf.go rename to internal/trie/node/leaf.go diff --git a/lib/trie/node/leaf_encode.go b/internal/trie/node/leaf_encode.go similarity index 100% rename from lib/trie/node/leaf_encode.go rename to internal/trie/node/leaf_encode.go diff --git a/lib/trie/node/leaf_encode_test.go b/internal/trie/node/leaf_encode_test.go similarity index 100% rename from lib/trie/node/leaf_encode_test.go rename to internal/trie/node/leaf_encode_test.go diff --git a/lib/trie/node/node.go b/internal/trie/node/node.go similarity index 100% rename from lib/trie/node/node.go rename to internal/trie/node/node.go diff --git a/lib/trie/node/types.go b/internal/trie/node/types.go similarity index 100% rename from lib/trie/node/types.go rename to internal/trie/node/types.go diff --git a/lib/trie/node/writer_mock_test.go b/internal/trie/node/writer_mock_test.go similarity index 100% rename from lib/trie/node/writer_mock_test.go rename to internal/trie/node/writer_mock_test.go diff --git a/lib/trie/database.go b/lib/trie/database.go index 834b707f2c..534114b0a8 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -8,9 +8,9 @@ import ( "errors" "fmt" + "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie/codec" - "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/ChainSafe/chaindb" ) diff --git a/lib/trie/decode.go b/lib/trie/decode.go index c3305615c7..75f76f9d6c 100644 --- a/lib/trie/decode.go +++ b/lib/trie/decode.go @@ -9,7 +9,7 @@ import ( "fmt" "io" - "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/trie/pools" ) diff --git a/lib/trie/decode_test.go b/lib/trie/decode_test.go index a4c1b4c068..869fee3d41 100644 --- a/lib/trie/decode_test.go +++ b/lib/trie/decode_test.go @@ -8,7 +8,7 @@ import ( "io" "testing" - "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/pkg/scale" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index c0680a88fc..cfe662af66 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -6,7 +6,7 @@ package trie import ( "bytes" - "github.com/ChainSafe/gossamer/lib/trie/node" + "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/trie/record" ) diff --git a/lib/trie/node.go b/lib/trie/node.go new file mode 100644 index 0000000000..758b4dd7ef --- /dev/null +++ b/lib/trie/node.go @@ -0,0 +1,6 @@ +package trie + +import "github.com/ChainSafe/gossamer/internal/trie/node" + +// Node is node in the trie and can be a leaf or a branch. +type Node node.Node diff --git a/lib/trie/print.go b/lib/trie/print.go index 4cbec8c39c..8a91160728 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -7,8 +7,8 @@ import ( "bytes" "fmt" + "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/ChainSafe/gossamer/lib/trie/pools" "github.com/disiqueira/gotree" diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 4595d7cbff..97fc3a7c8b 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -7,9 +7,9 @@ import ( "bytes" "fmt" + "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie/codec" - "github.com/ChainSafe/gossamer/lib/trie/node" "github.com/ChainSafe/gossamer/lib/trie/pools" ) diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index 4c8c0acab1..d5260a7d20 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -20,9 +20,9 @@ import ( "github.com/ChainSafe/chaindb" "github.com/stretchr/testify/require" + "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie/codec" - "github.com/ChainSafe/gossamer/lib/trie/node" ) type commonPrefixTest struct { From 9df5d2c598b1898dd76a5d0799258776a4e1baa1 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 7 Dec 2021 09:41:32 +0000 Subject: [PATCH 21/50] Simplify and fix mock generations --- internal/trie/node/branch_encode_test.go | 4 ---- internal/trie/node/buffer.go | 8 ++++---- internal/trie/node/buffer_mock_test.go | 2 +- internal/trie/node/leaf_encode_test.go | 4 ---- 4 files changed, 5 insertions(+), 13 deletions(-) diff --git a/internal/trie/node/branch_encode_test.go b/internal/trie/node/branch_encode_test.go index 2589e37e0f..60be03aa18 100644 --- a/internal/trie/node/branch_encode_test.go +++ b/internal/trie/node/branch_encode_test.go @@ -11,8 +11,6 @@ import ( "github.com/stretchr/testify/require" ) -//go:generate mockgen -destination=buffer_mock_test.go -package $GOPACKAGE github.com/ChainSafe/gossamer/lib/trie/encode Buffer - func Test_Branch_Encode(t *testing.T) { t.Parallel() @@ -420,8 +418,6 @@ func Test_encodeChildrenSequentially(t *testing.T) { } } -//go:generate mockgen -destination=writer_mock_test.go -package $GOPACKAGE io Writer - func Test_encodeChild(t *testing.T) { t.Parallel() diff --git a/internal/trie/node/buffer.go b/internal/trie/node/buffer.go index c1515f94c6..c4a2e74cf1 100644 --- a/internal/trie/node/buffer.go +++ b/internal/trie/node/buffer.go @@ -5,12 +5,12 @@ package node import "io" +//go:generate mockgen -destination=buffer_mock_test.go -package $GOPACKAGE . Buffer +//go:generate mockgen -destination=writer_mock_test.go -package $GOPACKAGE io Writer + // Buffer is an interface with some methods of *bytes.Buffer. type Buffer interface { - Writer + io.Writer Len() int Bytes() []byte } - -// Writer is the io.Writer interface -type Writer io.Writer diff --git a/internal/trie/node/buffer_mock_test.go b/internal/trie/node/buffer_mock_test.go index 1357336f23..8977a1ed52 100644 --- a/internal/trie/node/buffer_mock_test.go +++ b/internal/trie/node/buffer_mock_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/ChainSafe/gossamer/lib/trie/encode (interfaces: Buffer) +// Source: github.com/ChainSafe/gossamer/internal/trie/node (interfaces: Buffer) // Package node is a generated GoMock package. package node diff --git a/internal/trie/node/leaf_encode_test.go b/internal/trie/node/leaf_encode_test.go index 2edcb0e289..fdac0713c8 100644 --- a/internal/trie/node/leaf_encode_test.go +++ b/internal/trie/node/leaf_encode_test.go @@ -11,8 +11,6 @@ import ( "github.com/stretchr/testify/require" ) -//go:generate mockgen -destination=buffer_mock_test.go -package $GOPACKAGE github.com/ChainSafe/gossamer/lib/trie/encode Buffer - func Test_Leaf_Encode(t *testing.T) { t.Parallel() @@ -194,8 +192,6 @@ func Test_Leaf_ScaleEncodeHash(t *testing.T) { } } -//go:generate mockgen -destination=writer_mock_test.go -package $GOPACKAGE io Writer - func Test_Leaf_hash(t *testing.T) { t.Parallel() From d65d600694673b4bfd91385f641536f51f13330c Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 7 Dec 2021 09:53:47 +0000 Subject: [PATCH 22/50] Add licenses --- internal/trie/node/encode_doc.go | 3 +++ internal/trie/node/encode_test.go | 3 +++ internal/trie/node/key_test.go | 3 +++ lib/trie/node.go | 3 +++ lib/trie/proof.go | 3 +++ lib/trie/record/node.go | 3 +++ lib/trie/record/recorder_test.go | 3 +++ 7 files changed, 21 insertions(+) diff --git a/internal/trie/node/encode_doc.go b/internal/trie/node/encode_doc.go index 6cb2699d59..1a8b6a1c0a 100644 --- a/internal/trie/node/encode_doc.go +++ b/internal/trie/node/encode_doc.go @@ -1,3 +1,6 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + package node //nolint:lll diff --git a/internal/trie/node/encode_test.go b/internal/trie/node/encode_test.go index ae8593ba23..49a41ad0e0 100644 --- a/internal/trie/node/encode_test.go +++ b/internal/trie/node/encode_test.go @@ -1,3 +1,6 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + package node import "errors" diff --git a/internal/trie/node/key_test.go b/internal/trie/node/key_test.go index 088ab38eb4..4ec4985527 100644 --- a/internal/trie/node/key_test.go +++ b/internal/trie/node/key_test.go @@ -1,3 +1,6 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + package node import ( diff --git a/lib/trie/node.go b/lib/trie/node.go index 758b4dd7ef..4a2cec4d8f 100644 --- a/lib/trie/node.go +++ b/lib/trie/node.go @@ -1,3 +1,6 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + package trie import "github.com/ChainSafe/gossamer/internal/trie/node" diff --git a/lib/trie/proof.go b/lib/trie/proof.go index e79c47e28e..c46acad3df 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -1,3 +1,6 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + package trie import ( diff --git a/lib/trie/record/node.go b/lib/trie/record/node.go index eb3299e9bc..19a745c82c 100644 --- a/lib/trie/record/node.go +++ b/lib/trie/record/node.go @@ -1,3 +1,6 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + package record // Node represents a record of a visited node diff --git a/lib/trie/record/recorder_test.go b/lib/trie/record/recorder_test.go index 638661b97a..943f82859d 100644 --- a/lib/trie/record/recorder_test.go +++ b/lib/trie/record/recorder_test.go @@ -1,3 +1,6 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + package record import ( From a443ac50b771336a5435c03be364eff8c6e76a92 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 7 Dec 2021 09:56:49 +0000 Subject: [PATCH 23/50] Use `Node` instead of `node.Node` in `lib/trie` --- internal/trie/node/node.go | 2 +- lib/trie/database.go | 16 ++++++++-------- lib/trie/decode.go | 2 +- lib/trie/decode_test.go | 2 +- lib/trie/lookup.go | 2 +- lib/trie/node.go | 2 +- lib/trie/print.go | 2 +- lib/trie/trie.go | 34 +++++++++++++++++----------------- 8 files changed, 31 insertions(+), 31 deletions(-) diff --git a/internal/trie/node/node.go b/internal/trie/node/node.go index 63fe57443f..b2c6b7d935 100644 --- a/internal/trie/node/node.go +++ b/internal/trie/node/node.go @@ -3,7 +3,7 @@ package node -// Node is node in the trie and can be a leaf or a branch. +// Node is a node in the trie and can be a leaf or a branch. type Node interface { Encode(buffer Buffer) (err error) // TODO change to io.Writer EncodeAndHash() ([]byte, []byte, error) diff --git a/lib/trie/database.go b/lib/trie/database.go index 534114b0a8..b3bb9e635f 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -33,7 +33,7 @@ func (t *Trie) Store(db chaindb.Database) error { return batch.Flush() } -func (t *Trie) store(db chaindb.Batch, curr node.Node) error { +func (t *Trie) store(db chaindb.Batch, curr Node) error { if curr == nil { return nil } @@ -74,7 +74,7 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { return ErrEmptyProof } - mappedNodes := make(map[string]node.Node, len(proof)) + mappedNodes := make(map[string]Node, len(proof)) // map all the proofs hash -> decoded node // and takes the loop to indentify the root node @@ -105,7 +105,7 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { // loadProof is a recursive function that will create all the trie paths based // on the mapped proofs slice starting by the root -func (t *Trie) loadProof(proof map[string]node.Node, curr node.Node) { +func (t *Trie) loadProof(proof map[string]Node, curr Node) { c, ok := curr.(*node.Branch) if !ok { return @@ -150,7 +150,7 @@ func (t *Trie) Load(db chaindb.Database, root common.Hash) error { return t.load(db, t.root) } -func (t *Trie) load(db chaindb.Database, curr node.Node) error { +func (t *Trie) load(db chaindb.Database, curr Node) error { if c, ok := curr.(*node.Branch); ok { for i, child := range c.Children { if child == nil { @@ -183,7 +183,7 @@ func (t *Trie) load(db chaindb.Database, curr node.Node) error { } // GetNodeHashes return hash of each key of the trie. -func (t *Trie) GetNodeHashes(curr node.Node, keys map[common.Hash]struct{}) error { +func (t *Trie) GetNodeHashes(curr Node, keys map[common.Hash]struct{}) error { if c, ok := curr.(*node.Branch); ok { for _, child := range c.Children { if child == nil { @@ -251,7 +251,7 @@ func GetFromDB(db chaindb.Database, root common.Hash, key []byte) ([]byte, error return getFromDB(db, rootNode, k) } -func getFromDB(db chaindb.Database, parent node.Node, key []byte) ([]byte, error) { +func getFromDB(db chaindb.Database, parent Node, key []byte) ([]byte, error) { var value []byte switch p := parent.(type) { @@ -310,7 +310,7 @@ func (t *Trie) WriteDirty(db chaindb.Database) error { return batch.Flush() } -func (t *Trie) writeDirty(db chaindb.Batch, curr node.Node) error { +func (t *Trie) writeDirty(db chaindb.Batch, curr Node) error { if curr == nil || !curr.IsDirty() { return nil } @@ -358,7 +358,7 @@ func (t *Trie) GetInsertedNodeHashes() ([]common.Hash, error) { return t.getInsertedNodeHashes(t.root) } -func (t *Trie) getInsertedNodeHashes(curr node.Node) ([]common.Hash, error) { +func (t *Trie) getInsertedNodeHashes(curr Node) ([]common.Hash, error) { var nodeHashes []common.Hash if curr == nil || !curr.IsDirty() { return nil, nil diff --git a/lib/trie/decode.go b/lib/trie/decode.go index 75f76f9d6c..5f8de9a0e4 100644 --- a/lib/trie/decode.go +++ b/lib/trie/decode.go @@ -18,7 +18,7 @@ var ( ErrUnknownNodeType = errors.New("unknown node type") ) -func decodeNode(reader io.Reader) (n node.Node, err error) { +func decodeNode(reader io.Reader) (n Node, err error) { buffer := pools.SingleByteBuffers.Get().(*bytes.Buffer) defer pools.SingleByteBuffers.Put(buffer) oneByteBuf := buffer.Bytes() diff --git a/lib/trie/decode_test.go b/lib/trie/decode_test.go index 869fee3d41..6d0d39d9e5 100644 --- a/lib/trie/decode_test.go +++ b/lib/trie/decode_test.go @@ -25,7 +25,7 @@ func Test_decodeNode(t *testing.T) { testCases := map[string]struct { reader io.Reader - n node.Node + n Node errWrapped error errMessage string }{ diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index cfe662af66..27a3d4944e 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -21,7 +21,7 @@ func findAndRecord(t *Trie, key []byte, recorder recorder) error { return find(t.root, key, recorder) } -func find(parent node.Node, key []byte, recorder recorder) error { +func find(parent Node, key []byte, recorder recorder) error { enc, hash, err := parent.EncodeAndHash() if err != nil { return err diff --git a/lib/trie/node.go b/lib/trie/node.go index 4a2cec4d8f..8ab60a5455 100644 --- a/lib/trie/node.go +++ b/lib/trie/node.go @@ -5,5 +5,5 @@ package trie import "github.com/ChainSafe/gossamer/internal/trie/node" -// Node is node in the trie and can be a leaf or a branch. +// Node is a node in the trie and can be a leaf or a branch. type Node node.Node diff --git a/lib/trie/print.go b/lib/trie/print.go index 8a91160728..954f35c5d9 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -25,7 +25,7 @@ func (t *Trie) String() string { return fmt.Sprintf("\n%s", tree.Print()) } -func (t *Trie) string(tree gotree.Tree, curr node.Node, idx int) { +func (t *Trie) string(tree gotree.Tree, curr Node, idx int) { switch c := curr.(type) { case *node.Branch: buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 97fc3a7c8b..6dfcd03959 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -21,7 +21,7 @@ var EmptyHash, _ = NewEmptyTrie().Hash() // Use NewTrie to create a trie that sits on top of a database. type Trie struct { generation uint64 - root node.Node + root Node childTries map[common.Hash]*Trie // Used to store the child tries. deletedKeys []common.Hash } @@ -32,7 +32,7 @@ func NewEmptyTrie() *Trie { } // NewTrie creates a trie with an existing root node -func NewTrie(root node.Node) *Trie { +func NewTrie(root Node) *Trie { return &Trie{ root: root, childTries: make(map[common.Hash]*Trie), @@ -62,7 +62,7 @@ func (t *Trie) Snapshot() *Trie { return newTrie } -func (t *Trie) maybeUpdateGeneration(n node.Node) node.Node { +func (t *Trie) maybeUpdateGeneration(n Node) Node { if n == nil { return nil } @@ -101,7 +101,7 @@ func (t *Trie) DeepCopy() (*Trie, error) { } // RootNode returns the root of the trie -func (t *Trie) RootNode() node.Node { +func (t *Trie) RootNode() Node { return t.root } @@ -146,7 +146,7 @@ func (t *Trie) Entries() map[string][]byte { return t.entries(t.root, nil, make(map[string][]byte)) } -func (t *Trie) entries(current node.Node, prefix []byte, kv map[string][]byte) map[string][]byte { +func (t *Trie) entries(current Node, prefix []byte, kv map[string][]byte) map[string][]byte { switch c := current.(type) { case *node.Branch: if c.Value != nil { @@ -175,7 +175,7 @@ func (t *Trie) NextKey(key []byte) []byte { return codec.NibblesToKeyLE(next) } -func (t *Trie) nextKey(curr node.Node, prefix, key []byte) []byte { +func (t *Trie) nextKey(curr Node, prefix, key []byte) []byte { switch c := curr.(type) { case *node.Branch: fullKey := append(prefix, c.Key...) @@ -262,7 +262,7 @@ func (t *Trie) tryPut(key, value []byte) { } // insert attempts to insert a key with value into the trie -func (t *Trie) insert(parent node.Node, key []byte, value node.Node) node.Node { +func (t *Trie) insert(parent Node, key []byte, value Node) Node { switch p := t.maybeUpdateGeneration(parent).(type) { case *node.Branch: n := t.updateBranch(p, key, value) @@ -330,7 +330,7 @@ func (t *Trie) insert(parent node.Node, key []byte, value node.Node) node.Node { // updateBranch attempts to add the value node to a branch // inserts the value node as the branch's child at the index that's // the first nibble of the key -func (t *Trie) updateBranch(p *node.Branch, key []byte, value node.Node) (n node.Node) { +func (t *Trie) updateBranch(p *node.Branch, key []byte, value Node) (n Node) { length := lenCommonPrefix(key, p.Key) // whole parent key matches @@ -412,7 +412,7 @@ func (t *Trie) GetKeysWithPrefix(prefix []byte) [][]byte { return t.getKeysWithPrefix(t.root, []byte{}, p, [][]byte{}) } -func (t *Trie) getKeysWithPrefix(parent node.Node, prefix, key []byte, keys [][]byte) [][]byte { +func (t *Trie) getKeysWithPrefix(parent Node, prefix, key []byte, keys [][]byte) [][]byte { switch p := parent.(type) { case *node.Branch: length := lenCommonPrefix(p.Key, key) @@ -443,7 +443,7 @@ func (t *Trie) getKeysWithPrefix(parent node.Node, prefix, key []byte, keys [][] // addAllKeys appends all keys that are descendants of the parent node to a slice of keys // it uses the prefix to determine the entire key -func (t *Trie) addAllKeys(parent node.Node, prefix []byte, keys [][]byte) [][]byte { +func (t *Trie) addAllKeys(parent Node, prefix []byte, keys [][]byte) [][]byte { switch p := parent.(type) { case *node.Branch: if p.Value != nil { @@ -477,7 +477,7 @@ func (t *Trie) tryGet(key []byte) *node.Leaf { return t.retrieve(t.root, k) } -func (t *Trie) retrieve(parent node.Node, key []byte) *node.Leaf { +func (t *Trie) retrieve(parent Node, key []byte) *node.Leaf { var ( value *node.Leaf ) @@ -526,7 +526,7 @@ func (t *Trie) ClearPrefixLimit(prefix []byte, limit uint32) (uint32, bool) { // clearPrefixLimit deletes the keys having the prefix till limit reached and returns updated trie root node, // true if any node in the trie got updated, and next bool returns true if there is no keys left with prefix. -func (t *Trie) clearPrefixLimit(cn node.Node, prefix []byte, limit *uint32) (node.Node, bool, bool) { +func (t *Trie) clearPrefixLimit(cn Node, prefix []byte, limit *uint32) (Node, bool, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { @@ -584,7 +584,7 @@ func (t *Trie) clearPrefixLimit(cn node.Node, prefix []byte, limit *uint32) (nod return nil, false, true } -func (t *Trie) deleteNodes(cn node.Node, prefix []byte, limit *uint32) (node.Node, bool) { +func (t *Trie) deleteNodes(cn Node, prefix []byte, limit *uint32) (Node, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { @@ -650,7 +650,7 @@ func (t *Trie) ClearPrefix(prefix []byte) { t.root, _ = t.clearPrefix(t.root, p) } -func (t *Trie) clearPrefix(cn node.Node, prefix []byte) (node.Node, bool) { +func (t *Trie) clearPrefix(cn Node, prefix []byte) (Node, bool) { curr := t.maybeUpdateGeneration(cn) switch c := curr.(type) { case *node.Branch: @@ -706,7 +706,7 @@ func (t *Trie) Delete(key []byte) { t.root, _ = t.delete(t.root, k) } -func (t *Trie) delete(parent node.Node, key []byte) (node.Node, bool) { +func (t *Trie) delete(parent Node, key []byte) (Node, bool) { // Store the current node and return it, if the trie is not updated. switch p := t.maybeUpdateGeneration(parent).(type) { case *node.Branch: @@ -746,8 +746,8 @@ func (t *Trie) delete(parent node.Node, key []byte) (node.Node, bool) { // handleDeletion is called when a value is deleted from a branch // if the updated branch only has 1 child, it should be combined with that child // if the updated branch only has a value, it should be turned into a leaf -func handleDeletion(p *node.Branch, key []byte) node.Node { - var n node.Node = p +func handleDeletion(p *node.Branch, key []byte) Node { + var n Node = p length := lenCommonPrefix(p.Key, key) bitmap := p.ChildrenBitmap() From d1e21ccd03e8929f7b15784de63491746ccdd43c Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 7 Dec 2021 10:04:33 +0000 Subject: [PATCH 24/50] `lib/trie/recorder` -> `internal/trie/recorder` --- {lib => internal}/trie/record/node.go | 0 {lib => internal}/trie/record/recorder.go | 0 {lib => internal}/trie/record/recorder_test.go | 0 lib/trie/lookup.go | 2 +- lib/trie/proof.go | 2 +- 5 files changed, 2 insertions(+), 2 deletions(-) rename {lib => internal}/trie/record/node.go (100%) rename {lib => internal}/trie/record/recorder.go (100%) rename {lib => internal}/trie/record/recorder_test.go (100%) diff --git a/lib/trie/record/node.go b/internal/trie/record/node.go similarity index 100% rename from lib/trie/record/node.go rename to internal/trie/record/node.go diff --git a/lib/trie/record/recorder.go b/internal/trie/record/recorder.go similarity index 100% rename from lib/trie/record/recorder.go rename to internal/trie/record/recorder.go diff --git a/lib/trie/record/recorder_test.go b/internal/trie/record/recorder_test.go similarity index 100% rename from lib/trie/record/recorder_test.go rename to internal/trie/record/recorder_test.go diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index 27a3d4944e..abf9ee9192 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -7,7 +7,7 @@ import ( "bytes" "github.com/ChainSafe/gossamer/internal/trie/node" - "github.com/ChainSafe/gossamer/lib/trie/record" + "github.com/ChainSafe/gossamer/internal/trie/record" ) var _ recorder = (*record.Recorder)(nil) diff --git a/lib/trie/proof.go b/lib/trie/proof.go index c46acad3df..b6a5f7e04e 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -10,9 +10,9 @@ import ( "fmt" "github.com/ChainSafe/chaindb" + "github.com/ChainSafe/gossamer/internal/trie/record" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie/codec" - "github.com/ChainSafe/gossamer/lib/trie/record" ) var ( From 4721f25677acfe3c9d70fe15269152c097beaefc Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 7 Dec 2021 10:06:20 +0000 Subject: [PATCH 25/50] `lib/trie/pools` -> `internal/trie/pools` --- internal/trie/node/branch_encode.go | 2 +- internal/trie/node/hash.go | 2 +- internal/trie/node/key.go | 2 +- internal/trie/node/leaf_encode.go | 2 +- {lib => internal}/trie/pools/pools.go | 0 lib/trie/decode.go | 2 +- lib/trie/print.go | 2 +- lib/trie/trie.go | 2 +- 8 files changed, 7 insertions(+), 7 deletions(-) rename {lib => internal}/trie/pools/pools.go (100%) diff --git a/internal/trie/node/branch_encode.go b/internal/trie/node/branch_encode.go index e24fdd7261..7efd6f3ea1 100644 --- a/internal/trie/node/branch_encode.go +++ b/internal/trie/node/branch_encode.go @@ -9,9 +9,9 @@ import ( "hash" "io" + "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie/codec" - "github.com/ChainSafe/gossamer/lib/trie/pools" "github.com/ChainSafe/gossamer/pkg/scale" ) diff --git a/internal/trie/node/hash.go b/internal/trie/node/hash.go index 14086a732c..97b74d18b0 100644 --- a/internal/trie/node/hash.go +++ b/internal/trie/node/hash.go @@ -6,8 +6,8 @@ package node import ( "bytes" + "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/pools" ) // SetEncodingAndHash sets the encoding and hash slices diff --git a/internal/trie/node/key.go b/internal/trie/node/key.go index 2248efd963..eab6478c6f 100644 --- a/internal/trie/node/key.go +++ b/internal/trie/node/key.go @@ -9,8 +9,8 @@ import ( "fmt" "io" + "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/trie/codec" - "github.com/ChainSafe/gossamer/lib/trie/pools" ) // SetKey sets the key to the branch. diff --git a/internal/trie/node/leaf_encode.go b/internal/trie/node/leaf_encode.go index 8caafb10e5..107fce0a93 100644 --- a/internal/trie/node/leaf_encode.go +++ b/internal/trie/node/leaf_encode.go @@ -9,9 +9,9 @@ import ( "hash" "io" + "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie/codec" - "github.com/ChainSafe/gossamer/lib/trie/pools" "github.com/ChainSafe/gossamer/pkg/scale" ) diff --git a/lib/trie/pools/pools.go b/internal/trie/pools/pools.go similarity index 100% rename from lib/trie/pools/pools.go rename to internal/trie/pools/pools.go diff --git a/lib/trie/decode.go b/lib/trie/decode.go index 5f8de9a0e4..730e251b61 100644 --- a/lib/trie/decode.go +++ b/lib/trie/decode.go @@ -10,7 +10,7 @@ import ( "io" "github.com/ChainSafe/gossamer/internal/trie/node" - "github.com/ChainSafe/gossamer/lib/trie/pools" + "github.com/ChainSafe/gossamer/internal/trie/pools" ) var ( diff --git a/lib/trie/print.go b/lib/trie/print.go index 954f35c5d9..340bd2291d 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -8,8 +8,8 @@ import ( "fmt" "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/pools" "github.com/disiqueira/gotree" ) diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 6dfcd03959..d5959c52b6 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -8,9 +8,9 @@ import ( "fmt" "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie/codec" - "github.com/ChainSafe/gossamer/lib/trie/pools" ) // EmptyHash is the empty trie hash. From 8cd2ebb15aca21149adb2585a9c2a839c4df77d9 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 7 Dec 2021 10:07:14 +0000 Subject: [PATCH 26/50] `lib/trie/codec` -> `internal/trie/codec` --- {lib => internal}/trie/codec/nibbles.go | 0 {lib => internal}/trie/codec/nibbles_test.go | 0 internal/trie/node/branch_encode.go | 2 +- internal/trie/node/key.go | 2 +- internal/trie/node/leaf_encode.go | 2 +- lib/trie/database.go | 2 +- lib/trie/proof.go | 2 +- lib/trie/trie.go | 2 +- lib/trie/trie_test.go | 2 +- 9 files changed, 7 insertions(+), 7 deletions(-) rename {lib => internal}/trie/codec/nibbles.go (100%) rename {lib => internal}/trie/codec/nibbles_test.go (100%) diff --git a/lib/trie/codec/nibbles.go b/internal/trie/codec/nibbles.go similarity index 100% rename from lib/trie/codec/nibbles.go rename to internal/trie/codec/nibbles.go diff --git a/lib/trie/codec/nibbles_test.go b/internal/trie/codec/nibbles_test.go similarity index 100% rename from lib/trie/codec/nibbles_test.go rename to internal/trie/codec/nibbles_test.go diff --git a/internal/trie/node/branch_encode.go b/internal/trie/node/branch_encode.go index 7efd6f3ea1..5c3ee95fbc 100644 --- a/internal/trie/node/branch_encode.go +++ b/internal/trie/node/branch_encode.go @@ -9,9 +9,9 @@ import ( "hash" "io" + "github.com/ChainSafe/gossamer/internal/trie/codec" "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/codec" "github.com/ChainSafe/gossamer/pkg/scale" ) diff --git a/internal/trie/node/key.go b/internal/trie/node/key.go index eab6478c6f..eddfa1e099 100644 --- a/internal/trie/node/key.go +++ b/internal/trie/node/key.go @@ -9,8 +9,8 @@ import ( "fmt" "io" + "github.com/ChainSafe/gossamer/internal/trie/codec" "github.com/ChainSafe/gossamer/internal/trie/pools" - "github.com/ChainSafe/gossamer/lib/trie/codec" ) // SetKey sets the key to the branch. diff --git a/internal/trie/node/leaf_encode.go b/internal/trie/node/leaf_encode.go index 107fce0a93..2246b66453 100644 --- a/internal/trie/node/leaf_encode.go +++ b/internal/trie/node/leaf_encode.go @@ -9,9 +9,9 @@ import ( "hash" "io" + "github.com/ChainSafe/gossamer/internal/trie/codec" "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/codec" "github.com/ChainSafe/gossamer/pkg/scale" ) diff --git a/lib/trie/database.go b/lib/trie/database.go index b3bb9e635f..237bf6db9e 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -8,9 +8,9 @@ import ( "errors" "fmt" + "github.com/ChainSafe/gossamer/internal/trie/codec" "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/codec" "github.com/ChainSafe/chaindb" ) diff --git a/lib/trie/proof.go b/lib/trie/proof.go index b6a5f7e04e..2d8444d2db 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -10,9 +10,9 @@ import ( "fmt" "github.com/ChainSafe/chaindb" + "github.com/ChainSafe/gossamer/internal/trie/codec" "github.com/ChainSafe/gossamer/internal/trie/record" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/codec" ) var ( diff --git a/lib/trie/trie.go b/lib/trie/trie.go index d5959c52b6..73f3f3ec4d 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -7,10 +7,10 @@ import ( "bytes" "fmt" + "github.com/ChainSafe/gossamer/internal/trie/codec" "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/codec" ) // EmptyHash is the empty trie hash. diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index d5260a7d20..e564c52df2 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -20,9 +20,9 @@ import ( "github.com/ChainSafe/chaindb" "github.com/stretchr/testify/require" + "github.com/ChainSafe/gossamer/internal/trie/codec" "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/trie/codec" ) type commonPrefixTest struct { From 1c53578a1502d2e5dc41aa8a47711195d507cda3 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 7 Dec 2021 11:02:37 +0000 Subject: [PATCH 27/50] Node interface named returns --- internal/trie/node/generation.go | 4 ++-- internal/trie/node/leaf_encode.go | 2 +- internal/trie/node/node.go | 12 ++++++------ 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/trie/node/generation.go b/internal/trie/node/generation.go index ac7adc6aca..cdb8d2f9f3 100644 --- a/internal/trie/node/generation.go +++ b/internal/trie/node/generation.go @@ -9,7 +9,7 @@ func (b *Branch) SetGeneration(generation uint64) { } // GetGeneration returns the generation of the branch. -func (b *Branch) GetGeneration() uint64 { +func (b *Branch) GetGeneration() (generation uint64) { return b.Generation } @@ -19,6 +19,6 @@ func (l *Leaf) SetGeneration(generation uint64) { } // GetGeneration returns the generation of the leaf. -func (l *Leaf) GetGeneration() uint64 { +func (l *Leaf) GetGeneration() (generation uint64) { return l.Generation } diff --git a/internal/trie/node/leaf_encode.go b/internal/trie/node/leaf_encode.go index 2246b66453..337bf12eed 100644 --- a/internal/trie/node/leaf_encode.go +++ b/internal/trie/node/leaf_encode.go @@ -130,7 +130,7 @@ func (l *Leaf) Encode(buffer Buffer) (err error) { // ScaleEncodeHash hashes the node (blake2b sum on encoded value) // and then SCALE encodes it. This is used to encode children // nodes of branches. -func (l *Leaf) ScaleEncodeHash() (b []byte, err error) { +func (l *Leaf) ScaleEncodeHash() (encoding []byte, err error) { buffer := pools.DigestBuffers.Get().(*bytes.Buffer) buffer.Reset() defer pools.DigestBuffers.Put(buffer) diff --git a/internal/trie/node/node.go b/internal/trie/node/node.go index b2c6b7d935..85047bb8aa 100644 --- a/internal/trie/node/node.go +++ b/internal/trie/node/node.go @@ -6,15 +6,15 @@ package node // Node is a node in the trie and can be a leaf or a branch. type Node interface { Encode(buffer Buffer) (err error) // TODO change to io.Writer - EncodeAndHash() ([]byte, []byte, error) - ScaleEncodeHash() (b []byte, err error) + EncodeAndHash() (encoding []byte, hash []byte, err error) + ScaleEncodeHash() (encoding []byte, err error) IsDirty() bool SetDirty(dirty bool) SetKey(key []byte) String() string - SetEncodingAndHash([]byte, []byte) - GetHash() []byte - GetGeneration() uint64 - SetGeneration(uint64) + SetEncodingAndHash(encoding []byte, hash []byte) + GetHash() (hash []byte) + GetGeneration() (generation uint64) + SetGeneration(generation uint64) Copy() Node } From 79c2ecb6edf9aefc4d348eb0611ca4f5c6a41188 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 7 Dec 2021 13:19:13 +0000 Subject: [PATCH 28/50] Use `bytes.NewReader` for readers --- internal/trie/node/decode.go | 14 +++++++------- internal/trie/node/decode_test.go | 8 ++++---- lib/trie/database.go | 10 +++++----- lib/trie/decode_test.go | 12 ++++++------ 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/internal/trie/node/decode.go b/internal/trie/node/decode.go index a640196c4f..733ec85aee 100644 --- a/internal/trie/node/decode.go +++ b/internal/trie/node/decode.go @@ -79,24 +79,26 @@ func DecodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { } // DecodeLeaf reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. -func DecodeLeaf(r io.Reader, header byte) (leaf *Leaf, err error) { +func DecodeLeaf(reader io.Reader, header byte) (leaf *Leaf, err error) { nodeType := header >> 6 if nodeType != 1 { return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotALeaf, nodeType) } - leaf = new(Leaf) + leaf = &Leaf{ + Dirty: true, + } keyLen := header & 0x3f - leaf.Key, err = decodeKey(r, keyLen) + leaf.Key, err = decodeKey(reader, keyLen) if err != nil { return nil, fmt.Errorf("cannot decode key: %w", err) } - sd := scale.NewDecoder(r) + sd := scale.NewDecoder(reader) var value []byte err = sd.Decode(&value) - if err != nil { + if err != nil && !errors.Is(err, io.EOF) { return nil, fmt.Errorf("%w: %s", ErrDecodeValue, err) } @@ -104,7 +106,5 @@ func DecodeLeaf(r io.Reader, header byte) (leaf *Leaf, err error) { leaf.Value = value } - leaf.Dirty = true - return leaf, nil } diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go index daf64b5992..c6840f0a74 100644 --- a/internal/trie/node/decode_test.go +++ b/internal/trie/node/decode_test.go @@ -176,17 +176,17 @@ func Test_DecodeLeaf(t *testing.T) { }, "value decoding error": { reader: bytes.NewBuffer([]byte{ - 9, // key data - // missing value data + 9, // key data + 255, 255, // bad value data }), header: 65, // node type 1 and key length 1 errWrapped: ErrDecodeValue, - errMessage: "cannot decode value: EOF", + errMessage: "cannot decode value: could not decode invalid integer", }, "zero value": { reader: bytes.NewBuffer([]byte{ 9, // key data - 0, // missing value data + // missing value data }), header: 65, // node type 1 and key length 1 leaf: &Leaf{ diff --git a/lib/trie/database.go b/lib/trie/database.go index 237bf6db9e..2f003e701e 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -79,7 +79,7 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { // map all the proofs hash -> decoded node // and takes the loop to indentify the root node for _, rawNode := range proof { - decNode, err := decodeNode(bytes.NewBuffer(rawNode)) + decNode, err := decodeNode(bytes.NewReader(rawNode)) if err != nil { return err } @@ -139,7 +139,7 @@ func (t *Trie) Load(db chaindb.Database, root common.Hash) error { return fmt.Errorf("failed to find root key=%s: %w", root, err) } - t.root, err = decodeNode(bytes.NewBuffer(enc)) + t.root, err = decodeNode(bytes.NewReader(enc)) if err != nil { return err } @@ -163,7 +163,7 @@ func (t *Trie) load(db chaindb.Database, curr Node) error { return fmt.Errorf("failed to find node key=%x index=%d: %w", hash, i, err) } - child, err = decodeNode(bytes.NewBuffer(enc)) + child, err = decodeNode(bytes.NewReader(enc)) if err != nil { return err } @@ -243,7 +243,7 @@ func GetFromDB(db chaindb.Database, root common.Hash, key []byte) ([]byte, error return nil, fmt.Errorf("failed to find root key=%s: %w", root, err) } - rootNode, err := decodeNode(bytes.NewBuffer(enc)) + rootNode, err := decodeNode(bytes.NewReader(enc)) if err != nil { return nil, err } @@ -278,7 +278,7 @@ func getFromDB(db chaindb.Database, parent Node, key []byte) ([]byte, error) { return nil, fmt.Errorf("failed to find node in database: %w", err) } - child, err := decodeNode(bytes.NewBuffer(enc)) + child, err := decodeNode(bytes.NewReader(enc)) if err != nil { return nil, err } diff --git a/lib/trie/decode_test.go b/lib/trie/decode_test.go index 6d0d39d9e5..34788bbf1a 100644 --- a/lib/trie/decode_test.go +++ b/lib/trie/decode_test.go @@ -30,17 +30,17 @@ func Test_decodeNode(t *testing.T) { errMessage string }{ "no data": { - reader: bytes.NewBuffer(nil), + reader: bytes.NewReader(nil), errWrapped: ErrReadHeaderByte, errMessage: "cannot read header byte: EOF", }, "unknown node type": { - reader: bytes.NewBuffer([]byte{0}), + reader: bytes.NewReader([]byte{0}), errWrapped: ErrUnknownNodeType, errMessage: "unknown node type: 0", }, "leaf decoding error": { - reader: bytes.NewBuffer([]byte{ + reader: bytes.NewReader([]byte{ 65, // node type 1 and key length 1 // missing key data byte }), @@ -48,7 +48,7 @@ func Test_decodeNode(t *testing.T) { errMessage: "cannot decode leaf: cannot decode key: cannot read key data: EOF", }, "leaf success": { - reader: bytes.NewBuffer( + reader: bytes.NewReader( append( []byte{ 65, // node type 1 and key length 1 @@ -64,7 +64,7 @@ func Test_decodeNode(t *testing.T) { }, }, "branch decoding error": { - reader: bytes.NewBuffer([]byte{ + reader: bytes.NewReader([]byte{ 129, // node type 2 and key length 1 // missing key data byte }), @@ -72,7 +72,7 @@ func Test_decodeNode(t *testing.T) { errMessage: "cannot decode branch: cannot decode key: cannot read key data: EOF", }, "branch success": { - reader: bytes.NewBuffer( + reader: bytes.NewReader( []byte{ 129, // node type 2 and key length 1 9, // key data From 5cce780f70f85680a80ca09e6faae2ecbe4bac7a Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Wed, 8 Dec 2021 15:07:56 +0000 Subject: [PATCH 29/50] Add `GetValue` method to node interface --- internal/trie/node/node.go | 1 + internal/trie/node/value.go | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) create mode 100644 internal/trie/node/value.go diff --git a/internal/trie/node/node.go b/internal/trie/node/node.go index 85047bb8aa..178032aa96 100644 --- a/internal/trie/node/node.go +++ b/internal/trie/node/node.go @@ -14,6 +14,7 @@ type Node interface { String() string SetEncodingAndHash(encoding []byte, hash []byte) GetHash() (hash []byte) + GetValue() (value []byte) GetGeneration() (generation uint64) SetGeneration(generation uint64) Copy() Node diff --git a/internal/trie/node/value.go b/internal/trie/node/value.go new file mode 100644 index 0000000000..5ab07fb589 --- /dev/null +++ b/internal/trie/node/value.go @@ -0,0 +1,18 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +// GetValue returns the value of the branch. +// Note it does not copy the byte slice so modifying the returned +// byte slice will modify the byte slice of the branch. +func (b *Branch) GetValue() (value []byte) { + return b.Value +} + +// GetValue returns the value of the leaf. +// Note it does not copy the byte slice so modifying the returned +// byte slice will modify the byte slice of the leaf. +func (l *Leaf) GetValue() (value []byte) { + return l.Value +} From 5b800f17ed2339515019f4c10365e6714d0cba9b Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Wed, 8 Dec 2021 15:17:43 +0000 Subject: [PATCH 30/50] Add `GetKey` method to node interface --- internal/trie/node/key.go | 14 ++++++++++++++ internal/trie/node/node.go | 1 + 2 files changed, 15 insertions(+) diff --git a/internal/trie/node/key.go b/internal/trie/node/key.go index eddfa1e099..c1dfd56865 100644 --- a/internal/trie/node/key.go +++ b/internal/trie/node/key.go @@ -13,6 +13,20 @@ import ( "github.com/ChainSafe/gossamer/internal/trie/pools" ) +// GetKey returns the key of the branch. +// Note it does not copy the byte slice so modifying the returned +// byte slice will modify the byte slice of the branch. +func (b *Branch) GetKey() (value []byte) { + return b.Key +} + +// GetKey returns the key of the leaf. +// Note it does not copy the byte slice so modifying the returned +// byte slice will modify the byte slice of the leaf. +func (l *Leaf) GetKey() (value []byte) { + return l.Key +} + // SetKey sets the key to the branch. // Note it does not copy it so modifying the passed key // will modify the key stored in the branch. diff --git a/internal/trie/node/node.go b/internal/trie/node/node.go index 178032aa96..6c306979bb 100644 --- a/internal/trie/node/node.go +++ b/internal/trie/node/node.go @@ -14,6 +14,7 @@ type Node interface { String() string SetEncodingAndHash(encoding []byte, hash []byte) GetHash() (hash []byte) + GetKey() (key []byte) GetValue() (value []byte) GetGeneration() (generation uint64) SetGeneration(generation uint64) From 93ee8b647f01dc915e0054494b7b4fc3af6e4d98 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 9 Dec 2021 22:55:23 +0000 Subject: [PATCH 31/50] `node.Decode` function --- internal/trie/node/decode.go | 43 +++++++++- internal/trie/node/decode_test.go | 90 +++++++++++++++++++- internal/trie/node/encode_decode_test.go | 2 +- lib/trie/database.go | 10 +-- lib/trie/decode.go | 48 ----------- lib/trie/decode_test.go | 103 ----------------------- 6 files changed, 131 insertions(+), 165 deletions(-) delete mode 100644 lib/trie/decode.go delete mode 100644 lib/trie/decode_test.go diff --git a/internal/trie/node/decode.go b/internal/trie/node/decode.go index 733ec85aee..05be18c2b4 100644 --- a/internal/trie/node/decode.go +++ b/internal/trie/node/decode.go @@ -4,15 +4,18 @@ package node import ( + "bytes" "errors" "fmt" "io" + "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/pkg/scale" ) var ( ErrReadHeaderByte = errors.New("cannot read header byte") + ErrUnknownNodeType = errors.New("unknown node type") ErrNodeTypeIsNotABranch = errors.New("node type is not a branch") ErrNodeTypeIsNotALeaf = errors.New("node type is not a leaf") ErrDecodeValue = errors.New("cannot decode value") @@ -20,12 +23,44 @@ var ( ErrDecodeChildHash = errors.New("cannot decode child hash") ) -// DecodeBranch reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. +// Decode decodes a node from a reader. +// For branch decoding, see the comments on decodeBranch. +// For leaf decoding, see the comments on decodeLeaf. +func Decode(reader io.Reader) (n Node, err error) { + buffer := pools.SingleByteBuffers.Get().(*bytes.Buffer) + defer pools.SingleByteBuffers.Put(buffer) + oneByteBuf := buffer.Bytes() + _, err = reader.Read(oneByteBuf) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrReadHeaderByte, err) + } + header := oneByteBuf[0] + + nodeType := header >> 6 + switch nodeType { + case LeafType: + n, err = decodeLeaf(reader, header) + if err != nil { + return nil, fmt.Errorf("cannot decode leaf: %w", err) + } + return n, nil + case BranchType, BranchWithValueType: + n, err = decodeBranch(reader, header) + if err != nil { + return nil, fmt.Errorf("cannot decode branch: %w", err) + } + return n, nil + default: + return nil, fmt.Errorf("%w: %d", ErrUnknownNodeType, nodeType) + } +} + +// decodeBranch reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. // Note that since the encoded branch stores the hash of the children nodes, we are not // reconstructing the child nodes from the encoding. This function instead stubs where the // children are known to be with an empty leaf. The children nodes hashes are then used to // find other values using the persistent database. -func DecodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { +func decodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { nodeType := header >> 6 if nodeType != 2 && nodeType != 3 { return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotABranch, nodeType) @@ -78,8 +113,8 @@ func DecodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { return branch, nil } -// DecodeLeaf reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. -func DecodeLeaf(reader io.Reader, header byte) (leaf *Leaf, err error) { +// decodeLeaf reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. +func decodeLeaf(reader io.Reader, header byte) (leaf *Leaf, err error) { nodeType := header >> 6 if nodeType != 1 { return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotALeaf, nodeType) diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go index c6840f0a74..b3b2d91ef8 100644 --- a/internal/trie/node/decode_test.go +++ b/internal/trie/node/decode_test.go @@ -31,7 +31,89 @@ func concatByteSlices(slices [][]byte) (concatenated []byte) { return concatenated } -func Test_DecodeBranch(t *testing.T) { +func Test_Decode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reader io.Reader + n Node + errWrapped error + errMessage string + }{ + "no data": { + reader: bytes.NewReader(nil), + errWrapped: ErrReadHeaderByte, + errMessage: "cannot read header byte: EOF", + }, + "unknown node type": { + reader: bytes.NewReader([]byte{0}), + errWrapped: ErrUnknownNodeType, + errMessage: "unknown node type: 0", + }, + "leaf decoding error": { + reader: bytes.NewReader([]byte{ + 65, // node type 1 and key length 1 + // missing key data byte + }), + errWrapped: ErrReadKeyData, + errMessage: "cannot decode leaf: cannot decode key: cannot read key data: EOF", + }, + "leaf success": { + reader: bytes.NewReader( + append( + []byte{ + 65, // node type 1 and key length 1 + 9, // key data + }, + scaleEncodeBytes(t, 1, 2, 3)..., + ), + ), + n: &Leaf{ + Key: []byte{9}, + Value: []byte{1, 2, 3}, + Dirty: true, + }, + }, + "branch decoding error": { + reader: bytes.NewReader([]byte{ + 129, // node type 2 and key length 1 + // missing key data byte + }), + errWrapped: ErrReadKeyData, + errMessage: "cannot decode branch: cannot decode key: cannot read key data: EOF", + }, + "branch success": { + reader: bytes.NewReader( + []byte{ + 129, // node type 2 and key length 1 + 9, // key data + 0, 0, // no children bitmap + }, + ), + n: &Branch{ + Key: []byte{9}, + Dirty: true, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + n, err := Decode(testCase.reader) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.n, n) + }) + } +} + +func Test_decodeBranch(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -139,7 +221,7 @@ func Test_DecodeBranch(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - branch, err := DecodeBranch(testCase.reader, testCase.header) + branch, err := decodeBranch(testCase.reader, testCase.header) assert.ErrorIs(t, err, testCase.errWrapped) if err != nil { @@ -150,7 +232,7 @@ func Test_DecodeBranch(t *testing.T) { } } -func Test_DecodeLeaf(t *testing.T) { +func Test_decodeLeaf(t *testing.T) { t.Parallel() testCases := map[string]struct { @@ -215,7 +297,7 @@ func Test_DecodeLeaf(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - leaf, err := DecodeLeaf(testCase.reader, testCase.header) + leaf, err := decodeLeaf(testCase.reader, testCase.header) assert.ErrorIs(t, err, testCase.errWrapped) if err != nil { diff --git a/internal/trie/node/encode_decode_test.go b/internal/trie/node/encode_decode_test.go index cc380060d6..f8ba60df3f 100644 --- a/internal/trie/node/encode_decode_test.go +++ b/internal/trie/node/encode_decode_test.go @@ -80,7 +80,7 @@ func Test_Branch_Encode_Decode(t *testing.T) { require.NoError(t, err) header := oneBuffer[0] - resultBranch, err := DecodeBranch(buffer, header) + resultBranch, err := decodeBranch(buffer, header) require.NoError(t, err) assert.Equal(t, testCase.branchDecoded, resultBranch) diff --git a/lib/trie/database.go b/lib/trie/database.go index 2f003e701e..362720c5ce 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -79,7 +79,7 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { // map all the proofs hash -> decoded node // and takes the loop to indentify the root node for _, rawNode := range proof { - decNode, err := decodeNode(bytes.NewReader(rawNode)) + decNode, err := node.Decode(bytes.NewReader(rawNode)) if err != nil { return err } @@ -139,7 +139,7 @@ func (t *Trie) Load(db chaindb.Database, root common.Hash) error { return fmt.Errorf("failed to find root key=%s: %w", root, err) } - t.root, err = decodeNode(bytes.NewReader(enc)) + t.root, err = node.Decode(bytes.NewReader(enc)) if err != nil { return err } @@ -163,7 +163,7 @@ func (t *Trie) load(db chaindb.Database, curr Node) error { return fmt.Errorf("failed to find node key=%x index=%d: %w", hash, i, err) } - child, err = decodeNode(bytes.NewReader(enc)) + child, err = node.Decode(bytes.NewReader(enc)) if err != nil { return err } @@ -243,7 +243,7 @@ func GetFromDB(db chaindb.Database, root common.Hash, key []byte) ([]byte, error return nil, fmt.Errorf("failed to find root key=%s: %w", root, err) } - rootNode, err := decodeNode(bytes.NewReader(enc)) + rootNode, err := node.Decode(bytes.NewReader(enc)) if err != nil { return nil, err } @@ -278,7 +278,7 @@ func getFromDB(db chaindb.Database, parent Node, key []byte) ([]byte, error) { return nil, fmt.Errorf("failed to find node in database: %w", err) } - child, err := decodeNode(bytes.NewReader(enc)) + child, err := node.Decode(bytes.NewReader(enc)) if err != nil { return nil, err } diff --git a/lib/trie/decode.go b/lib/trie/decode.go deleted file mode 100644 index 730e251b61..0000000000 --- a/lib/trie/decode.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "bytes" - "errors" - "fmt" - "io" - - "github.com/ChainSafe/gossamer/internal/trie/node" - "github.com/ChainSafe/gossamer/internal/trie/pools" -) - -var ( - ErrReadHeaderByte = errors.New("cannot read header byte") - ErrUnknownNodeType = errors.New("unknown node type") -) - -func decodeNode(reader io.Reader) (n Node, err error) { - buffer := pools.SingleByteBuffers.Get().(*bytes.Buffer) - defer pools.SingleByteBuffers.Put(buffer) - oneByteBuf := buffer.Bytes() - _, err = reader.Read(oneByteBuf) - if err != nil { - return nil, fmt.Errorf("%w: %s", ErrReadHeaderByte, err) - } - header := oneByteBuf[0] - - nodeType := header >> 6 - switch nodeType { - case node.LeafType: - n, err = node.DecodeLeaf(reader, header) - if err != nil { - return nil, fmt.Errorf("cannot decode leaf: %w", err) - } - return n, nil - case node.BranchType, node.BranchWithValueType: - n, err = node.DecodeBranch(reader, header) - if err != nil { - return nil, fmt.Errorf("cannot decode branch: %w", err) - } - return n, nil - default: - return nil, fmt.Errorf("%w: %d", ErrUnknownNodeType, nodeType) - } -} diff --git a/lib/trie/decode_test.go b/lib/trie/decode_test.go deleted file mode 100644 index 34788bbf1a..0000000000 --- a/lib/trie/decode_test.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package trie - -import ( - "bytes" - "io" - "testing" - - "github.com/ChainSafe/gossamer/internal/trie/node" - "github.com/ChainSafe/gossamer/pkg/scale" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func scaleEncodeBytes(t *testing.T, b ...byte) (encoded []byte) { - encoded, err := scale.Marshal(b) - require.NoError(t, err) - return encoded -} - -func Test_decodeNode(t *testing.T) { - t.Parallel() - - testCases := map[string]struct { - reader io.Reader - n Node - errWrapped error - errMessage string - }{ - "no data": { - reader: bytes.NewReader(nil), - errWrapped: ErrReadHeaderByte, - errMessage: "cannot read header byte: EOF", - }, - "unknown node type": { - reader: bytes.NewReader([]byte{0}), - errWrapped: ErrUnknownNodeType, - errMessage: "unknown node type: 0", - }, - "leaf decoding error": { - reader: bytes.NewReader([]byte{ - 65, // node type 1 and key length 1 - // missing key data byte - }), - errWrapped: node.ErrReadKeyData, - errMessage: "cannot decode leaf: cannot decode key: cannot read key data: EOF", - }, - "leaf success": { - reader: bytes.NewReader( - append( - []byte{ - 65, // node type 1 and key length 1 - 9, // key data - }, - scaleEncodeBytes(t, 1, 2, 3)..., - ), - ), - n: &node.Leaf{ - Key: []byte{9}, - Value: []byte{1, 2, 3}, - Dirty: true, - }, - }, - "branch decoding error": { - reader: bytes.NewReader([]byte{ - 129, // node type 2 and key length 1 - // missing key data byte - }), - errWrapped: node.ErrReadKeyData, - errMessage: "cannot decode branch: cannot decode key: cannot read key data: EOF", - }, - "branch success": { - reader: bytes.NewReader( - []byte{ - 129, // node type 2 and key length 1 - 9, // key data - 0, 0, // no children bitmap - }, - ), - n: &node.Branch{ - Key: []byte{9}, - Dirty: true, - }, - }, - } - - for name, testCase := range testCases { - testCase := testCase - t.Run(name, func(t *testing.T) { - t.Parallel() - - n, err := decodeNode(testCase.reader) - - assert.ErrorIs(t, err, testCase.errWrapped) - if err != nil { - assert.EqualError(t, err, testCase.errMessage) - } - assert.Equal(t, testCase.n, n) - }) - } -} From c054516588706bdcdd050d68ef8d9f9e304e281e Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 13 Dec 2021 17:08:50 +0200 Subject: [PATCH 32/50] Apply suggestions from @kishansagathiya's code review Co-authored-by: Kishan Sagathiya --- internal/trie/codec/nibbles.go | 17 +++++++++-------- internal/trie/node/children.go | 5 +++-- internal/trie/node/decode.go | 6 +++--- internal/trie/node/encode_test.go | 2 +- 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/internal/trie/codec/nibbles.go b/internal/trie/codec/nibbles.go index 11e5a2e818..7b6f9bd4de 100644 --- a/internal/trie/codec/nibbles.go +++ b/internal/trie/codec/nibbles.go @@ -10,18 +10,19 @@ package codec // [ 0000 in[0] | in[1] in[2] | ... | in[k-2] in[k-1] ] // Otherwise, the result is // [ in[0] in[1] | ... | in[k-2] in[k-1] ] -func NibblesToKeyLE(nibbles []byte) (keyLE []byte) { +func NibblesToKeyLE(nibbles []byte) []byte { if len(nibbles)%2 == 0 { - keyLE = make([]byte, len(nibbles)/2) + keyLE := make([]byte, len(nibbles)/2) for i := 0; i < len(nibbles); i += 2 { keyLE[i/2] = (nibbles[i] << 4 & 0xf0) | (nibbles[i+1] & 0xf) } - } else { - keyLE = make([]byte, len(nibbles)/2+1) - keyLE[0] = nibbles[0] - for i := 2; i < len(nibbles); i += 2 { - keyLE[i/2] = (nibbles[i-1] << 4 & 0xf0) | (nibbles[i] & 0xf) - } + return keyLE + } + + keyLE := make([]byte, len(nibbles)/2+1) + keyLE[0] = nibbles[0] + for i := 2; i < len(nibbles); i += 2 { + keyLE[i/2] = (nibbles[i-1] << 4 & 0xf0) | (nibbles[i] & 0xf) } return keyLE diff --git a/internal/trie/node/children.go b/internal/trie/node/children.go index bd581cf657..02d23b7cab 100644 --- a/internal/trie/node/children.go +++ b/internal/trie/node/children.go @@ -7,9 +7,10 @@ package node // of the children in the branch. func (b *Branch) ChildrenBitmap() (bitmap uint16) { for i := uint(0); i < 16; i++ { - if b.Children[i] != nil { - bitmap = bitmap | 1<> 6 - if nodeType != 2 && nodeType != 3 { + if nodeType != BranchType && nodeType != BranchWithValueType { return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotABranch, nodeType) } @@ -82,7 +82,7 @@ func decodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { sd := scale.NewDecoder(reader) - if nodeType == 3 { + if nodeType == BranchWithValueType { var value []byte // branch w/ value err := sd.Decode(&value) @@ -116,7 +116,7 @@ func decodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { // decodeLeaf reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. func decodeLeaf(reader io.Reader, header byte) (leaf *Leaf, err error) { nodeType := header >> 6 - if nodeType != 1 { + if nodeType != LeafType { return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotALeaf, nodeType) } diff --git a/internal/trie/node/encode_test.go b/internal/trie/node/encode_test.go index 49a41ad0e0..cc72efc06a 100644 --- a/internal/trie/node/encode_test.go +++ b/internal/trie/node/encode_test.go @@ -7,7 +7,7 @@ import "errors" type writeCall struct { written []byte - n int + n int // number of bytes err error } From 87f058de8d9c0d36ff94c4b015248112e0df0bb7 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 13 Dec 2021 15:21:56 +0000 Subject: [PATCH 33/50] Improve ScaleEncodeHash error wrapping --- internal/trie/node/branch_encode.go | 6 +++--- internal/trie/node/leaf_encode.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/trie/node/branch_encode.go b/internal/trie/node/branch_encode.go index 5c3ee95fbc..599feeb39e 100644 --- a/internal/trie/node/branch_encode.go +++ b/internal/trie/node/branch_encode.go @@ -25,12 +25,12 @@ func (b *Branch) ScaleEncodeHash() (encoding []byte, err error) { err = b.hash(buffer) if err != nil { - return nil, fmt.Errorf("cannot hash node: %w", err) + return nil, fmt.Errorf("cannot hash branch: %w", err) } encoding, err = scale.Marshal(buffer.Bytes()) if err != nil { - return nil, fmt.Errorf("cannot scale encode hashed node: %w", err) + return nil, fmt.Errorf("cannot scale encode hashed branch: %w", err) } return encoding, nil @@ -113,7 +113,7 @@ func (b *Branch) Encode(buffer Buffer) (err error) { } } - const parallel = false // TODO + const parallel = false // TODO Done in pull request #2081 if parallel { err = encodeChildrenInParallel(b.Children, buffer) } else { diff --git a/internal/trie/node/leaf_encode.go b/internal/trie/node/leaf_encode.go index 337bf12eed..184bfca644 100644 --- a/internal/trie/node/leaf_encode.go +++ b/internal/trie/node/leaf_encode.go @@ -137,12 +137,12 @@ func (l *Leaf) ScaleEncodeHash() (encoding []byte, err error) { err = l.hash(buffer) if err != nil { - return nil, fmt.Errorf("cannot hash node: %w", err) + return nil, fmt.Errorf("cannot hash leaf: %w", err) } scEncChild, err := scale.Marshal(buffer.Bytes()) if err != nil { - return nil, fmt.Errorf("cannot scale encode hashed node: %w", err) + return nil, fmt.Errorf("cannot scale encode hashed leaf: %w", err) } return scEncChild, nil } From 9fd4036355a0130a92c71529fe2a7ec857f5162a Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 13 Dec 2021 15:22:20 +0000 Subject: [PATCH 34/50] Shorten children bitmap bitwise or --- internal/trie/node/children.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/trie/node/children.go b/internal/trie/node/children.go index 02d23b7cab..be4f9e47ea 100644 --- a/internal/trie/node/children.go +++ b/internal/trie/node/children.go @@ -10,7 +10,7 @@ func (b *Branch) ChildrenBitmap() (bitmap uint16) { if b.Children[i] == nil { continue } - bitmap = bitmap | 1< Date: Mon, 13 Dec 2021 18:57:36 +0000 Subject: [PATCH 35/50] Add comments for `Dirty` and `Generation` --- internal/trie/node/branch.go | 18 ++++++++++++------ internal/trie/node/leaf.go | 10 ++++++++-- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/internal/trie/node/branch.go b/internal/trie/node/branch.go index 45547d1b3b..3903f1dcab 100644 --- a/internal/trie/node/branch.go +++ b/internal/trie/node/branch.go @@ -14,12 +14,18 @@ var _ Node = (*Branch)(nil) // Branch is a branch in the trie. type Branch struct { - Key []byte // partial key - Children [16]Node - Value []byte - Dirty bool - Hash []byte - Encoding []byte + Key []byte // partial key + Children [16]Node + Value []byte + // Dirty is true when the branch differs + // from the node stored in the database. + Dirty bool + Hash []byte + Encoding []byte + // Generation is incremented on every trie Snapshot() call. + // Nodes that are part of the trie are then gradually updated + // to have a matching generation number as well, if they are + // still relevant. Generation uint64 sync.RWMutex } diff --git a/internal/trie/node/leaf.go b/internal/trie/node/leaf.go index d77fe3339f..01984eab4f 100644 --- a/internal/trie/node/leaf.go +++ b/internal/trie/node/leaf.go @@ -14,12 +14,18 @@ var _ Node = (*Leaf)(nil) // Leaf is a leaf in the trie. type Leaf struct { - Key []byte // partial key - Value []byte + Key []byte // partial key + Value []byte + // Dirty is true when the branch differs + // from the node stored in the database. Dirty bool Hash []byte Encoding []byte encodingMu sync.RWMutex + // Generation is incremented on every trie Snapshot() call. + // Nodes that are part of the trie are then gradually updated + // to have a matching generation number as well, if they are + // still relevant. Generation uint64 sync.RWMutex } From 8c213df630dcf2d2cd591c9e4bdec26766bee7d5 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 13 Dec 2021 19:06:15 +0000 Subject: [PATCH 36/50] Unexport `Generation` node field --- internal/trie/node/branch.go | 14 ++++++++++++-- internal/trie/node/copy.go | 4 ++-- internal/trie/node/generation.go | 8 ++++---- internal/trie/node/leaf.go | 14 ++++++++++++-- lib/trie/print.go | 10 ++++++---- lib/trie/trie.go | 10 +++++++--- 6 files changed, 43 insertions(+), 17 deletions(-) diff --git a/internal/trie/node/branch.go b/internal/trie/node/branch.go index 3903f1dcab..8669579aac 100644 --- a/internal/trie/node/branch.go +++ b/internal/trie/node/branch.go @@ -22,14 +22,24 @@ type Branch struct { Dirty bool Hash []byte Encoding []byte - // Generation is incremented on every trie Snapshot() call. + // generation is incremented on every trie Snapshot() call. // Nodes that are part of the trie are then gradually updated // to have a matching generation number as well, if they are // still relevant. - Generation uint64 + generation uint64 sync.RWMutex } +// NewBranch creates a new branch using the arguments given. +func NewBranch(key, value []byte, dirty bool, generation uint64) *Branch { + return &Branch{ + Key: key, + Value: value, + Dirty: dirty, + generation: generation, + } +} + func (b *Branch) String() string { if len(b.Value) > 1024 { return fmt.Sprintf("key=%x childrenBitmap=%16b value (hashed)=%x dirty=%v", diff --git a/internal/trie/node/copy.go b/internal/trie/node/copy.go index 41f1b068d1..45add922a1 100644 --- a/internal/trie/node/copy.go +++ b/internal/trie/node/copy.go @@ -15,7 +15,7 @@ func (b *Branch) Copy() Node { Dirty: b.Dirty, Hash: make([]byte, len(b.Hash)), Encoding: make([]byte, len(b.Encoding)), - Generation: b.Generation, + generation: b.generation, } copy(cpy.Key, b.Key) @@ -44,7 +44,7 @@ func (l *Leaf) Copy() Node { Dirty: l.Dirty, Hash: make([]byte, len(l.Hash)), Encoding: make([]byte, len(l.Encoding)), - Generation: l.Generation, + generation: l.generation, } copy(cpy.Key, l.Key) copy(cpy.Value, l.Value) diff --git a/internal/trie/node/generation.go b/internal/trie/node/generation.go index cdb8d2f9f3..113c283328 100644 --- a/internal/trie/node/generation.go +++ b/internal/trie/node/generation.go @@ -5,20 +5,20 @@ package node // SetGeneration sets the generation given to the branch. func (b *Branch) SetGeneration(generation uint64) { - b.Generation = generation + b.generation = generation } // GetGeneration returns the generation of the branch. func (b *Branch) GetGeneration() (generation uint64) { - return b.Generation + return b.generation } // SetGeneration sets the generation given to the leaf. func (l *Leaf) SetGeneration(generation uint64) { - l.Generation = generation + l.generation = generation } // GetGeneration returns the generation of the leaf. func (l *Leaf) GetGeneration() (generation uint64) { - return l.Generation + return l.generation } diff --git a/internal/trie/node/leaf.go b/internal/trie/node/leaf.go index 01984eab4f..7a42f55ab3 100644 --- a/internal/trie/node/leaf.go +++ b/internal/trie/node/leaf.go @@ -22,14 +22,24 @@ type Leaf struct { Hash []byte Encoding []byte encodingMu sync.RWMutex - // Generation is incremented on every trie Snapshot() call. + // generation is incremented on every trie Snapshot() call. // Nodes that are part of the trie are then gradually updated // to have a matching generation number as well, if they are // still relevant. - Generation uint64 + generation uint64 sync.RWMutex } +// NewLeaf creates a new leaf using the arguments given. +func NewLeaf(key, value []byte, dirty bool, generation uint64) *Leaf { + return &Leaf{ + Key: key, + Value: value, + Dirty: dirty, + generation: generation, + } +} + func (l *Leaf) String() string { if len(l.Value) > 1024 { return fmt.Sprintf("leaf key=%x value (hashed)=%x dirty=%v", l.Key, common.MustBlake2bHash(l.Value), l.Dirty) diff --git a/lib/trie/print.go b/lib/trie/print.go index 340bd2291d..4b3ffcc4da 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -36,9 +36,10 @@ func (t *Trie) string(tree gotree.Tree, curr Node, idx int) { var bstr string if len(c.Encoding) > 1024 { - bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", idx, c.String(), common.MustBlake2bHash(c.Encoding), c.Generation) + bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", + idx, c, common.MustBlake2bHash(c.Encoding), c.GetGeneration()) } else { - bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.Encoding, c.Generation) + bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.Encoding, c.GetGeneration()) } pools.EncodingBuffers.Put(buffer) @@ -61,9 +62,10 @@ func (t *Trie) string(tree gotree.Tree, curr Node, idx int) { var bstr string if len(c.Encoding) > 1024 { - bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", idx, c.String(), common.MustBlake2bHash(c.Encoding), c.Generation) + bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", + idx, c.String(), common.MustBlake2bHash(c.Encoding), c.GetGeneration()) } else { - bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.Encoding, c.Generation) + bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.Encoding, c.GetGeneration()) } pools.EncodingBuffers.Put(buffer) diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 73f3f3ec4d..854b3e3b35 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -258,7 +258,7 @@ func (t *Trie) Put(key, value []byte) { func (t *Trie) tryPut(key, value []byte) { k := codec.KeyLEToNibbles(key) - t.root = t.insert(t.root, k, &node.Leaf{Key: nil, Value: value, Dirty: true, Generation: t.generation}) + t.root = t.insert(t.root, k, node.NewLeaf(nil, value, true, t.generation)) } // insert attempts to insert a key with value into the trie @@ -288,7 +288,9 @@ func (t *Trie) insert(parent Node, key []byte, value Node) Node { length := lenCommonPrefix(key, p.Key) // need to convert this leaf into a branch - br := &node.Branch{Key: key[:length], Dirty: true, Generation: t.generation} + var newBranchValue []byte + const newBranchDirty = true + br := node.NewBranch(key[:length], newBranchValue, newBranchDirty, t.generation) parentKey := p.Key // value goes at this branch @@ -367,7 +369,9 @@ func (t *Trie) updateBranch(p *node.Branch, key []byte, value Node) (n Node) { // we need to branch out at the point where the keys diverge // update partial keys, new branch has key up to matching length - br := &node.Branch{Key: key[:length], Dirty: true, Generation: t.generation} + var newBranchValue []byte + const newBranchDirty = true + br := node.NewBranch(key[:length], newBranchValue, newBranchDirty, t.generation) parentIndex := p.Key[length] br.Children[parentIndex] = t.insert(nil, p.Key[length+1:], p) From db578d44baa7ca076a76b33f7105d55952680940 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 13 Dec 2021 19:12:47 +0000 Subject: [PATCH 37/50] Unexport `Dirty` field for node --- internal/trie/node/branch.go | 10 +++++----- internal/trie/node/branch_encode.go | 2 +- internal/trie/node/copy.go | 4 ++-- internal/trie/node/decode.go | 4 ++-- internal/trie/node/decode_test.go | 12 ++++++------ internal/trie/node/dirty.go | 8 ++++---- internal/trie/node/encode_decode_test.go | 8 ++++---- internal/trie/node/hash.go | 2 +- internal/trie/node/leaf.go | 8 ++++---- internal/trie/node/leaf_encode.go | 2 +- lib/trie/trie.go | 6 +++--- lib/trie/trie_test.go | 9 ++++++--- 12 files changed, 39 insertions(+), 36 deletions(-) diff --git a/internal/trie/node/branch.go b/internal/trie/node/branch.go index 8669579aac..2f9af9d8d4 100644 --- a/internal/trie/node/branch.go +++ b/internal/trie/node/branch.go @@ -17,9 +17,9 @@ type Branch struct { Key []byte // partial key Children [16]Node Value []byte - // Dirty is true when the branch differs + // dirty is true when the branch differs // from the node stored in the database. - Dirty bool + dirty bool Hash []byte Encoding []byte // generation is incremented on every trie Snapshot() call. @@ -35,7 +35,7 @@ func NewBranch(key, value []byte, dirty bool, generation uint64) *Branch { return &Branch{ Key: key, Value: value, - Dirty: dirty, + dirty: dirty, generation: generation, } } @@ -43,8 +43,8 @@ func NewBranch(key, value []byte, dirty bool, generation uint64) *Branch { func (b *Branch) String() string { if len(b.Value) > 1024 { return fmt.Sprintf("key=%x childrenBitmap=%16b value (hashed)=%x dirty=%v", - b.Key, b.ChildrenBitmap(), common.MustBlake2bHash(b.Value), b.Dirty) + b.Key, b.ChildrenBitmap(), common.MustBlake2bHash(b.Value), b.dirty) } return fmt.Sprintf("key=%x childrenBitmap=%16b value=%v dirty=%v", - b.Key, b.ChildrenBitmap(), b.Value, b.Dirty) + b.Key, b.ChildrenBitmap(), b.Value, b.dirty) } diff --git a/internal/trie/node/branch_encode.go b/internal/trie/node/branch_encode.go index 599feeb39e..6520e234be 100644 --- a/internal/trie/node/branch_encode.go +++ b/internal/trie/node/branch_encode.go @@ -76,7 +76,7 @@ func (b *Branch) hash(digestBuffer io.Writer) (err error) { // Encode encodes a branch with the encoding specified at the top of this package // to the buffer given. func (b *Branch) Encode(buffer Buffer) (err error) { - if !b.Dirty && b.Encoding != nil { + if !b.dirty && b.Encoding != nil { _, err = buffer.Write(b.Encoding) if err != nil { return fmt.Errorf("cannot write stored encoding to buffer: %w", err) diff --git a/internal/trie/node/copy.go b/internal/trie/node/copy.go index 45add922a1..efcb0e8eae 100644 --- a/internal/trie/node/copy.go +++ b/internal/trie/node/copy.go @@ -12,7 +12,7 @@ func (b *Branch) Copy() Node { Key: make([]byte, len(b.Key)), Children: b.Children, // copy interface pointers Value: nil, - Dirty: b.Dirty, + dirty: b.dirty, Hash: make([]byte, len(b.Hash)), Encoding: make([]byte, len(b.Encoding)), generation: b.generation, @@ -41,7 +41,7 @@ func (l *Leaf) Copy() Node { cpy := &Leaf{ Key: make([]byte, len(l.Key)), Value: make([]byte, len(l.Value)), - Dirty: l.Dirty, + dirty: l.dirty, Hash: make([]byte, len(l.Hash)), Encoding: make([]byte, len(l.Encoding)), generation: l.generation, diff --git a/internal/trie/node/decode.go b/internal/trie/node/decode.go index 12dcbdcd3c..ccebf20bb3 100644 --- a/internal/trie/node/decode.go +++ b/internal/trie/node/decode.go @@ -108,7 +108,7 @@ func decodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { } } - branch.Dirty = true + branch.dirty = true return branch, nil } @@ -121,7 +121,7 @@ func decodeLeaf(reader io.Reader, header byte) (leaf *Leaf, err error) { } leaf = &Leaf{ - Dirty: true, + dirty: true, } keyLen := header & 0x3f diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go index b3b2d91ef8..99fedeca2f 100644 --- a/internal/trie/node/decode_test.go +++ b/internal/trie/node/decode_test.go @@ -71,7 +71,7 @@ func Test_Decode(t *testing.T) { n: &Leaf{ Key: []byte{9}, Value: []byte{1, 2, 3}, - Dirty: true, + dirty: true, }, }, "branch decoding error": { @@ -92,7 +92,7 @@ func Test_Decode(t *testing.T) { ), n: &Branch{ Key: []byte{9}, - Dirty: true, + dirty: true, }, }, } @@ -176,7 +176,7 @@ func Test_decodeBranch(t *testing.T) { Hash: []byte{1, 2, 3, 4, 5}, }, }, - Dirty: true, + dirty: true, }, }, "value decoding error for node type 3": { @@ -211,7 +211,7 @@ func Test_decodeBranch(t *testing.T) { Hash: []byte{1, 2, 3, 4, 5}, }, }, - Dirty: true, + dirty: true, }, }, } @@ -273,7 +273,7 @@ func Test_decodeLeaf(t *testing.T) { header: 65, // node type 1 and key length 1 leaf: &Leaf{ Key: []byte{9}, - Dirty: true, + dirty: true, }, }, "success": { @@ -287,7 +287,7 @@ func Test_decodeLeaf(t *testing.T) { leaf: &Leaf{ Key: []byte{9}, Value: []byte{1, 2, 3, 4, 5}, - Dirty: true, + dirty: true, }, }, } diff --git a/internal/trie/node/dirty.go b/internal/trie/node/dirty.go index 7922139b18..27d0367014 100644 --- a/internal/trie/node/dirty.go +++ b/internal/trie/node/dirty.go @@ -5,20 +5,20 @@ package node // IsDirty returns the dirty status of the branch. func (b *Branch) IsDirty() bool { - return b.Dirty + return b.dirty } // SetDirty sets the dirty status to the branch. func (b *Branch) SetDirty(dirty bool) { - b.Dirty = dirty + b.dirty = dirty } // IsDirty returns the dirty status of the leaf. func (l *Leaf) IsDirty() bool { - return l.Dirty + return l.dirty } // SetDirty sets the dirty status to the leaf. func (l *Leaf) SetDirty(dirty bool) { - l.Dirty = dirty + l.dirty = dirty } diff --git a/internal/trie/node/encode_decode_test.go b/internal/trie/node/encode_decode_test.go index f8ba60df3f..bb483e3bf2 100644 --- a/internal/trie/node/encode_decode_test.go +++ b/internal/trie/node/encode_decode_test.go @@ -22,7 +22,7 @@ func Test_Branch_Encode_Decode(t *testing.T) { branchToEncode: new(Branch), branchDecoded: &Branch{ Key: []byte{}, - Dirty: true, + dirty: true, }, }, "branch with key 5": { @@ -31,7 +31,7 @@ func Test_Branch_Encode_Decode(t *testing.T) { }, branchDecoded: &Branch{ Key: []byte{5}, - Dirty: true, + dirty: true, }, }, "branch with two bytes key": { @@ -40,7 +40,7 @@ func Test_Branch_Encode_Decode(t *testing.T) { }, branchDecoded: &Branch{ Key: []byte{0xf, 0xa}, - Dirty: true, + dirty: true, }, }, "branch with child": { @@ -60,7 +60,7 @@ func Test_Branch_Encode_Decode(t *testing.T) { Hash: []byte{0x41, 0x9, 0x4, 0xa}, }, }, - Dirty: true, + dirty: true, }, }, } diff --git a/internal/trie/node/hash.go b/internal/trie/node/hash.go index 97b74d18b0..9047e885aa 100644 --- a/internal/trie/node/hash.go +++ b/internal/trie/node/hash.go @@ -30,7 +30,7 @@ func (b *Branch) GetHash() []byte { // If the encoding is less than 32 bytes, the hash returned // is the encoding and not the hash of the encoding. func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) { - if !b.Dirty && b.Encoding != nil && b.Hash != nil { + if !b.dirty && b.Encoding != nil && b.Hash != nil { return b.Encoding, b.Hash, nil } diff --git a/internal/trie/node/leaf.go b/internal/trie/node/leaf.go index 7a42f55ab3..a5c0b88d46 100644 --- a/internal/trie/node/leaf.go +++ b/internal/trie/node/leaf.go @@ -18,7 +18,7 @@ type Leaf struct { Value []byte // Dirty is true when the branch differs // from the node stored in the database. - Dirty bool + dirty bool Hash []byte Encoding []byte encodingMu sync.RWMutex @@ -35,14 +35,14 @@ func NewLeaf(key, value []byte, dirty bool, generation uint64) *Leaf { return &Leaf{ Key: key, Value: value, - Dirty: dirty, + dirty: dirty, generation: generation, } } func (l *Leaf) String() string { if len(l.Value) > 1024 { - return fmt.Sprintf("leaf key=%x value (hashed)=%x dirty=%v", l.Key, common.MustBlake2bHash(l.Value), l.Dirty) + return fmt.Sprintf("leaf key=%x value (hashed)=%x dirty=%v", l.Key, common.MustBlake2bHash(l.Value), l.dirty) } - return fmt.Sprintf("leaf key=%x value=%v dirty=%v", l.Key, l.Value, l.Dirty) + return fmt.Sprintf("leaf key=%x value=%v dirty=%v", l.Key, l.Value, l.dirty) } diff --git a/internal/trie/node/leaf_encode.go b/internal/trie/node/leaf_encode.go index 184bfca644..526ce8cb32 100644 --- a/internal/trie/node/leaf_encode.go +++ b/internal/trie/node/leaf_encode.go @@ -87,7 +87,7 @@ func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { // NodeHeader | Extra partial key length | Partial Key | Value func (l *Leaf) Encode(buffer Buffer) (err error) { l.encodingMu.RLock() - if !l.Dirty && l.Encoding != nil { + if !l.dirty && l.Encoding != nil { _, err = buffer.Write(l.Encoding) l.encodingMu.RUnlock() if err != nil { diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 854b3e3b35..94a56053db 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -280,7 +280,7 @@ func (t *Trie) insert(parent Node, key []byte, value Node) Node { if p.Value != nil && bytes.Equal(p.Key, key) { if !bytes.Equal(value.(*node.Leaf).Value, p.Value) { p.Value = value.(*node.Leaf).Value - p.Dirty = true + p.SetDirty(true) } return p } @@ -492,7 +492,7 @@ func (t *Trie) retrieve(parent Node, key []byte) *node.Leaf { // found the value at this node if bytes.Equal(p.Key, key) || len(key) == 0 { - return &node.Leaf{Key: p.Key, Value: p.Value, Dirty: false} + return node.NewLeaf(p.Key, p.Value, false, 0) } // did not find value @@ -757,7 +757,7 @@ func handleDeletion(p *node.Branch, key []byte) Node { // if branch has no children, just a value, turn it into a leaf if bitmap == 0 && p.Value != nil { - n = &node.Leaf{Key: key[:length], Value: p.Value, Dirty: true} + n = node.NewBranch(key[:length], p.Value, true, 0) } else if p.NumChildren() == 1 && p.Value == nil { // there is only 1 child and no value, combine the child branch with this branch // find index of child diff --git a/lib/trie/trie_test.go b/lib/trie/trie_test.go index e564c52df2..cc19116a50 100644 --- a/lib/trie/trie_test.go +++ b/lib/trie/trie_test.go @@ -944,11 +944,14 @@ func TestClearPrefix_Small(t *testing.T) { } ssTrie.ClearPrefix([]byte("noo")) - require.Equal(t, ssTrie.root, &node.Leaf{ + + expectedRoot := &node.Leaf{ Key: codec.KeyLEToNibbles([]byte("other")), Value: []byte("other"), - Dirty: true, - }) + } + expectedRoot.SetDirty(true) + + require.Equal(t, expectedRoot, ssTrie.root) // Get the updated root hash of all tries. tHash, err = trie.Hash() From a4da41a6463da98fd4f3bf3c89766d3ff2f86fbf Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 13 Dec 2021 19:17:09 +0000 Subject: [PATCH 38/50] Unexport node's `Hash` as `hashDigest` --- internal/trie/node/branch.go | 6 +++--- internal/trie/node/copy.go | 8 ++++---- internal/trie/node/decode.go | 2 +- internal/trie/node/decode_test.go | 4 ++-- internal/trie/node/encode_decode_test.go | 2 +- internal/trie/node/hash.go | 18 +++++++++--------- internal/trie/node/leaf.go | 2 +- internal/trie/node/leaf_encode.go | 18 +++++++++--------- lib/trie/database.go | 2 +- 9 files changed, 31 insertions(+), 31 deletions(-) diff --git a/internal/trie/node/branch.go b/internal/trie/node/branch.go index 2f9af9d8d4..09005d2723 100644 --- a/internal/trie/node/branch.go +++ b/internal/trie/node/branch.go @@ -19,9 +19,9 @@ type Branch struct { Value []byte // dirty is true when the branch differs // from the node stored in the database. - dirty bool - Hash []byte - Encoding []byte + dirty bool + hashDigest []byte + Encoding []byte // generation is incremented on every trie Snapshot() call. // Nodes that are part of the trie are then gradually updated // to have a matching generation number as well, if they are diff --git a/internal/trie/node/copy.go b/internal/trie/node/copy.go index efcb0e8eae..ef5d599eb9 100644 --- a/internal/trie/node/copy.go +++ b/internal/trie/node/copy.go @@ -13,7 +13,7 @@ func (b *Branch) Copy() Node { Children: b.Children, // copy interface pointers Value: nil, dirty: b.dirty, - Hash: make([]byte, len(b.Hash)), + hashDigest: make([]byte, len(b.hashDigest)), Encoding: make([]byte, len(b.Encoding)), generation: b.generation, } @@ -25,7 +25,7 @@ func (b *Branch) Copy() Node { copy(cpy.Value, b.Value) } - copy(cpy.Hash, b.Hash) + copy(cpy.hashDigest, b.hashDigest) copy(cpy.Encoding, b.Encoding) return cpy } @@ -42,13 +42,13 @@ func (l *Leaf) Copy() Node { Key: make([]byte, len(l.Key)), Value: make([]byte, len(l.Value)), dirty: l.dirty, - Hash: make([]byte, len(l.Hash)), + hashDigest: make([]byte, len(l.hashDigest)), Encoding: make([]byte, len(l.Encoding)), generation: l.generation, } copy(cpy.Key, l.Key) copy(cpy.Value, l.Value) - copy(cpy.Hash, l.Hash) + copy(cpy.hashDigest, l.hashDigest) copy(cpy.Encoding, l.Encoding) return cpy } diff --git a/internal/trie/node/decode.go b/internal/trie/node/decode.go index ccebf20bb3..1a9850e10b 100644 --- a/internal/trie/node/decode.go +++ b/internal/trie/node/decode.go @@ -104,7 +104,7 @@ func decodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { } branch.Children[i] = &Leaf{ - Hash: hash, + hashDigest: hash, } } diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go index 99fedeca2f..b5706d7b76 100644 --- a/internal/trie/node/decode_test.go +++ b/internal/trie/node/decode_test.go @@ -173,7 +173,7 @@ func Test_decodeBranch(t *testing.T) { nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, &Leaf{ - Hash: []byte{1, 2, 3, 4, 5}, + hashDigest: []byte{1, 2, 3, 4, 5}, }, }, dirty: true, @@ -208,7 +208,7 @@ func Test_decodeBranch(t *testing.T) { nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, &Leaf{ - Hash: []byte{1, 2, 3, 4, 5}, + hashDigest: []byte{1, 2, 3, 4, 5}, }, }, dirty: true, diff --git a/internal/trie/node/encode_decode_test.go b/internal/trie/node/encode_decode_test.go index bb483e3bf2..898ed9b7e0 100644 --- a/internal/trie/node/encode_decode_test.go +++ b/internal/trie/node/encode_decode_test.go @@ -57,7 +57,7 @@ func Test_Branch_Encode_Decode(t *testing.T) { Key: []byte{5}, Children: [16]Node{ &Leaf{ - Hash: []byte{0x41, 0x9, 0x4, 0xa}, + hashDigest: []byte{0x41, 0x9, 0x4, 0xa}, }, }, dirty: true, diff --git a/internal/trie/node/hash.go b/internal/trie/node/hash.go index 9047e885aa..5fefe928dd 100644 --- a/internal/trie/node/hash.go +++ b/internal/trie/node/hash.go @@ -14,7 +14,7 @@ import ( // given to the branch. Note it does not copy them, so beware. func (b *Branch) SetEncodingAndHash(enc, hash []byte) { b.Encoding = enc - b.Hash = hash + b.hashDigest = hash } // GetHash returns the hash of the branch. @@ -22,7 +22,7 @@ func (b *Branch) SetEncodingAndHash(enc, hash []byte) { // the returned hash will modify the hash // of the branch. func (b *Branch) GetHash() []byte { - return b.Hash + return b.hashDigest } // EncodeAndHash returns the encoding of the branch and @@ -30,8 +30,8 @@ func (b *Branch) GetHash() []byte { // If the encoding is less than 32 bytes, the hash returned // is the encoding and not the hash of the encoding. func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) { - if !b.dirty && b.Encoding != nil && b.Hash != nil { - return b.Encoding, b.Hash, nil + if !b.dirty && b.Encoding != nil && b.hashDigest != nil { + return b.Encoding, b.hashDigest, nil } buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) @@ -50,9 +50,9 @@ func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) { encoding = b.Encoding // no need to copy if buffer.Len() < 32 { - b.Hash = make([]byte, len(bufferBytes)) - copy(b.Hash, bufferBytes) - hash = b.Hash // no need to copy + b.hashDigest = make([]byte, len(bufferBytes)) + copy(b.hashDigest, bufferBytes) + hash = b.hashDigest // no need to copy return encoding, hash, nil } @@ -61,8 +61,8 @@ func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) { if err != nil { return nil, nil, err } - b.Hash = hashArray[:] - hash = b.Hash // no need to copy + b.hashDigest = hashArray[:] + hash = b.hashDigest // no need to copy return encoding, hash, nil } diff --git a/internal/trie/node/leaf.go b/internal/trie/node/leaf.go index a5c0b88d46..59202829a9 100644 --- a/internal/trie/node/leaf.go +++ b/internal/trie/node/leaf.go @@ -19,7 +19,7 @@ type Leaf struct { // Dirty is true when the branch differs // from the node stored in the database. dirty bool - Hash []byte + hashDigest []byte Encoding []byte encodingMu sync.RWMutex // generation is incremented on every trie Snapshot() call. diff --git a/internal/trie/node/leaf_encode.go b/internal/trie/node/leaf_encode.go index 526ce8cb32..b203d7d9f0 100644 --- a/internal/trie/node/leaf_encode.go +++ b/internal/trie/node/leaf_encode.go @@ -21,7 +21,7 @@ func (l *Leaf) SetEncodingAndHash(enc, hash []byte) { l.encodingMu.Lock() l.Encoding = enc l.encodingMu.Unlock() - l.Hash = hash + l.hashDigest = hash } // GetHash returns the hash of the leaf. @@ -29,7 +29,7 @@ func (l *Leaf) SetEncodingAndHash(enc, hash []byte) { // the returned hash will modify the hash // of the branch. func (l *Leaf) GetHash() []byte { - return l.Hash + return l.hashDigest } // EncodeAndHash returns the encoding of the leaf and @@ -38,9 +38,9 @@ func (l *Leaf) GetHash() []byte { // is the encoding and not the hash of the encoding. func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { l.encodingMu.RLock() - if !l.IsDirty() && l.Encoding != nil && l.Hash != nil { + if !l.IsDirty() && l.Encoding != nil && l.hashDigest != nil { l.encodingMu.RUnlock() - return l.Encoding, l.Hash, nil + return l.Encoding, l.hashDigest, nil } l.encodingMu.RUnlock() @@ -64,9 +64,9 @@ func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { encoding = l.Encoding // no need to copy if len(bufferBytes) < 32 { - l.Hash = make([]byte, len(bufferBytes)) - copy(l.Hash, bufferBytes) - hash = l.Hash // no need to copy + l.hashDigest = make([]byte, len(bufferBytes)) + copy(l.hashDigest, bufferBytes) + hash = l.hashDigest // no need to copy return encoding, hash, nil } @@ -76,8 +76,8 @@ func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { return nil, nil, err } - l.Hash = hashArray[:] - hash = l.Hash // no need to copy + l.hashDigest = hashArray[:] + hash = l.hashDigest // no need to copy return encoding, hash, nil } diff --git a/lib/trie/database.go b/lib/trie/database.go index 362720c5ce..f80c2d7b65 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -273,7 +273,7 @@ func getFromDB(db chaindb.Database, parent Node, key []byte) ([]byte, error) { } // load child with potential value - enc, err := db.Get(p.Children[key[length]].(*node.Leaf).Hash) + enc, err := db.Get(p.Children[key[length]].GetHash()) if err != nil { return nil, fmt.Errorf("failed to find node in database: %w", err) } From 89bf885a02de78c1a60ba38569a8042a3514dd50 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 13 Dec 2021 19:22:15 +0000 Subject: [PATCH 39/50] Trie string does not cache encoding in nodes --- lib/trie/print.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/lib/trie/print.go b/lib/trie/print.go index 4b3ffcc4da..e39c8069f6 100644 --- a/lib/trie/print.go +++ b/lib/trie/print.go @@ -32,14 +32,14 @@ func (t *Trie) string(tree gotree.Tree, curr Node, idx int) { buffer.Reset() _ = c.Encode(buffer) - c.Encoding = buffer.Bytes() + encoding := buffer.Bytes() var bstr string - if len(c.Encoding) > 1024 { + if len(encoding) > 1024 { bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", - idx, c, common.MustBlake2bHash(c.Encoding), c.GetGeneration()) + idx, c, common.MustBlake2bHash(encoding), c.GetGeneration()) } else { - bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.Encoding, c.GetGeneration()) + bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), encoding, c.GetGeneration()) } pools.EncodingBuffers.Put(buffer) @@ -56,16 +56,14 @@ func (t *Trie) string(tree gotree.Tree, curr Node, idx int) { _ = c.Encode(buffer) - // TODO lock or use methods on leaf to set the encoding bytes. - // Right now this is only used for debugging so no need to lock - c.Encoding = buffer.Bytes() + encoding := buffer.Bytes() var bstr string - if len(c.Encoding) > 1024 { + if len(encoding) > 1024 { bstr = fmt.Sprintf("idx=%d %s hash=%x gen=%d", - idx, c.String(), common.MustBlake2bHash(c.Encoding), c.GetGeneration()) + idx, c.String(), common.MustBlake2bHash(encoding), c.GetGeneration()) } else { - bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), c.Encoding, c.GetGeneration()) + bstr = fmt.Sprintf("idx=%d %s encode=%x gen=%d", idx, c.String(), encoding, c.GetGeneration()) } pools.EncodingBuffers.Put(buffer) From 23560c653c372bc98a7da1fe7ead02e89b191f10 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 13 Dec 2021 19:22:54 +0000 Subject: [PATCH 40/50] Unexport node's `Encoding` field --- internal/trie/node/branch.go | 2 +- internal/trie/node/branch_encode.go | 4 ++-- internal/trie/node/branch_encode_test.go | 4 ++-- internal/trie/node/copy.go | 8 ++++---- internal/trie/node/hash.go | 12 ++++++------ internal/trie/node/leaf.go | 2 +- internal/trie/node/leaf_encode.go | 20 ++++++++++---------- internal/trie/node/leaf_encode_test.go | 14 +++++++------- 8 files changed, 33 insertions(+), 33 deletions(-) diff --git a/internal/trie/node/branch.go b/internal/trie/node/branch.go index 09005d2723..ead394b801 100644 --- a/internal/trie/node/branch.go +++ b/internal/trie/node/branch.go @@ -21,7 +21,7 @@ type Branch struct { // from the node stored in the database. dirty bool hashDigest []byte - Encoding []byte + encoding []byte // generation is incremented on every trie Snapshot() call. // Nodes that are part of the trie are then gradually updated // to have a matching generation number as well, if they are diff --git a/internal/trie/node/branch_encode.go b/internal/trie/node/branch_encode.go index 6520e234be..76700cc567 100644 --- a/internal/trie/node/branch_encode.go +++ b/internal/trie/node/branch_encode.go @@ -76,8 +76,8 @@ func (b *Branch) hash(digestBuffer io.Writer) (err error) { // Encode encodes a branch with the encoding specified at the top of this package // to the buffer given. func (b *Branch) Encode(buffer Buffer) (err error) { - if !b.dirty && b.Encoding != nil { - _, err = buffer.Write(b.Encoding) + if !b.dirty && b.encoding != nil { + _, err = buffer.Write(b.encoding) if err != nil { return fmt.Errorf("cannot write stored encoding to buffer: %w", err) } diff --git a/internal/trie/node/branch_encode_test.go b/internal/trie/node/branch_encode_test.go index 60be03aa18..54504665bf 100644 --- a/internal/trie/node/branch_encode_test.go +++ b/internal/trie/node/branch_encode_test.go @@ -22,7 +22,7 @@ func Test_Branch_Encode(t *testing.T) { }{ "clean branch with encoding": { branch: &Branch{ - Encoding: []byte{1, 2, 3}, + encoding: []byte{1, 2, 3}, }, writes: []writeCall{ { // stored encoding @@ -32,7 +32,7 @@ func Test_Branch_Encode(t *testing.T) { }, "write error for clean branch with encoding": { branch: &Branch{ - Encoding: []byte{1, 2, 3}, + encoding: []byte{1, 2, 3}, }, writes: []writeCall{ { // stored encoding diff --git a/internal/trie/node/copy.go b/internal/trie/node/copy.go index ef5d599eb9..a90be5a680 100644 --- a/internal/trie/node/copy.go +++ b/internal/trie/node/copy.go @@ -14,7 +14,7 @@ func (b *Branch) Copy() Node { Value: nil, dirty: b.dirty, hashDigest: make([]byte, len(b.hashDigest)), - Encoding: make([]byte, len(b.Encoding)), + encoding: make([]byte, len(b.encoding)), generation: b.generation, } copy(cpy.Key, b.Key) @@ -26,7 +26,7 @@ func (b *Branch) Copy() Node { } copy(cpy.hashDigest, b.hashDigest) - copy(cpy.Encoding, b.Encoding) + copy(cpy.encoding, b.encoding) return cpy } @@ -43,12 +43,12 @@ func (l *Leaf) Copy() Node { Value: make([]byte, len(l.Value)), dirty: l.dirty, hashDigest: make([]byte, len(l.hashDigest)), - Encoding: make([]byte, len(l.Encoding)), + encoding: make([]byte, len(l.encoding)), generation: l.generation, } copy(cpy.Key, l.Key) copy(cpy.Value, l.Value) copy(cpy.hashDigest, l.hashDigest) - copy(cpy.Encoding, l.Encoding) + copy(cpy.encoding, l.encoding) return cpy } diff --git a/internal/trie/node/hash.go b/internal/trie/node/hash.go index 5fefe928dd..ac8a584572 100644 --- a/internal/trie/node/hash.go +++ b/internal/trie/node/hash.go @@ -13,7 +13,7 @@ import ( // SetEncodingAndHash sets the encoding and hash slices // given to the branch. Note it does not copy them, so beware. func (b *Branch) SetEncodingAndHash(enc, hash []byte) { - b.Encoding = enc + b.encoding = enc b.hashDigest = hash } @@ -30,8 +30,8 @@ func (b *Branch) GetHash() []byte { // If the encoding is less than 32 bytes, the hash returned // is the encoding and not the hash of the encoding. func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) { - if !b.dirty && b.Encoding != nil && b.hashDigest != nil { - return b.Encoding, b.hashDigest, nil + if !b.dirty && b.encoding != nil && b.hashDigest != nil { + return b.encoding, b.hashDigest, nil } buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) @@ -45,9 +45,9 @@ func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) { bufferBytes := buffer.Bytes() - b.Encoding = make([]byte, len(bufferBytes)) - copy(b.Encoding, bufferBytes) - encoding = b.Encoding // no need to copy + b.encoding = make([]byte, len(bufferBytes)) + copy(b.encoding, bufferBytes) + encoding = b.encoding // no need to copy if buffer.Len() < 32 { b.hashDigest = make([]byte, len(bufferBytes)) diff --git a/internal/trie/node/leaf.go b/internal/trie/node/leaf.go index 59202829a9..3aab730cbb 100644 --- a/internal/trie/node/leaf.go +++ b/internal/trie/node/leaf.go @@ -20,7 +20,7 @@ type Leaf struct { // from the node stored in the database. dirty bool hashDigest []byte - Encoding []byte + encoding []byte encodingMu sync.RWMutex // generation is incremented on every trie Snapshot() call. // Nodes that are part of the trie are then gradually updated diff --git a/internal/trie/node/leaf_encode.go b/internal/trie/node/leaf_encode.go index b203d7d9f0..291b45ae40 100644 --- a/internal/trie/node/leaf_encode.go +++ b/internal/trie/node/leaf_encode.go @@ -19,7 +19,7 @@ import ( // given to the branch. Note it does not copy them, so beware. func (l *Leaf) SetEncodingAndHash(enc, hash []byte) { l.encodingMu.Lock() - l.Encoding = enc + l.encoding = enc l.encodingMu.Unlock() l.hashDigest = hash } @@ -38,9 +38,9 @@ func (l *Leaf) GetHash() []byte { // is the encoding and not the hash of the encoding. func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { l.encodingMu.RLock() - if !l.IsDirty() && l.Encoding != nil && l.hashDigest != nil { + if !l.IsDirty() && l.encoding != nil && l.hashDigest != nil { l.encodingMu.RUnlock() - return l.Encoding, l.hashDigest, nil + return l.encoding, l.hashDigest, nil } l.encodingMu.RUnlock() @@ -58,10 +58,10 @@ func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { l.encodingMu.Lock() // TODO remove this copying since it defeats the purpose of `buffer` // and the sync.Pool. - l.Encoding = make([]byte, len(bufferBytes)) - copy(l.Encoding, bufferBytes) + l.encoding = make([]byte, len(bufferBytes)) + copy(l.encoding, bufferBytes) l.encodingMu.Unlock() - encoding = l.Encoding // no need to copy + encoding = l.encoding // no need to copy if len(bufferBytes) < 32 { l.hashDigest = make([]byte, len(bufferBytes)) @@ -87,8 +87,8 @@ func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { // NodeHeader | Extra partial key length | Partial Key | Value func (l *Leaf) Encode(buffer Buffer) (err error) { l.encodingMu.RLock() - if !l.dirty && l.Encoding != nil { - _, err = buffer.Write(l.Encoding) + if !l.dirty && l.encoding != nil { + _, err = buffer.Write(l.encoding) l.encodingMu.RUnlock() if err != nil { return fmt.Errorf("cannot write stored encoding to buffer: %w", err) @@ -122,8 +122,8 @@ func (l *Leaf) Encode(buffer Buffer) (err error) { // and the sync.Pool. l.encodingMu.Lock() defer l.encodingMu.Unlock() - l.Encoding = make([]byte, buffer.Len()) - copy(l.Encoding, buffer.Bytes()) + l.encoding = make([]byte, buffer.Len()) + copy(l.encoding, buffer.Bytes()) return nil } diff --git a/internal/trie/node/leaf_encode_test.go b/internal/trie/node/leaf_encode_test.go index fdac0713c8..61eb78ad9b 100644 --- a/internal/trie/node/leaf_encode_test.go +++ b/internal/trie/node/leaf_encode_test.go @@ -26,7 +26,7 @@ func Test_Leaf_Encode(t *testing.T) { }{ "clean leaf with encoding": { leaf: &Leaf{ - Encoding: []byte{1, 2, 3}, + encoding: []byte{1, 2, 3}, }, writes: []writeCall{ { @@ -37,7 +37,7 @@ func Test_Leaf_Encode(t *testing.T) { }, "write error for clean leaf with encoding": { leaf: &Leaf{ - Encoding: []byte{1, 2, 3}, + encoding: []byte{1, 2, 3}, }, writes: []writeCall{ { @@ -153,7 +153,7 @@ func Test_Leaf_Encode(t *testing.T) { } else { require.NoError(t, err) } - assert.Equal(t, testCase.expectedEncoding, testCase.leaf.Encoding) + assert.Equal(t, testCase.expectedEncoding, testCase.leaf.encoding) }) } } @@ -204,7 +204,7 @@ func Test_Leaf_hash(t *testing.T) { }{ "small leaf buffer write error": { leaf: &Leaf{ - Encoding: []byte{1, 2, 3}, + encoding: []byte{1, 2, 3}, }, writeCall: true, write: writeCall{ @@ -217,7 +217,7 @@ func Test_Leaf_hash(t *testing.T) { }, "small leaf success": { leaf: &Leaf{ - Encoding: []byte{1, 2, 3}, + encoding: []byte{1, 2, 3}, }, writeCall: true, write: writeCall{ @@ -226,7 +226,7 @@ func Test_Leaf_hash(t *testing.T) { }, "leaf hash sum buffer write error": { leaf: &Leaf{ - Encoding: []byte{ + encoding: []byte{ 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, @@ -250,7 +250,7 @@ func Test_Leaf_hash(t *testing.T) { }, "leaf hash sum success": { leaf: &Leaf{ - Encoding: []byte{ + encoding: []byte{ 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, From 1837d85b892c03058de3aae22e5dc823ca21ec8e Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 14 Dec 2021 12:32:06 +0000 Subject: [PATCH 41/50] Fix `NewBranch` to `NewLeaf` --- lib/trie/trie.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/trie/trie.go b/lib/trie/trie.go index 94a56053db..c8a33d7167 100644 --- a/lib/trie/trie.go +++ b/lib/trie/trie.go @@ -757,7 +757,7 @@ func handleDeletion(p *node.Branch, key []byte) Node { // if branch has no children, just a value, turn it into a leaf if bitmap == 0 && p.Value != nil { - n = node.NewBranch(key[:length], p.Value, true, 0) + n = node.NewLeaf(key[:length], p.Value, true, 0) } else if p.NumChildren() == 1 && p.Value == nil { // there is only 1 child and no value, combine the child branch with this branch // find index of child From 3129e796ad7d177783c12589a4305a9e22da88ab Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 14 Dec 2021 12:38:08 +0000 Subject: [PATCH 42/50] Add node type comments in tests --- internal/trie/node/decode_test.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go index b5706d7b76..9998dde230 100644 --- a/internal/trie/node/decode_test.go +++ b/internal/trie/node/decode_test.go @@ -52,7 +52,7 @@ func Test_Decode(t *testing.T) { }, "leaf decoding error": { reader: bytes.NewReader([]byte{ - 65, // node type 1 and key length 1 + 65, // node type 1 (leaf) and key length 1 // missing key data byte }), errWrapped: ErrReadKeyData, @@ -62,7 +62,7 @@ func Test_Decode(t *testing.T) { reader: bytes.NewReader( append( []byte{ - 65, // node type 1 and key length 1 + 65, // node type 1 (leaf) and key length 1 9, // key data }, scaleEncodeBytes(t, 1, 2, 3)..., @@ -76,7 +76,7 @@ func Test_Decode(t *testing.T) { }, "branch decoding error": { reader: bytes.NewReader([]byte{ - 129, // node type 2 and key length 1 + 129, // node type 2 (branch without value) and key length 1 // missing key data byte }), errWrapped: ErrReadKeyData, @@ -85,7 +85,7 @@ func Test_Decode(t *testing.T) { "branch success": { reader: bytes.NewReader( []byte{ - 129, // node type 2 and key length 1 + 129, // node type 2 (branch without value) and key length 1 9, // key data 0, 0, // no children bitmap }, @@ -133,7 +133,7 @@ func Test_decodeBranch(t *testing.T) { reader: bytes.NewBuffer([]byte{ // missing key data byte }), - header: 129, // node type 2 and key length 1 + header: 129, // node type 2 (branch without value) and key length 1 errWrapped: ErrReadKeyData, errMessage: "cannot decode key: cannot read key data: EOF", }, @@ -142,7 +142,7 @@ func Test_decodeBranch(t *testing.T) { 9, // key data // missing children bitmap 2 bytes }), - header: 129, // node type 2 and key length 1 + header: 129, // node type 2 (branch without value) and key length 1 errWrapped: ErrReadChildrenBitmap, errMessage: "cannot read children bitmap: EOF", }, @@ -152,7 +152,7 @@ func Test_decodeBranch(t *testing.T) { 0, 4, // children bitmap // missing children scale encoded data }), - header: 129, // node type 2 and key length 1 + header: 129, // node type 2 (branch without value) and key length 1 errWrapped: ErrDecodeChildHash, errMessage: "cannot decode child hash: at index 10: EOF", }, @@ -166,7 +166,7 @@ func Test_decodeBranch(t *testing.T) { scaleEncodeBytes(t, 1, 2, 3, 4, 5), // child hash }), ), - header: 129, // node type 2 and key length 1 + header: 129, // node type 2 (branch without value) and key length 1 branch: &Branch{ Key: []byte{9}, Children: [16]Node{ @@ -187,7 +187,7 @@ func Test_decodeBranch(t *testing.T) { // missing encoded branch value }), ), - header: 193, // node type 3 and key length 1 + header: 193, // node type 3 (branch with value) and key length 1 errWrapped: ErrDecodeValue, errMessage: "cannot decode value: EOF", }, @@ -200,7 +200,7 @@ func Test_decodeBranch(t *testing.T) { scaleEncodeBytes(t, 1, 2, 3, 4, 5), // child hash }), ), - header: 193, // node type 3 and key length 1 + header: 193, // node type 3 (branch with value) and key length 1 branch: &Branch{ Key: []byte{9}, Value: []byte{7, 8, 9}, @@ -252,7 +252,7 @@ func Test_decodeLeaf(t *testing.T) { reader: bytes.NewBuffer([]byte{ // missing key data byte }), - header: 65, // node type 1 and key length 1 + header: 65, // node type 1 (leaf) and key length 1 errWrapped: ErrReadKeyData, errMessage: "cannot decode key: cannot read key data: EOF", }, @@ -261,7 +261,7 @@ func Test_decodeLeaf(t *testing.T) { 9, // key data 255, 255, // bad value data }), - header: 65, // node type 1 and key length 1 + header: 65, // node type 1 (leaf) and key length 1 errWrapped: ErrDecodeValue, errMessage: "cannot decode value: could not decode invalid integer", }, @@ -270,7 +270,7 @@ func Test_decodeLeaf(t *testing.T) { 9, // key data // missing value data }), - header: 65, // node type 1 and key length 1 + header: 65, // node type 1 (leaf) and key length 1 leaf: &Leaf{ Key: []byte{9}, dirty: true, @@ -283,7 +283,7 @@ func Test_decodeLeaf(t *testing.T) { scaleEncodeBytes(t, 1, 2, 3, 4, 5), // value data }), ), - header: 65, // node type 1 and key length 1 + header: 65, // node type 1 (leaf) and key length 1 leaf: &Leaf{ Key: []byte{9}, Value: []byte{1, 2, 3, 4, 5}, From 7503f21e381245378d0c3250b84a9b2d2a1ceab0 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 14 Dec 2021 12:39:20 +0000 Subject: [PATCH 43/50] Use `node.Type` for Type constants --- internal/trie/node/decode.go | 6 +++--- internal/trie/node/types.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/trie/node/decode.go b/internal/trie/node/decode.go index 1a9850e10b..5bab716a92 100644 --- a/internal/trie/node/decode.go +++ b/internal/trie/node/decode.go @@ -36,7 +36,7 @@ func Decode(reader io.Reader) (n Node, err error) { } header := oneByteBuf[0] - nodeType := header >> 6 + nodeType := Type(header >> 6) switch nodeType { case LeafType: n, err = decodeLeaf(reader, header) @@ -61,7 +61,7 @@ func Decode(reader io.Reader) (n Node, err error) { // children are known to be with an empty leaf. The children nodes hashes are then used to // find other values using the persistent database. func decodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { - nodeType := header >> 6 + nodeType := Type(header >> 6) if nodeType != BranchType && nodeType != BranchWithValueType { return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotABranch, nodeType) } @@ -115,7 +115,7 @@ func decodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { // decodeLeaf reads and decodes from a reader with the encoding specified in lib/trie/node/encode_doc.go. func decodeLeaf(reader io.Reader, header byte) (leaf *Leaf, err error) { - nodeType := header >> 6 + nodeType := Type(header >> 6) if nodeType != LeafType { return nil, fmt.Errorf("%w: %d", ErrNodeTypeIsNotALeaf, nodeType) } diff --git a/internal/trie/node/types.go b/internal/trie/node/types.go index a912955b3e..5f0ef8191b 100644 --- a/internal/trie/node/types.go +++ b/internal/trie/node/types.go @@ -7,7 +7,7 @@ package node type Type byte const ( - _ = iota + _ Type = iota // LeafType type is 1 LeafType // BranchType type is 2 From e17c08e42cc377144e4d6a23b13520b8f1c7f3fc Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Tue, 14 Dec 2021 20:31:14 +0000 Subject: [PATCH 44/50] `keyLenOffset` constant 0x3f --- internal/trie/node/decode.go | 4 ++-- internal/trie/node/header.go | 12 ++++++++---- internal/trie/node/key.go | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/internal/trie/node/decode.go b/internal/trie/node/decode.go index 5bab716a92..007cae95c3 100644 --- a/internal/trie/node/decode.go +++ b/internal/trie/node/decode.go @@ -68,7 +68,7 @@ func decodeBranch(reader io.Reader, header byte) (branch *Branch, err error) { branch = new(Branch) - keyLen := header & 0x3f + keyLen := header & keyLenOffset branch.Key, err = decodeKey(reader, keyLen) if err != nil { return nil, fmt.Errorf("cannot decode key: %w", err) @@ -124,7 +124,7 @@ func decodeLeaf(reader io.Reader, header byte) (leaf *Leaf, err error) { dirty: true, } - keyLen := header & 0x3f + keyLen := header & keyLenOffset leaf.Key, err = decodeKey(reader, keyLen) if err != nil { return nil, fmt.Errorf("cannot decode key: %w", err) diff --git a/internal/trie/node/header.go b/internal/trie/node/header.go index b5d28a6415..424d21e307 100644 --- a/internal/trie/node/header.go +++ b/internal/trie/node/header.go @@ -7,6 +7,10 @@ import ( "io" ) +const ( + keyLenOffset = 0x3f +) + // encodeHeader creates the encoded header for the branch. func (b *Branch) encodeHeader(writer io.Writer) (err error) { var header byte @@ -16,8 +20,8 @@ func (b *Branch) encodeHeader(writer io.Writer) (err error) { header = 3 << 6 } - if len(b.Key) >= 63 { - header = header | 0x3f + if len(b.Key) >= keyLenOffset { + header = header | keyLenOffset _, err = writer.Write([]byte{header}) if err != nil { return err @@ -43,12 +47,12 @@ func (l *Leaf) encodeHeader(writer io.Writer) (err error) { var header byte = 1 << 6 if len(l.Key) < 63 { - header = header | byte(len(l.Key)) + header |= byte(len(l.Key)) _, err = writer.Write([]byte{header}) return err } - header = header | 0x3f + header |= keyLenOffset _, err = writer.Write([]byte{header}) if err != nil { return err diff --git a/internal/trie/node/key.go b/internal/trie/node/key.go index c1dfd56865..c42438493b 100644 --- a/internal/trie/node/key.go +++ b/internal/trie/node/key.go @@ -81,7 +81,7 @@ func encodeKeyLength(keyLength int, writer io.Writer) (err error) { func decodeKey(reader io.Reader, keyLength byte) (b []byte, err error) { publicKeyLength := int(keyLength) - if keyLength == 0x3f { + if keyLength == keyLenOffset { // partial key longer than 63, read next bytes for rest of pk len buffer := pools.SingleByteBuffers.Get().(*bytes.Buffer) defer pools.SingleByteBuffers.Put(buffer) From 865aae6d143ec3137c92d30a7ce7235b96bb41bb Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Wed, 15 Dec 2021 19:02:38 +0000 Subject: [PATCH 45/50] Fix comment (@noot suggestion) --- internal/trie/node/branch_encode.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/trie/node/branch_encode.go b/internal/trie/node/branch_encode.go index 76700cc567..badd3556f0 100644 --- a/internal/trie/node/branch_encode.go +++ b/internal/trie/node/branch_encode.go @@ -46,7 +46,7 @@ func (b *Branch) hash(digestBuffer io.Writer) (err error) { return fmt.Errorf("cannot encode leaf: %w", err) } - // if length of encoded leaf is less than 32 bytes, do not hash + // if length of encoded branch is less than 32 bytes, do not hash if encodingBuffer.Len() < 32 { _, err = digestBuffer.Write(encodingBuffer.Bytes()) if err != nil { From 8c3178891f3f0aeeb842482f8c9213619e04b232 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Wed, 15 Dec 2021 20:01:45 +0000 Subject: [PATCH 46/50] Updated generation comment --- internal/trie/node/branch.go | 6 +++--- internal/trie/node/leaf.go | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/trie/node/branch.go b/internal/trie/node/branch.go index ead394b801..58473023a9 100644 --- a/internal/trie/node/branch.go +++ b/internal/trie/node/branch.go @@ -23,9 +23,9 @@ type Branch struct { hashDigest []byte encoding []byte // generation is incremented on every trie Snapshot() call. - // Nodes that are part of the trie are then gradually updated - // to have a matching generation number as well, if they are - // still relevant. + // Each node also contain a certain generation number, + // which is updated to match the trie generation once they are + // inserted, moved or iterated over. generation uint64 sync.RWMutex } diff --git a/internal/trie/node/leaf.go b/internal/trie/node/leaf.go index 3aab730cbb..77c884f397 100644 --- a/internal/trie/node/leaf.go +++ b/internal/trie/node/leaf.go @@ -23,9 +23,9 @@ type Leaf struct { encoding []byte encodingMu sync.RWMutex // generation is incremented on every trie Snapshot() call. - // Nodes that are part of the trie are then gradually updated - // to have a matching generation number as well, if they are - // still relevant. + // Each node also contain a certain generation number, + // which is updated to match the trie generation once they are + // inserted, moved or iterated over. generation uint64 sync.RWMutex } From 831213c415835a1a1a929f2e12d0f55ea9137fa0 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 16 Dec 2021 12:47:02 +0000 Subject: [PATCH 47/50] fix copy functions to copy nil slices correctly --- internal/trie/node/copy.go | 51 +++++++++---- internal/trie/node/copy_test.go | 124 ++++++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 14 deletions(-) create mode 100644 internal/trie/node/copy_test.go diff --git a/internal/trie/node/copy.go b/internal/trie/node/copy.go index a90be5a680..1a59f24d21 100644 --- a/internal/trie/node/copy.go +++ b/internal/trie/node/copy.go @@ -9,24 +9,33 @@ func (b *Branch) Copy() Node { defer b.RUnlock() cpy := &Branch{ - Key: make([]byte, len(b.Key)), Children: b.Children, // copy interface pointers - Value: nil, dirty: b.dirty, - hashDigest: make([]byte, len(b.hashDigest)), - encoding: make([]byte, len(b.encoding)), generation: b.generation, } copy(cpy.Key, b.Key) + if b.Key != nil { + cpy.Key = make([]byte, len(b.Key)) + copy(cpy.Key, b.Key) + } + // nil and []byte{} are encoded differently, watch out! if b.Value != nil { cpy.Value = make([]byte, len(b.Value)) copy(cpy.Value, b.Value) } - copy(cpy.hashDigest, b.hashDigest) - copy(cpy.encoding, b.encoding) + if b.hashDigest != nil { + cpy.hashDigest = make([]byte, len(b.hashDigest)) + copy(cpy.hashDigest, b.hashDigest) + } + + if b.encoding != nil { + cpy.encoding = make([]byte, len(b.encoding)) + copy(cpy.encoding, b.encoding) + } + return cpy } @@ -39,16 +48,30 @@ func (l *Leaf) Copy() Node { defer l.encodingMu.RUnlock() cpy := &Leaf{ - Key: make([]byte, len(l.Key)), - Value: make([]byte, len(l.Value)), dirty: l.dirty, - hashDigest: make([]byte, len(l.hashDigest)), - encoding: make([]byte, len(l.encoding)), generation: l.generation, } - copy(cpy.Key, l.Key) - copy(cpy.Value, l.Value) - copy(cpy.hashDigest, l.hashDigest) - copy(cpy.encoding, l.encoding) + + if l.Key != nil { + cpy.Key = make([]byte, len(l.Key)) + copy(cpy.Key, l.Key) + } + + // nil and []byte{} are encoded differently, watch out! + if l.Value != nil { + cpy.Value = make([]byte, len(l.Value)) + copy(cpy.Value, l.Value) + } + + if l.hashDigest != nil { + cpy.hashDigest = make([]byte, len(l.hashDigest)) + copy(cpy.hashDigest, l.hashDigest) + } + + if l.encoding != nil { + cpy.encoding = make([]byte, len(l.encoding)) + copy(cpy.encoding, l.encoding) + } + return cpy } diff --git a/internal/trie/node/copy_test.go b/internal/trie/node/copy_test.go new file mode 100644 index 0000000000..75de5f6284 --- /dev/null +++ b/internal/trie/node/copy_test.go @@ -0,0 +1,124 @@ +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testForSliceModif(t *testing.T, original, copied []byte) { + t.Helper() + require.Equal(t, len(original), len(copied)) + if len(copied) == 0 { + // cannot test for modification + return + } + original[0]++ + assert.NotEqual(t, copied, original) +} + +func Test_Branch_Copy(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + expectedBranch *Branch + }{ + "empty branch": { + branch: &Branch{}, + expectedBranch: &Branch{}, + }, + "non empty branch": { + branch: &Branch{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + Children: [16]Node{ + nil, nil, &Leaf{Key: []byte{9}}, + }, + dirty: true, + hashDigest: []byte{5}, + encoding: []byte{6}, + }, + expectedBranch: &Branch{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + Children: [16]Node{ + nil, nil, &Leaf{Key: []byte{9}}, + }, + dirty: true, + hashDigest: []byte{5}, + encoding: []byte{6}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + nodeCopy := testCase.branch.Copy() + + branchCopy, ok := nodeCopy.(*Branch) + require.True(t, ok) + + assert.Equal(t, testCase.expectedBranch, branchCopy) + testForSliceModif(t, testCase.branch.Key, branchCopy.Key) + testForSliceModif(t, testCase.branch.Value, branchCopy.Value) + testForSliceModif(t, testCase.branch.hashDigest, branchCopy.hashDigest) + testForSliceModif(t, testCase.branch.encoding, branchCopy.encoding) + + testCase.branch.Children[15] = &Leaf{Key: []byte("modified")} + assert.NotEqual(t, branchCopy.Children, testCase.branch.Children) + }) + } +} + +func Test_Leaf_Copy(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + expectedLeaf *Leaf + }{ + "empty leaf": { + leaf: &Leaf{}, + expectedLeaf: &Leaf{}, + }, + "non empty leaf": { + leaf: &Leaf{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + dirty: true, + hashDigest: []byte{5}, + encoding: []byte{6}, + }, + expectedLeaf: &Leaf{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + dirty: true, + hashDigest: []byte{5}, + encoding: []byte{6}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + nodeCopy := testCase.leaf.Copy() + + leafCopy, ok := nodeCopy.(*Leaf) + require.True(t, ok) + + assert.Equal(t, testCase.expectedLeaf, leafCopy) + testForSliceModif(t, testCase.leaf.Key, leafCopy.Key) + testForSliceModif(t, testCase.leaf.Value, leafCopy.Value) + testForSliceModif(t, testCase.leaf.hashDigest, leafCopy.hashDigest) + testForSliceModif(t, testCase.leaf.encoding, leafCopy.encoding) + }) + } +} From 8a315af2bac510854af7bc3ba1a04592876b25b0 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 16 Dec 2021 12:56:45 +0000 Subject: [PATCH 48/50] Move leaf hash methods to hash.go --- internal/trie/node/hash.go | 67 ++++++++++++++++++++++++++++++ internal/trie/node/leaf_encode.go | 68 ------------------------------- 2 files changed, 67 insertions(+), 68 deletions(-) diff --git a/internal/trie/node/hash.go b/internal/trie/node/hash.go index ac8a584572..315ad9b738 100644 --- a/internal/trie/node/hash.go +++ b/internal/trie/node/hash.go @@ -66,3 +66,70 @@ func (b *Branch) EncodeAndHash() (encoding, hash []byte, err error) { return encoding, hash, nil } + +// SetEncodingAndHash sets the encoding and hash slices +// given to the branch. Note it does not copy them, so beware. +func (l *Leaf) SetEncodingAndHash(enc, hash []byte) { + l.encodingMu.Lock() + l.encoding = enc + l.encodingMu.Unlock() + l.hashDigest = hash +} + +// GetHash returns the hash of the leaf. +// Note it does not copy it, so modifying +// the returned hash will modify the hash +// of the branch. +func (l *Leaf) GetHash() []byte { + return l.hashDigest +} + +// EncodeAndHash returns the encoding of the leaf and +// the blake2b hash digest of the encoding of the leaf. +// If the encoding is less than 32 bytes, the hash returned +// is the encoding and not the hash of the encoding. +func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { + l.encodingMu.RLock() + if !l.IsDirty() && l.encoding != nil && l.hashDigest != nil { + l.encodingMu.RUnlock() + return l.encoding, l.hashDigest, nil + } + l.encodingMu.RUnlock() + + buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) + buffer.Reset() + defer pools.EncodingBuffers.Put(buffer) + + err = l.Encode(buffer) + if err != nil { + return nil, nil, err + } + + bufferBytes := buffer.Bytes() + + l.encodingMu.Lock() + // TODO remove this copying since it defeats the purpose of `buffer` + // and the sync.Pool. + l.encoding = make([]byte, len(bufferBytes)) + copy(l.encoding, bufferBytes) + l.encodingMu.Unlock() + encoding = l.encoding // no need to copy + + if len(bufferBytes) < 32 { + l.hashDigest = make([]byte, len(bufferBytes)) + copy(l.hashDigest, bufferBytes) + hash = l.hashDigest // no need to copy + return encoding, hash, nil + } + + // Note: using the sync.Pool's buffer is useful here. + hashArray, err := common.Blake2bHash(buffer.Bytes()) + if err != nil { + return nil, nil, err + } + + l.hashDigest = hashArray[:] + hash = l.hashDigest // no need to copy + + return encoding, hash, nil +} diff --git a/internal/trie/node/leaf_encode.go b/internal/trie/node/leaf_encode.go index 291b45ae40..f18bbf0d28 100644 --- a/internal/trie/node/leaf_encode.go +++ b/internal/trie/node/leaf_encode.go @@ -11,77 +11,9 @@ import ( "github.com/ChainSafe/gossamer/internal/trie/codec" "github.com/ChainSafe/gossamer/internal/trie/pools" - "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/pkg/scale" ) -// SetEncodingAndHash sets the encoding and hash slices -// given to the branch. Note it does not copy them, so beware. -func (l *Leaf) SetEncodingAndHash(enc, hash []byte) { - l.encodingMu.Lock() - l.encoding = enc - l.encodingMu.Unlock() - l.hashDigest = hash -} - -// GetHash returns the hash of the leaf. -// Note it does not copy it, so modifying -// the returned hash will modify the hash -// of the branch. -func (l *Leaf) GetHash() []byte { - return l.hashDigest -} - -// EncodeAndHash returns the encoding of the leaf and -// the blake2b hash digest of the encoding of the leaf. -// If the encoding is less than 32 bytes, the hash returned -// is the encoding and not the hash of the encoding. -func (l *Leaf) EncodeAndHash() (encoding, hash []byte, err error) { - l.encodingMu.RLock() - if !l.IsDirty() && l.encoding != nil && l.hashDigest != nil { - l.encodingMu.RUnlock() - return l.encoding, l.hashDigest, nil - } - l.encodingMu.RUnlock() - - buffer := pools.EncodingBuffers.Get().(*bytes.Buffer) - buffer.Reset() - defer pools.EncodingBuffers.Put(buffer) - - err = l.Encode(buffer) - if err != nil { - return nil, nil, err - } - - bufferBytes := buffer.Bytes() - - l.encodingMu.Lock() - // TODO remove this copying since it defeats the purpose of `buffer` - // and the sync.Pool. - l.encoding = make([]byte, len(bufferBytes)) - copy(l.encoding, bufferBytes) - l.encodingMu.Unlock() - encoding = l.encoding // no need to copy - - if len(bufferBytes) < 32 { - l.hashDigest = make([]byte, len(bufferBytes)) - copy(l.hashDigest, bufferBytes) - hash = l.hashDigest // no need to copy - return encoding, hash, nil - } - - // Note: using the sync.Pool's buffer is useful here. - hashArray, err := common.Blake2bHash(buffer.Bytes()) - if err != nil { - return nil, nil, err - } - - l.hashDigest = hashArray[:] - hash = l.hashDigest // no need to copy - - return encoding, hash, nil -} - // Encode encodes a leaf to the buffer given. // The encoding has the following format: // NodeHeader | Extra partial key length | Partial Key | Value From 9c73b4821c80bc645bf5d2480ae0f932501a7d4e Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 16 Dec 2021 13:17:48 +0000 Subject: [PATCH 49/50] fix naming in `decodeKey` function --- internal/trie/node/key.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/trie/node/key.go b/internal/trie/node/key.go index c42438493b..3478ef3aa7 100644 --- a/internal/trie/node/key.go +++ b/internal/trie/node/key.go @@ -78,10 +78,10 @@ func encodeKeyLength(keyLength int, writer io.Writer) (err error) { } // decodeKey decodes a key from a reader. -func decodeKey(reader io.Reader, keyLength byte) (b []byte, err error) { - publicKeyLength := int(keyLength) +func decodeKey(reader io.Reader, keyLengthByte byte) (b []byte, err error) { + keyLength := int(keyLengthByte) - if keyLength == keyLenOffset { + if keyLengthByte == keyLenOffset { // partial key longer than 63, read next bytes for rest of pk len buffer := pools.SingleByteBuffers.Get().(*bytes.Buffer) defer pools.SingleByteBuffers.Put(buffer) @@ -93,24 +93,24 @@ func decodeKey(reader io.Reader, keyLength byte) (b []byte, err error) { } nextKeyLen := oneByteBuf[0] - publicKeyLength += int(nextKeyLen) + keyLength += int(nextKeyLen) if nextKeyLen < 0xff { break } - if publicKeyLength >= int(maxPartialKeySize) { + if keyLength >= int(maxPartialKeySize) { return nil, fmt.Errorf("%w: %d", - ErrPartialKeyTooBig, publicKeyLength) + ErrPartialKeyTooBig, keyLength) } } } - if publicKeyLength == 0 { + if keyLength == 0 { return []byte{}, nil } - key := make([]byte, publicKeyLength/2+publicKeyLength%2) + key := make([]byte, keyLength/2+keyLength%2) n, err := reader.Read(key) if err != nil { return nil, fmt.Errorf("%w: %s", ErrReadKeyData, err) @@ -119,5 +119,5 @@ func decodeKey(reader io.Reader, keyLength byte) (b []byte, err error) { ErrReadKeyData, n, len(key)) } - return codec.KeyLEToNibbles(key)[publicKeyLength%2:], nil + return codec.KeyLEToNibbles(key)[keyLength%2:], nil } From 61baed1deef383ea811ff0bef062347d21412bd4 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Thu, 16 Dec 2021 13:53:04 +0000 Subject: [PATCH 50/50] 95.7% test coverage for `internal/trie/node` --- internal/trie/node/branch.go | 4 +- internal/trie/node/branch_encode_test.go | 128 ++++++++++++ internal/trie/node/branch_test.go | 95 +++++++++ internal/trie/node/copy_test.go | 3 + internal/trie/node/dirty_test.go | 150 +++++++++++++ internal/trie/node/generation_test.go | 50 +++++ internal/trie/node/hash_test.go | 254 +++++++++++++++++++++++ internal/trie/node/key_test.go | 163 +++++++++++++-- internal/trie/node/leaf.go | 4 +- internal/trie/node/leaf_test.go | 77 +++++++ internal/trie/node/reader_mock_test.go | 49 +++++ internal/trie/node/value_test.go | 30 +++ 12 files changed, 980 insertions(+), 27 deletions(-) create mode 100644 internal/trie/node/branch_test.go create mode 100644 internal/trie/node/dirty_test.go create mode 100644 internal/trie/node/generation_test.go create mode 100644 internal/trie/node/hash_test.go create mode 100644 internal/trie/node/leaf_test.go create mode 100644 internal/trie/node/reader_mock_test.go create mode 100644 internal/trie/node/value_test.go diff --git a/internal/trie/node/branch.go b/internal/trie/node/branch.go index 58473023a9..7f3422a6f1 100644 --- a/internal/trie/node/branch.go +++ b/internal/trie/node/branch.go @@ -42,9 +42,9 @@ func NewBranch(key, value []byte, dirty bool, generation uint64) *Branch { func (b *Branch) String() string { if len(b.Value) > 1024 { - return fmt.Sprintf("key=%x childrenBitmap=%16b value (hashed)=%x dirty=%v", + return fmt.Sprintf("branch key=0x%x childrenBitmap=%b value (hashed)=0x%x dirty=%t", b.Key, b.ChildrenBitmap(), common.MustBlake2bHash(b.Value), b.dirty) } - return fmt.Sprintf("key=%x childrenBitmap=%16b value=%v dirty=%v", + return fmt.Sprintf("branch key=0x%x childrenBitmap=%b value=0x%x dirty=%t", b.Key, b.ChildrenBitmap(), b.Value, b.dirty) } diff --git a/internal/trie/node/branch_encode_test.go b/internal/trie/node/branch_encode_test.go index 54504665bf..9c1fc50703 100644 --- a/internal/trie/node/branch_encode_test.go +++ b/internal/trie/node/branch_encode_test.go @@ -11,6 +11,134 @@ import ( "github.com/stretchr/testify/require" ) +func Test_Branch_ScaleEncodeHash(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + encoding []byte + wrappedErr error + errMessage string + }{ + "empty branch": { + branch: &Branch{}, + encoding: []byte{0xc, 0x80, 0x0, 0x0}, + }, + "non empty branch": { + branch: &Branch{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + Children: [16]Node{ + nil, nil, &Leaf{Key: []byte{9}}, + }, + }, + encoding: []byte{0x2c, 0xc2, 0x12, 0x4, 0x0, 0x8, 0x3, 0x4, 0xc, 0x41, 0x9, 0x0}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + encoding, err := testCase.branch.ScaleEncodeHash() + + if testCase.wrappedErr != nil { + assert.ErrorIs(t, err, testCase.wrappedErr) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + assert.Equal(t, testCase.encoding, encoding) + }) + } +} + +func Test_Branch_hash(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + write writeCall + errWrapped error + errMessage string + }{ + "empty branch": { + branch: &Branch{}, + write: writeCall{ + written: []byte{128, 0, 0}, + }, + }, + "less than 32 bytes encoding": { + branch: &Branch{ + Key: []byte{1, 2}, + }, + write: writeCall{ + written: []byte{130, 18, 0, 0}, + }, + }, + "less than 32 bytes encoding write error": { + branch: &Branch{ + Key: []byte{1, 2}, + }, + write: writeCall{ + written: []byte{130, 18, 0, 0}, + err: errTest, + }, + errWrapped: errTest, + errMessage: "cannot write encoded branch to buffer: test error", + }, + "more than 32 bytes encoding": { + branch: &Branch{ + Key: repeatBytes(100, 1), + }, + write: writeCall{ + written: []byte{ + 70, 102, 188, 24, 31, 68, 86, 114, + 95, 156, 225, 138, 175, 254, 176, 251, + 81, 84, 193, 40, 11, 234, 142, 233, + 69, 250, 158, 86, 72, 228, 66, 46}, + }, + }, + "more than 32 bytes encoding write error": { + branch: &Branch{ + Key: repeatBytes(100, 1), + }, + write: writeCall{ + written: []byte{ + 70, 102, 188, 24, 31, 68, 86, 114, + 95, 156, 225, 138, 175, 254, 176, 251, + 81, 84, 193, 40, 11, 234, 142, 233, + 69, 250, 158, 86, 72, 228, 66, 46}, + err: errTest, + }, + errWrapped: errTest, + errMessage: "cannot write hash sum of branch to buffer: test error", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + digestBuffer := NewMockWriter(ctrl) + digestBuffer.EXPECT().Write(testCase.write.written). + Return(testCase.write.n, testCase.write.err) + + err := testCase.branch.hash(digestBuffer) + + if testCase.errWrapped != nil { + assert.ErrorIs(t, err, testCase.errWrapped) + assert.EqualError(t, err, testCase.errMessage) + } else { + require.NoError(t, err) + } + }) + } +} + func Test_Branch_Encode(t *testing.T) { t.Parallel() diff --git a/internal/trie/node/branch_test.go b/internal/trie/node/branch_test.go new file mode 100644 index 0000000000..a7d4591c32 --- /dev/null +++ b/internal/trie/node/branch_test.go @@ -0,0 +1,95 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NewBranch(t *testing.T) { + t.Parallel() + + key := []byte{1, 2} + value := []byte{3, 4} + const dirty = true + const generation = 9 + + branch := NewBranch(key, value, dirty, generation) + + expectedBranch := &Branch{ + Key: key, + Value: value, + dirty: dirty, + generation: generation, + } + assert.Equal(t, expectedBranch, branch) + + // Check modifying passed slice modifies branch slices + key[0] = 11 + value[0] = 13 + assert.Equal(t, expectedBranch, branch) +} + +func Test_Branch_String(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + s string + }{ + "empty branch": { + branch: &Branch{}, + s: "branch key=0x childrenBitmap=0 value=0x dirty=false", + }, + "branch with value smaller than 1024": { + branch: &Branch{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + dirty: true, + Children: [16]Node{ + nil, nil, nil, + &Leaf{}, + nil, nil, nil, + &Branch{}, + nil, nil, nil, + &Leaf{}, + nil, nil, nil, nil, + }, + }, + s: "branch key=0x0102 childrenBitmap=100010001000 value=0x0304 dirty=true", + }, + "branch with value higher than 1024": { + branch: &Branch{ + Key: []byte{1, 2}, + Value: make([]byte, 1025), + dirty: true, + Children: [16]Node{ + nil, nil, nil, + &Leaf{}, + nil, nil, nil, + &Branch{}, + nil, nil, nil, + &Leaf{}, + nil, nil, nil, nil, + }, + }, + s: "branch key=0x0102 childrenBitmap=100010001000 " + + "value (hashed)=0x307861663233363133353361303538646238383034626337353735323831663131663735313265326331346336373032393864306232336630396538386266333066 " + //nolint:lll + "dirty=true", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + s := testCase.branch.String() + + assert.Equal(t, testCase.s, s) + }) + } +} diff --git a/internal/trie/node/copy_test.go b/internal/trie/node/copy_test.go index 75de5f6284..bff0f409c2 100644 --- a/internal/trie/node/copy_test.go +++ b/internal/trie/node/copy_test.go @@ -1,3 +1,6 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + package node import ( diff --git a/internal/trie/node/dirty_test.go b/internal/trie/node/dirty_test.go new file mode 100644 index 0000000000..ebe9c02fa1 --- /dev/null +++ b/internal/trie/node/dirty_test.go @@ -0,0 +1,150 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Branch_IsDirty(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + dirty bool + }{ + "not dirty": { + branch: &Branch{}, + }, + "dirty": { + branch: &Branch{ + dirty: true, + }, + dirty: true, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + dirty := testCase.branch.IsDirty() + + assert.Equal(t, testCase.dirty, dirty) + }) + } +} + +func Test_Branch_SetDirty(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + dirty bool + expected *Branch + }{ + "not dirty to not dirty": { + branch: &Branch{}, + expected: &Branch{}, + }, + "not dirty to dirty": { + branch: &Branch{}, + dirty: true, + expected: &Branch{dirty: true}, + }, + "dirty to not dirty": { + branch: &Branch{dirty: true}, + expected: &Branch{}, + }, + "dirty to dirty": { + branch: &Branch{dirty: true}, + dirty: true, + expected: &Branch{dirty: true}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.branch.SetDirty(testCase.dirty) + + assert.Equal(t, testCase.expected, testCase.branch) + }) + } +} + +func Test_Leaf_IsDirty(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + dirty bool + }{ + "not dirty": { + leaf: &Leaf{}, + }, + "dirty": { + leaf: &Leaf{ + dirty: true, + }, + dirty: true, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + dirty := testCase.leaf.IsDirty() + + assert.Equal(t, testCase.dirty, dirty) + }) + } +} + +func Test_Leaf_SetDirty(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + dirty bool + expected *Leaf + }{ + "not dirty to not dirty": { + leaf: &Leaf{}, + expected: &Leaf{}, + }, + "not dirty to dirty": { + leaf: &Leaf{}, + dirty: true, + expected: &Leaf{dirty: true}, + }, + "dirty to not dirty": { + leaf: &Leaf{dirty: true}, + expected: &Leaf{}, + }, + "dirty to dirty": { + leaf: &Leaf{dirty: true}, + dirty: true, + expected: &Leaf{dirty: true}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + testCase.leaf.SetDirty(testCase.dirty) + + assert.Equal(t, testCase.expected, testCase.leaf) + }) + } +} diff --git a/internal/trie/node/generation_test.go b/internal/trie/node/generation_test.go new file mode 100644 index 0000000000..708d93058e --- /dev/null +++ b/internal/trie/node/generation_test.go @@ -0,0 +1,50 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Branch_SetGeneration(t *testing.T) { + t.Parallel() + + branch := &Branch{ + generation: 1, + } + branch.SetGeneration(2) + assert.Equal(t, &Branch{generation: 2}, branch) +} + +func Test_Branch_GetGeneration(t *testing.T) { + t.Parallel() + + const generation uint64 = 1 + branch := &Branch{ + generation: generation, + } + assert.Equal(t, branch.GetGeneration(), generation) +} + +func Test_Leaf_SetGeneration(t *testing.T) { + t.Parallel() + + leaf := &Leaf{ + generation: 1, + } + leaf.SetGeneration(2) + assert.Equal(t, &Leaf{generation: 2}, leaf) +} + +func Test_Leaf_GetGeneration(t *testing.T) { + t.Parallel() + + const generation uint64 = 1 + leaf := &Leaf{ + generation: generation, + } + assert.Equal(t, leaf.GetGeneration(), generation) +} diff --git a/internal/trie/node/hash_test.go b/internal/trie/node/hash_test.go new file mode 100644 index 0000000000..26693cd76b --- /dev/null +++ b/internal/trie/node/hash_test.go @@ -0,0 +1,254 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Branch_SetEncodingAndHash(t *testing.T) { + t.Parallel() + + branch := &Branch{ + encoding: []byte{2}, + hashDigest: []byte{3}, + } + branch.SetEncodingAndHash([]byte{4}, []byte{5}) + + expectedBranch := &Branch{ + encoding: []byte{4}, + hashDigest: []byte{5}, + } + assert.Equal(t, expectedBranch, branch) +} + +func Test_Branch_GetHash(t *testing.T) { + t.Parallel() + + branch := &Branch{ + hashDigest: []byte{3}, + } + hash := branch.GetHash() + + expectedHash := []byte{3} + assert.Equal(t, expectedHash, hash) +} + +func Test_Branch_EncodeAndHash(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + branch *Branch + expectedBranch *Branch + encoding []byte + hash []byte + errWrapped error + errMessage string + }{ + "empty branch": { + branch: &Branch{}, + expectedBranch: &Branch{ + encoding: []byte{0x80, 0x0, 0x0}, + hashDigest: []byte{0x80, 0x0, 0x0}, + }, + encoding: []byte{0x80, 0x0, 0x0}, + hash: []byte{0x80, 0x0, 0x0}, + }, + "small branch encoding": { + branch: &Branch{ + Key: []byte{1}, + Value: []byte{2}, + }, + expectedBranch: &Branch{ + encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + hashDigest: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + }, + encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + hash: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + }, + "branch dirty with precomputed encoding and hash": { + branch: &Branch{ + Key: []byte{1}, + Value: []byte{2}, + dirty: true, + encoding: []byte{3}, + hashDigest: []byte{4}, + }, + expectedBranch: &Branch{ + encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + hashDigest: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + }, + encoding: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + hash: []byte{0xc1, 0x1, 0x0, 0x0, 0x4, 0x2}, + }, + "branch not dirty with precomputed encoding and hash": { + branch: &Branch{ + Key: []byte{1}, + Value: []byte{2}, + dirty: false, + encoding: []byte{3}, + hashDigest: []byte{4}, + }, + expectedBranch: &Branch{ + Key: []byte{1}, + Value: []byte{2}, + encoding: []byte{3}, + hashDigest: []byte{4}, + }, + encoding: []byte{3}, + hash: []byte{4}, + }, + "large branch encoding": { + branch: &Branch{ + Key: repeatBytes(65, 7), + }, + expectedBranch: &Branch{ + encoding: []byte{0xbf, 0x2, 0x7, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x0, 0x0}, //nolint:lll + hashDigest: []byte{0x6b, 0xd8, 0xcc, 0xac, 0x71, 0x77, 0x44, 0x17, 0xfe, 0xe0, 0xde, 0xda, 0xd5, 0x97, 0x6e, 0x69, 0xeb, 0xe9, 0xdd, 0x80, 0x1d, 0x4b, 0x51, 0xf1, 0x5b, 0xf3, 0x4a, 0x93, 0x27, 0x32, 0x2c, 0xb0}, //nolint:lll + }, + encoding: []byte{0xbf, 0x2, 0x7, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x0, 0x0}, //nolint:lll + hash: []byte{0x6b, 0xd8, 0xcc, 0xac, 0x71, 0x77, 0x44, 0x17, 0xfe, 0xe0, 0xde, 0xda, 0xd5, 0x97, 0x6e, 0x69, 0xeb, 0xe9, 0xdd, 0x80, 0x1d, 0x4b, 0x51, 0xf1, 0x5b, 0xf3, 0x4a, 0x93, 0x27, 0x32, 0x2c, 0xb0}, //nolint:lll + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + encoding, hash, err := testCase.branch.EncodeAndHash() + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.encoding, encoding) + assert.Equal(t, testCase.hash, hash) + }) + } +} + +func Test_Leaf_SetEncodingAndHash(t *testing.T) { + t.Parallel() + + leaf := &Leaf{ + encoding: []byte{2}, + hashDigest: []byte{3}, + } + leaf.SetEncodingAndHash([]byte{4}, []byte{5}) + + expectedLeaf := &Leaf{ + encoding: []byte{4}, + hashDigest: []byte{5}, + } + assert.Equal(t, expectedLeaf, leaf) +} + +func Test_Leaf_GetHash(t *testing.T) { + t.Parallel() + + leaf := &Leaf{ + hashDigest: []byte{3}, + } + hash := leaf.GetHash() + + expectedHash := []byte{3} + assert.Equal(t, expectedHash, hash) +} + +func Test_Leaf_EncodeAndHash(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + expectedLeaf *Leaf + encoding []byte + hash []byte + errWrapped error + errMessage string + }{ + "empty leaf": { + leaf: &Leaf{}, + expectedLeaf: &Leaf{ + encoding: []byte{0x40, 0x0}, + hashDigest: []byte{0x40, 0x0}, + }, + encoding: []byte{0x40, 0x0}, + hash: []byte{0x40, 0x0}, + }, + "small leaf encoding": { + leaf: &Leaf{ + Key: []byte{1}, + Value: []byte{2}, + }, + expectedLeaf: &Leaf{ + encoding: []byte{0x41, 0x1, 0x4, 0x2}, + hashDigest: []byte{0x41, 0x1, 0x4, 0x2}, + }, + encoding: []byte{0x41, 0x1, 0x4, 0x2}, + hash: []byte{0x41, 0x1, 0x4, 0x2}, + }, + "leaf dirty with precomputed encoding and hash": { + leaf: &Leaf{ + Key: []byte{1}, + Value: []byte{2}, + dirty: true, + encoding: []byte{3}, + hashDigest: []byte{4}, + }, + expectedLeaf: &Leaf{ + encoding: []byte{0x41, 0x1, 0x4, 0x2}, + hashDigest: []byte{0x41, 0x1, 0x4, 0x2}, + }, + encoding: []byte{0x41, 0x1, 0x4, 0x2}, + hash: []byte{0x41, 0x1, 0x4, 0x2}, + }, + "leaf not dirty with precomputed encoding and hash": { + leaf: &Leaf{ + Key: []byte{1}, + Value: []byte{2}, + dirty: false, + encoding: []byte{3}, + hashDigest: []byte{4}, + }, + expectedLeaf: &Leaf{ + Key: []byte{1}, + Value: []byte{2}, + encoding: []byte{3}, + hashDigest: []byte{4}, + }, + encoding: []byte{3}, + hash: []byte{4}, + }, + "large leaf encoding": { + leaf: &Leaf{ + Key: repeatBytes(65, 7), + }, + expectedLeaf: &Leaf{ + encoding: []byte{0x7f, 0x2, 0x7, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x0}, //nolint:lll + hashDigest: []byte{0xfb, 0xae, 0x31, 0x4b, 0xef, 0x31, 0x9, 0xc7, 0x62, 0x99, 0x9d, 0x40, 0x9b, 0xd4, 0xdc, 0x64, 0xe7, 0x39, 0x46, 0x8b, 0xd3, 0xaf, 0xe8, 0x63, 0x9d, 0xf9, 0x41, 0x40, 0x76, 0x40, 0x10, 0xa3}, //nolint:lll + }, + encoding: []byte{0x7f, 0x2, 0x7, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x77, 0x0}, //nolint:lll + hash: []byte{0xfb, 0xae, 0x31, 0x4b, 0xef, 0x31, 0x9, 0xc7, 0x62, 0x99, 0x9d, 0x40, 0x9b, 0xd4, 0xdc, 0x64, 0xe7, 0x39, 0x46, 0x8b, 0xd3, 0xaf, 0xe8, 0x63, 0x9d, 0xf9, 0x41, 0x40, 0x76, 0x40, 0x10, 0xa3}, //nolint:lll + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + encoding, hash, err := testCase.leaf.EncodeAndHash() + + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.encoding, encoding) + assert.Equal(t, testCase.hash, hash) + }) + } +} diff --git a/internal/trie/node/key_test.go b/internal/trie/node/key_test.go index 4ec4985527..c3413c1628 100644 --- a/internal/trie/node/key_test.go +++ b/internal/trie/node/key_test.go @@ -5,7 +5,7 @@ package node import ( "bytes" - "io" + "fmt" "testing" "github.com/golang/mock/gomock" @@ -13,6 +13,46 @@ import ( "github.com/stretchr/testify/require" ) +func Test_Branch_GetKey(t *testing.T) { + t.Parallel() + + branch := &Branch{ + Key: []byte{2}, + } + key := branch.GetKey() + assert.Equal(t, []byte{2}, key) +} + +func Test_Leaf_GetKey(t *testing.T) { + t.Parallel() + + leaf := &Leaf{ + Key: []byte{2}, + } + key := leaf.GetKey() + assert.Equal(t, []byte{2}, key) +} + +func Test_Branch_SetKey(t *testing.T) { + t.Parallel() + + branch := &Branch{ + Key: []byte{2}, + } + branch.SetKey([]byte{3}) + assert.Equal(t, &Branch{Key: []byte{3}}, branch) +} + +func Test_Leaf_SetKey(t *testing.T) { + t.Parallel() + + leaf := &Leaf{ + Key: []byte{2}, + } + leaf.SetKey([]byte{3}) + assert.Equal(t, &Leaf{Key: []byte{3}}, leaf) +} + func repeatBytes(n int, b byte) (slice []byte) { slice = make([]byte, n) for i := range slice { @@ -144,11 +184,60 @@ func Test_encodeKeyLength(t *testing.T) { }) } +//go:generate mockgen -destination=reader_mock_test.go -package $GOPACKAGE io Reader + +type readCall struct { + buffArgCap int + read []byte + n int // number of bytes read + err error +} + +func repeatReadCalls(rc readCall, length int) (readCalls []readCall) { + readCalls = make([]readCall, length) + for i := range readCalls { + readCalls[i] = readCall{ + buffArgCap: rc.buffArgCap, + n: rc.n, + err: rc.err, + } + if rc.read != nil { + readCalls[i].read = make([]byte, len(rc.read)) + copy(readCalls[i].read, rc.read) + } + } + return readCalls +} + +var _ gomock.Matcher = (*byteSliceCapMatcher)(nil) + +type byteSliceCapMatcher struct { + capacity int +} + +func (b *byteSliceCapMatcher) Matches(x interface{}) bool { + slice, ok := x.([]byte) + if !ok { + return false + } + return cap(slice) == b.capacity +} + +func (b *byteSliceCapMatcher) String() string { + return fmt.Sprintf("capacity of slice is not the expected capacity %d", b.capacity) +} + +func newByteSliceCapMatcher(capacity int) *byteSliceCapMatcher { + return &byteSliceCapMatcher{ + capacity: capacity, + } +} + func Test_decodeKey(t *testing.T) { t.Parallel() testCases := map[string]struct { - reader io.Reader + reads []readCall keyLength byte b []byte errWrapped error @@ -158,42 +247,54 @@ func Test_decodeKey(t *testing.T) { b: []byte{}, }, "short key length": { - reader: bytes.NewBuffer([]byte{1, 2, 3}), + reads: []readCall{ + {buffArgCap: 3, read: []byte{1, 2, 3}, n: 3}, + }, keyLength: 5, b: []byte{0x1, 0x0, 0x2, 0x0, 0x3}, }, "key read error": { - reader: bytes.NewBuffer(nil), + reads: []readCall{ + {buffArgCap: 3, err: errTest}, + }, keyLength: 5, errWrapped: ErrReadKeyData, - errMessage: "cannot read key data: EOF", + errMessage: "cannot read key data: test error", + }, + + "key read bytes count mismatch": { + reads: []readCall{ + {buffArgCap: 3, n: 2}, + }, + keyLength: 5, + errWrapped: ErrReadKeyData, + errMessage: "cannot read key data: read 2 bytes instead of 3", }, "long key length": { - reader: bytes.NewBuffer( - append( - []byte{ - 6, // key length - }, - repeatBytes(64, 7)..., // key data - )), + reads: []readCall{ + {buffArgCap: 1, read: []byte{6}, n: 1}, // key length + {buffArgCap: 35, read: repeatBytes(35, 7), n: 35}, // key data + }, keyLength: 0x3f, b: []byte{ - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, - 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7}, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, + 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7, 0x0, 0x7}, }, "key length read error": { - reader: bytes.NewBuffer(nil), + reads: []readCall{ + {buffArgCap: 1, err: errTest}, + }, keyLength: 0x3f, errWrapped: ErrReadKeyLength, - errMessage: "cannot read key length: EOF", + errMessage: "cannot read key length: test error", }, "key length too big": { - reader: bytes.NewBuffer(repeatBytes(257, 0xff)), + reads: repeatReadCalls(readCall{buffArgCap: 1, read: []byte{0xff}, n: 1}, 257), keyLength: 0x3f, errWrapped: ErrPartialKeyTooBig, errMessage: "partial key length cannot be larger than or equal to 2^16: 65598", @@ -204,8 +305,24 @@ func Test_decodeKey(t *testing.T) { testCase := testCase t.Run(name, func(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + + reader := NewMockReader(ctrl) + var previousCall *gomock.Call + for _, readCall := range testCase.reads { + byteSliceCapMatcher := newByteSliceCapMatcher(readCall.buffArgCap) + call := reader.EXPECT().Read(byteSliceCapMatcher). + DoAndReturn(func(b []byte) (n int, err error) { + copy(b, readCall.read) + return readCall.n, readCall.err + }) + if previousCall != nil { + call.After(previousCall) + } + previousCall = call + } - b, err := decodeKey(testCase.reader, testCase.keyLength) + b, err := decodeKey(reader, testCase.keyLength) assert.ErrorIs(t, err, testCase.errWrapped) if err != nil { diff --git a/internal/trie/node/leaf.go b/internal/trie/node/leaf.go index 77c884f397..0de16a3881 100644 --- a/internal/trie/node/leaf.go +++ b/internal/trie/node/leaf.go @@ -42,7 +42,7 @@ func NewLeaf(key, value []byte, dirty bool, generation uint64) *Leaf { func (l *Leaf) String() string { if len(l.Value) > 1024 { - return fmt.Sprintf("leaf key=%x value (hashed)=%x dirty=%v", l.Key, common.MustBlake2bHash(l.Value), l.dirty) + return fmt.Sprintf("leaf key=0x%x value (hashed)=0x%x dirty=%t", l.Key, common.MustBlake2bHash(l.Value), l.dirty) } - return fmt.Sprintf("leaf key=%x value=%v dirty=%v", l.Key, l.Value, l.dirty) + return fmt.Sprintf("leaf key=0x%x value=0x%x dirty=%t", l.Key, l.Value, l.dirty) } diff --git a/internal/trie/node/leaf_test.go b/internal/trie/node/leaf_test.go new file mode 100644 index 0000000000..d755eb724d --- /dev/null +++ b/internal/trie/node/leaf_test.go @@ -0,0 +1,77 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_NewLeaf(t *testing.T) { + t.Parallel() + + key := []byte{1, 2} + value := []byte{3, 4} + const dirty = true + const generation = 9 + + leaf := NewLeaf(key, value, dirty, generation) + + expectedLeaf := &Leaf{ + Key: key, + Value: value, + dirty: dirty, + generation: generation, + } + assert.Equal(t, expectedLeaf, leaf) + + // Check modifying passed slice modifies leaf slices + key[0] = 11 + value[0] = 13 + assert.Equal(t, expectedLeaf, leaf) +} + +func Test_Leaf_String(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + leaf *Leaf + s string + }{ + "empty leaf": { + leaf: &Leaf{}, + s: "leaf key=0x value=0x dirty=false", + }, + "leaf with value smaller than 1024": { + leaf: &Leaf{ + Key: []byte{1, 2}, + Value: []byte{3, 4}, + dirty: true, + }, + s: "leaf key=0x0102 value=0x0304 dirty=true", + }, + "leaf with value higher than 1024": { + leaf: &Leaf{ + Key: []byte{1, 2}, + Value: make([]byte, 1025), + dirty: true, + }, + s: "leaf key=0x0102 " + + "value (hashed)=0x307861663233363133353361303538646238383034626337353735323831663131663735313265326331346336373032393864306232336630396538386266333066 " + //nolint:lll + "dirty=true", + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + s := testCase.leaf.String() + + assert.Equal(t, testCase.s, s) + }) + } +} diff --git a/internal/trie/node/reader_mock_test.go b/internal/trie/node/reader_mock_test.go new file mode 100644 index 0000000000..2aa28d2998 --- /dev/null +++ b/internal/trie/node/reader_mock_test.go @@ -0,0 +1,49 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: io (interfaces: Reader) + +// Package node is a generated GoMock package. +package node + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockReader is a mock of Reader interface. +type MockReader struct { + ctrl *gomock.Controller + recorder *MockReaderMockRecorder +} + +// MockReaderMockRecorder is the mock recorder for MockReader. +type MockReaderMockRecorder struct { + mock *MockReader +} + +// NewMockReader creates a new mock instance. +func NewMockReader(ctrl *gomock.Controller) *MockReader { + mock := &MockReader{ctrl: ctrl} + mock.recorder = &MockReaderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockReader) EXPECT() *MockReaderMockRecorder { + return m.recorder +} + +// Read mocks base method. +func (m *MockReader) Read(arg0 []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read. +func (mr *MockReaderMockRecorder) Read(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockReader)(nil).Read), arg0) +} diff --git a/internal/trie/node/value_test.go b/internal/trie/node/value_test.go new file mode 100644 index 0000000000..f6fe989d1d --- /dev/null +++ b/internal/trie/node/value_test.go @@ -0,0 +1,30 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package node + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_Branch_GetValue(t *testing.T) { + t.Parallel() + + branch := &Branch{ + Value: []byte{2}, + } + value := branch.GetValue() + assert.Equal(t, []byte{2}, value) +} + +func Test_Leaf_GetValue(t *testing.T) { + t.Parallel() + + leaf := &Leaf{ + Value: []byte{2}, + } + value := leaf.GetValue() + assert.Equal(t, []byte{2}, value) +}