Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(lib/trie): Parallel hash trie. #1657

Merged
merged 10 commits into from
Jul 2, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion dot/state/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ func TestService_PruneStorage(t *testing.T) {
}

var toFinalize common.Hash

for i := 0; i < 3; i++ {
block, trieState := generateBlockWithRandomTrie(t, serv, nil, int64(i+1))
block.Header.Digest = types.Digest{
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ require (
github.com/urfave/cli v1.20.0
github.com/wasmerio/go-ext-wasm v0.3.2-0.20200326095750-0a32be6068ec
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40 // indirect
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1
google.golang.org/appengine v1.6.5 // indirect
Expand Down
1 change: 0 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200j
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.4.0/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw=
github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc=
github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0=
Expand Down
4 changes: 3 additions & 1 deletion lib/trie/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ func encodeRecursive(n node, enc []byte) ([]byte, error) {
return []byte{}, nil
}

nenc, err := n.encode()
hasher := NewHasher(false)
defer hasher.returnToPool()
nenc, err := hasher.encode(n)
if err != nil {
return enc, err
}
Expand Down
183 changes: 173 additions & 10 deletions lib/trie/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,62 @@
package trie

import (
"bytes"
"context"
"hash"
"sync"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/scale"
"golang.org/x/crypto/blake2b"
"golang.org/x/sync/errgroup"
)

type sliceBuffer []byte

func (b *sliceBuffer) write(data []byte) {
*b = append(*b, data...)
}

func (b *sliceBuffer) reset() {
*b = (*b)[:0]
}

// Hasher is a wrapper around a hash function
type Hasher struct {
hash hash.Hash
hash hash.Hash
tmp sliceBuffer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering why not use bytes.Buffer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we are calculating the upper limit of the slice and preallocating it. Thus, avoiding dynamic allocation.

parallel bool // Whether to use parallel threads when hashing
}

// hasherPool creates a pool of Hasher.
var hasherPool = sync.Pool{
New: func() interface{} {
h, _ := blake2b.New256(nil)

return &Hasher{
tmp: make(sliceBuffer, 0, 520), // cap is as large as a full branch node.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how was this calculated? branch nodes can still store a value, so the cap might be larger than this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed it will be only keyed. I will look into this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me know if you fix this and I can take another look

hash: h,
}
},
}

// NewHasher create new Hasher instance
func NewHasher() (*Hasher, error) {
h, err := blake2b.New256(nil)
if err != nil {
return nil, err
}
func NewHasher(parallel bool) *Hasher {
h := hasherPool.Get().(*Hasher)
h.parallel = parallel
return h
}

return &Hasher{
hash: h,
}, nil
func (h *Hasher) returnToPool() {
h.tmp.reset()
h.hash.Reset()
hasherPool.Put(h)
}

// Hash encodes the node and then hashes it if its encoded length is > 32 bytes
func (h *Hasher) Hash(n node) (res []byte, err error) {
encNode, err := n.encode()
encNode, err := h.encode(n)
if err != nil {
return nil, err
}
Expand All @@ -51,6 +82,7 @@ func (h *Hasher) Hash(n node) (res []byte, err error) {
return encNode, nil
}

h.hash.Reset()
// otherwise, hash encoded node
_, err = h.hash.Write(encNode)
if err == nil {
Expand All @@ -59,3 +91,134 @@ func (h *Hasher) Hash(n node) (res []byte, err error) {

return res, err
}

// encode is the high-level function wrapping the encoding for different node types
// encoding has the following format:
// NodeHeader | Extra partial key length | Partial Key | Value
func (h *Hasher) encode(n node) ([]byte, error) {
switch n := n.(type) {
case *branch:
return h.encodeBranch(n)
case *leaf:
return h.encodeLeaf(n)
case nil:
return []byte{0}, nil
}

return nil, nil
}

func encodeAndHash(n node) ([]byte, error) {
h := NewHasher(false)
defer h.returnToPool()

encChild, err := h.Hash(n)
if err != nil {
return nil, err
}

scEncChild, err := scale.Encode(encChild)
if err != nil {
return nil, err
}
return scEncChild, nil
}

// encodeBranch encodes a branch with the encoding specified at the top of this package
func (h *Hasher) encodeBranch(b *branch) ([]byte, error) {
if !b.dirty && b.encoding != nil {
return b.encoding, nil
}
h.tmp.reset()

encoding, err := b.header()
h.tmp.write(encoding)
if err != nil {
return nil, err
}

h.tmp.write(nibblesToKeyLE(b.key))
h.tmp.write(common.Uint16ToBytes(b.childrenBitmap()))

if b.value != nil {
buffer := bytes.Buffer{}
se := scale.Encoder{Writer: &buffer}
_, err = se.Encode(b.value)
if err != nil {
return h.tmp, err
}
h.tmp.write(buffer.Bytes())
}

if h.parallel {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if h.parallel {
switch h.parallel { {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch for one case seems overkill.

wg, _ := errgroup.WithContext(context.Background())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL errgroup

resBuff := make([][]byte, 16)
for i := 0; i < 16; i++ {
func(i int) {
wg.Go(func() error {
child := b.children[i]
if child == nil {
return nil
}

var err error
resBuff[i], err = encodeAndHash(child)
if err != nil {
return err
}
return nil
})
}(i)
}
if err := wg.Wait(); err != nil {
return nil, err
}

for _, v := range resBuff {
if v != nil {
h.tmp.write(v)
}
}
} else {
for i := 0; i < 16; i++ {
if child := b.children[i]; child != nil {
scEncChild, err := encodeAndHash(child)
if err != nil {
return nil, err
}
h.tmp.write(scEncChild)
}
}
}

return h.tmp, nil
}

// encodeLeaf encodes a leaf with the encoding specified at the top of this package
func (h *Hasher) encodeLeaf(l *leaf) ([]byte, error) {
if !l.dirty && l.encoding != nil {
return l.encoding, nil
}

h.tmp.reset()

encoding, err := l.header()
h.tmp.write(encoding)
if err != nil {
return nil, err
}

h.tmp.write(nibblesToKeyLE(l.key))

buffer := bytes.Buffer{}
se := scale.Encoder{Writer: &buffer}

_, err = se.Encode(l.value)
if err != nil {
return nil, err
}

h.tmp.write(buffer.Bytes())
l.encoding = h.tmp
return h.tmp, nil
}
30 changes: 10 additions & 20 deletions lib/trie/hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,10 @@ func generateRand(size int) [][]byte {
}

func TestNewHasher(t *testing.T) {
hasher, err := NewHasher()
if err != nil {
t.Fatalf("error creating new hasher: %s", err)
} else if hasher == nil {
t.Fatal("did not create new hasher")
}
hasher := NewHasher(false)
defer hasher.returnToPool()

_, err = hasher.hash.Write([]byte("noot"))
_, err := hasher.hash.Write([]byte("noot"))
if err != nil {
t.Error(err)
}
Expand All @@ -62,10 +58,8 @@ func TestNewHasher(t *testing.T) {
}

func TestHashLeaf(t *testing.T) {
hasher, err := NewHasher()
if err != nil {
t.Fatal(err)
}
hasher := NewHasher(false)
defer hasher.returnToPool()

n := &leaf{key: generateRandBytes(380), value: generateRandBytes(64)}
h, err := hasher.Hash(n)
Expand All @@ -77,10 +71,8 @@ func TestHashLeaf(t *testing.T) {
}

func TestHashBranch(t *testing.T) {
hasher, err := NewHasher()
if err != nil {
t.Fatal(err)
}
hasher := NewHasher(false)
defer hasher.returnToPool()

n := &branch{key: generateRandBytes(380), value: generateRandBytes(380)}
n.children[3] = &leaf{key: generateRandBytes(380), value: generateRandBytes(380)}
Expand All @@ -93,13 +85,11 @@ func TestHashBranch(t *testing.T) {
}

func TestHashShort(t *testing.T) {
hasher, err := NewHasher()
if err != nil {
t.Fatal(err)
}
hasher := NewHasher(false)
defer hasher.returnToPool()

n := &leaf{key: generateRandBytes(2), value: generateRandBytes(3)}
expected, err := n.encode()
expected, err := hasher.encode(n)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading