Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(lib/trie): Parallel hash trie. #1657

Merged
merged 10 commits into from
Jul 2, 2021
Merged
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ require (
github.com/urfave/cli v1.20.0
github.com/wasmerio/go-ext-wasm v0.3.2-0.20200326095750-0a32be6068ec
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40 // indirect
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1
google.golang.org/appengine v1.6.5 // indirect
Expand Down
4 changes: 3 additions & 1 deletion lib/trie/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ func encodeRecursive(n node, enc []byte) ([]byte, error) {
return []byte{}, nil
}

nenc, err := n.encode()
hasher := NewHasher(false)
defer returnHasherToPool(hasher)
nenc, err := hasher.encode(n)
if err != nil {
return enc, err
}
Expand Down
185 changes: 175 additions & 10 deletions lib/trie/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,64 @@
package trie

import (
"bytes"
"context"
"hash"
"sync"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/scale"
"golang.org/x/sync/errgroup"

"golang.org/x/crypto/blake2b"
)

type sliceBuffer []byte

func (b *sliceBuffer) Write(data []byte) (n int, err error) {
*b = append(*b, data...)
return len(data), nil
}

func (b *sliceBuffer) Reset() {
*b = (*b)[:0]
}

noot marked this conversation as resolved.
Show resolved Hide resolved
// Hasher is a wrapper around a hash function
type Hasher struct {
hash hash.Hash
hash hash.Hash
tmp sliceBuffer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering why not use bytes.Buffer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we are calculating the upper limit of the slice and preallocating it. Thus, avoiding dynamic allocation.

parallel bool // Whether to use parallel threads when hashing
}

// HasherPool creates a pool of Hasher.
var HasherPool = sync.Pool{
noot marked this conversation as resolved.
Show resolved Hide resolved
New: func() interface{} {
h, _ := blake2b.New256(nil)

return &Hasher{
tmp: make(sliceBuffer, 0, 520), // cap is as large as a full branch node.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how was this calculated? branch nodes can still store a value, so the cap might be larger than this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed it will be only keyed. I will look into this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me know if you fix this and I can take another look

hash: h,
}
},
}

// NewHasher create new Hasher instance
func NewHasher() (*Hasher, error) {
h, err := blake2b.New256(nil)
if err != nil {
return nil, err
}
func NewHasher(parallel bool) *Hasher {
h := HasherPool.Get().(*Hasher)
h.parallel = parallel
return h
}

return &Hasher{
hash: h,
}, nil
func returnHasherToPool(h *Hasher) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this a receiver function on Hasher? It would be nice to just be able to defer h.Return().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

h.tmp.Reset()
h.hash.Reset()
HasherPool.Put(h)
}

// Hash encodes the node and then hashes it if its encoded length is > 32 bytes
func (h *Hasher) Hash(n node) (res []byte, err error) {
encNode, err := n.encode()
encNode, err := h.encode(n)
if err != nil {
return nil, err
}
Expand All @@ -51,6 +84,7 @@ func (h *Hasher) Hash(n node) (res []byte, err error) {
return encNode, nil
}

h.hash.Reset()
// otherwise, hash encoded node
_, err = h.hash.Write(encNode)
if err == nil {
Expand All @@ -59,3 +93,134 @@ func (h *Hasher) Hash(n node) (res []byte, err error) {

return res, err
}

// encode is the high-level function wrapping the encoding for different node types
// encoding has the following format:
// NodeHeader | Extra partial key length | Partial Key | Value
func (h *Hasher) encode(n node) ([]byte, error) {
switch n := n.(type) {
case *branch:
return h.encodeBranch(n)
case *leaf:
return h.encodeLeaf(n)
case nil:
return []byte{0}, nil
}

return nil, nil
}

// encodeBranch encodes a branch with the encoding specified at the top of this package
func (h *Hasher) encodeBranch(b *branch) ([]byte, error) {
if !b.dirty && b.encoding != nil {
return b.encoding, nil
}
h.tmp.Reset()

encoding, err := b.header()
h.tmp = append(h.tmp, encoding...)
if err != nil {
return nil, err
}

h.tmp = append(h.tmp, nibblesToKeyLE(b.key)...)
h.tmp = append(h.tmp, common.Uint16ToBytes(b.childrenBitmap())...)

if b.value != nil {
buffer := bytes.Buffer{}
se := scale.Encoder{Writer: &buffer}
_, err = se.Encode(b.value)
if err != nil {
return h.tmp, err
}
h.tmp = append(h.tmp, buffer.Bytes()...)
}

if h.parallel {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if h.parallel {
switch h.parallel { {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch for one case seems overkill.

wg, _ := errgroup.WithContext(context.Background())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL errgroup

resBuff := make([][]byte, 16)
for i := 0; i < 16; i++ {
func(i int) {
wg.Go(func() error {
child := b.children[i]
if child == nil {
return nil
}

hasher := NewHasher(false)
defer returnHasherToPool(hasher)

encChild, err := hasher.Hash(child)
if err != nil {
return err
}

scEncChild, err := scale.Encode(encChild)
if err != nil {
return err
}
noot marked this conversation as resolved.
Show resolved Hide resolved
resBuff[i] = scEncChild
return nil
})
}(i)
}
if err := wg.Wait(); err != nil {
return nil, err
}

for _, v := range resBuff {
if v != nil {
h.tmp = append(h.tmp, v...)
}
}
} else {
for i := 0; i < 16; i++ {
if child := b.children[i]; child != nil {
hasher := NewHasher(false)
defer returnHasherToPool(hasher)

encChild, err := hasher.Hash(child)
if err != nil {
return nil, err
}

scEncChild, err := scale.Encode(encChild)
if err != nil {
return nil, err
}
h.tmp = append(h.tmp, scEncChild...)
}
}
}

return h.tmp, nil
}

// encodeLeaf encodes a leaf with the encoding specified at the top of this package
func (h *Hasher) encodeLeaf(l *leaf) ([]byte, error) {
if !l.dirty && l.encoding != nil {
return l.encoding, nil
}

h.tmp.Reset()

encoding, err := l.header()
h.tmp = append(h.tmp, encoding...)
if err != nil {
return nil, err
}

h.tmp = append(h.tmp, nibblesToKeyLE(l.key)...)

buffer := bytes.Buffer{}
se := scale.Encoder{Writer: &buffer}

_, err = se.Encode(l.value)
if err != nil {
return nil, err
}

h.tmp = append(h.tmp, buffer.Bytes()...)
l.encoding = h.tmp
return h.tmp, nil
}
30 changes: 10 additions & 20 deletions lib/trie/hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,10 @@ func generateRand(size int) [][]byte {
}

func TestNewHasher(t *testing.T) {
hasher, err := NewHasher()
if err != nil {
t.Fatalf("error creating new hasher: %s", err)
} else if hasher == nil {
t.Fatal("did not create new hasher")
}
hasher := NewHasher(false)
defer returnHasherToPool(hasher)

_, err = hasher.hash.Write([]byte("noot"))
_, err := hasher.hash.Write([]byte("noot"))
if err != nil {
t.Error(err)
}
Expand All @@ -62,10 +58,8 @@ func TestNewHasher(t *testing.T) {
}

func TestHashLeaf(t *testing.T) {
hasher, err := NewHasher()
if err != nil {
t.Fatal(err)
}
hasher := NewHasher(false)
defer returnHasherToPool(hasher)

n := &leaf{key: generateRandBytes(380), value: generateRandBytes(64)}
h, err := hasher.Hash(n)
Expand All @@ -77,10 +71,8 @@ func TestHashLeaf(t *testing.T) {
}

func TestHashBranch(t *testing.T) {
hasher, err := NewHasher()
if err != nil {
t.Fatal(err)
}
hasher := NewHasher(false)
defer returnHasherToPool(hasher)

n := &branch{key: generateRandBytes(380), value: generateRandBytes(380)}
n.children[3] = &leaf{key: generateRandBytes(380), value: generateRandBytes(380)}
Expand All @@ -93,13 +85,11 @@ func TestHashBranch(t *testing.T) {
}

func TestHashShort(t *testing.T) {
hasher, err := NewHasher()
if err != nil {
t.Fatal(err)
}
hasher := NewHasher(false)
defer returnHasherToPool(hasher)

n := &leaf{key: generateRandBytes(2), value: generateRandBytes(3)}
expected, err := n.encode()
expected, err := hasher.encode(n)
if err != nil {
t.Fatal(err)
}
Expand Down
Loading