Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate zstd compression into chain exchange #842

Merged
merged 3 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions chainexchange/pubsub.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package chainexchange

import (
"bytes"
"context"
"fmt"
"sync"
"time"

"github.com/filecoin-project/go-f3/gpbft"
"github.com/filecoin-project/go-f3/internal/encoding"
"github.com/filecoin-project/go-f3/internal/psutil"
lru "github.com/hashicorp/golang-lru/v2"
logging "github.com/ipfs/go-log/v2"
Expand Down Expand Up @@ -38,18 +38,24 @@
pendingCacheAsWanted chan Message
topic *pubsub.Topic
stop func() error
encoding *encoding.ZSTD[*Message]
}

func NewPubSubChainExchange(o ...Option) (*PubSubChainExchange, error) {
opts, err := newOptions(o...)
if err != nil {
return nil, err
}
zstd, err := encoding.NewZSTD[*Message]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going with ZSTD by default?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For chain exchange yes. For GPBFT it's configurable via manifest.

Happy to make it configurable for chain exchange too if you think it's worth doing.

if err != nil {
return nil, err
}

Check warning on line 52 in chainexchange/pubsub.go

View check run for this annotation

Codecov / codecov/patch

chainexchange/pubsub.go#L51-L52

Added lines #L51 - L52 were not covered by tests
return &PubSubChainExchange{
options: opts,
chainsWanted: map[uint64]*lru.Cache[gpbft.ECChainKey, *chainPortion]{},
chainsDiscovered: map[uint64]*lru.Cache[gpbft.ECChainKey, *chainPortion]{},
pendingCacheAsWanted: make(chan Message, 100), // TODO: parameterise.
encoding: zstd,
}, nil
}

Expand Down Expand Up @@ -189,8 +195,7 @@

func (p *PubSubChainExchange) validatePubSubMessage(_ context.Context, _ peer.ID, msg *pubsub.Message) pubsub.ValidationResult {
var cmsg Message
buf := bytes.NewBuffer(msg.Data)
if err := cmsg.UnmarshalCBOR(buf); err != nil {
if err := p.encoding.Decode(msg.Data, &cmsg); err != nil {
log.Debugw("failed to decode message", "from", msg.GetFrom(), "err", err)
return pubsub.ValidationReject
}
Expand Down Expand Up @@ -266,12 +271,11 @@
log.Warnw("Dropping wanted cache entry. Chain exchange is too slow to process chains as wanted", "msg", msg)
}

// TODO: integrate zstd compression.
var buf bytes.Buffer
if err := msg.MarshalCBOR(&buf); err != nil {
encoded, err := p.encoding.Encode(&msg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
if err := p.topic.Publish(ctx, buf.Bytes()); err != nil {
if err := p.topic.Publish(ctx, encoded); err != nil {
return fmt.Errorf("failed to publish message: %w", err)
}
return nil
Expand Down
50 changes: 27 additions & 23 deletions msg_encoding_test.go → encoding_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/filecoin-project/go-bitfield"
"github.com/filecoin-project/go-f3/gpbft"
"github.com/filecoin-project/go-f3/internal/encoding"
"github.com/ipfs/go-cid"
"github.com/multiformats/go-multihash"
"github.com/stretchr/testify/require"
Expand All @@ -15,23 +16,23 @@ const seed = 1413

func BenchmarkCborEncoding(b *testing.B) {
rng := rand.New(rand.NewSource(seed))
encoder := &cborGMessageEncoding{}
encoder := encoding.NewCBOR[*PartialGMessage]()
msg := generateRandomPartialGMessage(b, rng)

b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if _, err := encoder.Encode(msg); err != nil {
require.NoError(b, err)
}
got, err := encoder.Encode(msg)
require.NoError(b, err)
require.NotEmpty(b, got)
}
})
}

func BenchmarkCborDecoding(b *testing.B) {
rng := rand.New(rand.NewSource(seed))
encoder := &cborGMessageEncoding{}
encoder := encoding.NewCBOR[*PartialGMessage]()
msg := generateRandomPartialGMessage(b, rng)
data, err := encoder.Encode(msg)
require.NoError(b, err)
Expand All @@ -40,34 +41,33 @@ func BenchmarkCborDecoding(b *testing.B) {
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if got, err := encoder.Decode(data); err != nil {
require.NoError(b, err)
require.Equal(b, msg, got)
}
var got PartialGMessage
require.NoError(b, encoder.Decode(data, &got))
require.Equal(b, msg, &got)
}
})
}

func BenchmarkZstdEncoding(b *testing.B) {
rng := rand.New(rand.NewSource(seed))
encoder, err := newZstdGMessageEncoding()
encoder, err := encoding.NewZSTD[*PartialGMessage]()
require.NoError(b, err)
msg := generateRandomPartialGMessage(b, rng)

b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if _, err := encoder.Encode(msg); err != nil {
require.NoError(b, err)
}
got, err := encoder.Encode(msg)
require.NoError(b, err)
require.NotEmpty(b, got)
}
})
}

func BenchmarkZstdDecoding(b *testing.B) {
rng := rand.New(rand.NewSource(seed))
encoder, err := newZstdGMessageEncoding()
encoder, err := encoding.NewZSTD[*PartialGMessage]()
require.NoError(b, err)
msg := generateRandomPartialGMessage(b, rng)
data, err := encoder.Encode(msg)
Expand All @@ -77,10 +77,9 @@ func BenchmarkZstdDecoding(b *testing.B) {
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if got, err := encoder.Decode(data); err != nil {
require.NoError(b, err)
require.Equal(b, msg, got)
}
var got PartialGMessage
require.NoError(b, encoder.Decode(data, &got))
require.Equal(b, msg, &got)
}
})
}
Expand All @@ -99,9 +98,8 @@ func generateRandomPartialGMessage(b *testing.B, rng *rand.Rand) *PartialGMessag
func generateRandomGMessage(b *testing.B, rng *rand.Rand) *gpbft.GMessage {
var maybeTicket []byte
if rng.Float64() < 0.5 {
generateRandomBytes(b, rng, 96)
maybeTicket = generateRandomBytes(b, rng, 96)
}

return &gpbft.GMessage{
Sender: gpbft.ActorID(rng.Uint64()),
Vote: generateRandomPayload(b, rng),
Expand All @@ -114,7 +112,7 @@ func generateRandomGMessage(b *testing.B, rng *rand.Rand) *gpbft.GMessage {
func generateRandomJustification(b *testing.B, rng *rand.Rand) *gpbft.Justification {
return &gpbft.Justification{
Vote: generateRandomPayload(b, rng),
Signers: generateRandomBitfield(rng),
Signers: generateRandomBitfield(b, rng),
Signature: generateRandomBytes(b, rng, 96),
}
}
Expand All @@ -138,12 +136,18 @@ func generateRandomPayload(b *testing.B, rng *rand.Rand) gpbft.Payload {
}
}

func generateRandomBitfield(rng *rand.Rand) bitfield.BitField {
func generateRandomBitfield(b *testing.B, rng *rand.Rand) bitfield.BitField {
ids := make([]uint64, rng.Intn(2_000)+1)
for i := range ids {
ids[i] = rng.Uint64()
}
return bitfield.NewFromSet(ids)
// Copy the bitfield once to force initialization of internal bit field state.
// This is to work around the equality assertions in tests, where under the hood
// reflection is used to check for equality. This way we can avoid writing custom
// equality checking for bitfields.
bitField, err := bitfield.NewFromSet(ids).Copy()
require.NoError(b, err)
return bitField
}

func generateRandomECChain(b *testing.B, rng *rand.Rand, length int) *gpbft.ECChain {
Expand Down
15 changes: 8 additions & 7 deletions host.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/filecoin-project/go-f3/gpbft"
"github.com/filecoin-project/go-f3/internal/caching"
"github.com/filecoin-project/go-f3/internal/clock"
"github.com/filecoin-project/go-f3/internal/encoding"
"github.com/filecoin-project/go-f3/internal/psutil"
"github.com/filecoin-project/go-f3/internal/writeaheadlog"
"github.com/filecoin-project/go-f3/manifest"
Expand Down Expand Up @@ -52,7 +53,7 @@ type gpbftRunner struct {
selfMessages map[uint64]map[roundPhase][]*gpbft.GMessage

inputs gpbftInputs
msgEncoding gMessageEncoding
msgEncoding encoding.EncodeDecoder[*PartialGMessage]
pmm *partialMessageManager
pmv *cachingPartialValidator
pmCache *caching.GroupedSet
Expand Down Expand Up @@ -138,12 +139,12 @@ func newRunner(
runner.participant = p

if runner.manifest.PubSub.CompressionEnabled {
runner.msgEncoding, err = newZstdGMessageEncoding()
runner.msgEncoding, err = encoding.NewZSTD[*PartialGMessage]()
if err != nil {
return nil, err
}
} else {
runner.msgEncoding = &cborGMessageEncoding{}
runner.msgEncoding = encoding.NewCBOR[*PartialGMessage]()
}

runner.pmm, err = newPartialMessageManager(runner.Progress, ps, m)
Expand Down Expand Up @@ -541,15 +542,15 @@ func (h *gpbftRunner) validatePubsubMessage(ctx context.Context, _ peer.ID, msg
recordValidationTime(ctx, start, _result)
}(time.Now())

pgmsg, err := h.msgEncoding.Decode(msg.Data)
if err != nil {
var pgmsg PartialGMessage
if err := h.msgEncoding.Decode(msg.Data, &pgmsg); err != nil {
log.Debugw("failed to decode message", "from", msg.GetFrom(), "err", err)
return pubsub.ValidationReject
}

gmsg, completed := h.pmm.CompleteMessage(ctx, pgmsg)
gmsg, completed := h.pmm.CompleteMessage(ctx, &pgmsg)
if !completed {
partiallyValidatedMessage, err := h.pmv.PartiallyValidateMessage(pgmsg)
partiallyValidatedMessage, err := h.pmv.PartiallyValidateMessage(&pgmsg)
result := pubsubValidationResultFromError(err)
if result == pubsub.ValidationAccept {
msg.ValidatorData = partiallyValidatedMessage
Expand Down
86 changes: 86 additions & 0 deletions internal/encoding/encoding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package encoding

import (
"bytes"
"fmt"

"github.com/klauspost/compress/zstd"
cbg "github.com/whyrusleeping/cbor-gen"
)

// maxDecompressedSize is the default maximum amount of memory allocated by the
// zstd decoder. The limit of 1MiB is chosen based on the default maximum message
// size in GossipSub.
const maxDecompressedSize = 1 << 20

type CBORMarshalUnmarshaler interface {
cbg.CBORMarshaler
cbg.CBORUnmarshaler
}

type EncodeDecoder[T CBORMarshalUnmarshaler] interface {
Encode(v T) ([]byte, error)
Decode([]byte, T) error
}

type CBOR[T CBORMarshalUnmarshaler] struct{}

func NewCBOR[T CBORMarshalUnmarshaler]() *CBOR[T] {
return &CBOR[T]{}
}

func (c *CBOR[T]) Encode(m T) ([]byte, error) {
var buf bytes.Buffer
if err := m.MarshalCBOR(&buf); err != nil {
return nil, err
}

Check warning on line 36 in internal/encoding/encoding.go

View check run for this annotation

Codecov / codecov/patch

internal/encoding/encoding.go#L35-L36

Added lines #L35 - L36 were not covered by tests
return buf.Bytes(), nil
}

func (c *CBOR[T]) Decode(v []byte, t T) error {
r := bytes.NewReader(v)
return t.UnmarshalCBOR(r)
}

type ZSTD[T CBORMarshalUnmarshaler] struct {
cborEncoding *CBOR[T]
compressor *zstd.Encoder
decompressor *zstd.Decoder
}

func NewZSTD[T CBORMarshalUnmarshaler]() (*ZSTD[T], error) {
writer, err := zstd.NewWriter(nil)
if err != nil {
return nil, err
}

Check warning on line 55 in internal/encoding/encoding.go

View check run for this annotation

Codecov / codecov/patch

internal/encoding/encoding.go#L54-L55

Added lines #L54 - L55 were not covered by tests
reader, err := zstd.NewReader(nil, zstd.WithDecoderMaxMemory(maxDecompressedSize))
if err != nil {
return nil, err
}

Check warning on line 59 in internal/encoding/encoding.go

View check run for this annotation

Codecov / codecov/patch

internal/encoding/encoding.go#L58-L59

Added lines #L58 - L59 were not covered by tests
return &ZSTD[T]{
cborEncoding: &CBOR[T]{},
compressor: writer,
decompressor: reader,
}, nil
}

func (c *ZSTD[T]) Encode(m T) ([]byte, error) {
cborEncoded, err := c.cborEncoding.Encode(m)
if len(cborEncoded) > maxDecompressedSize {
// Error out early if the encoded value is too large to be decompressed.
return nil, fmt.Errorf("encoded value cannot exceed maximum size: %d > %d", len(cborEncoded), maxDecompressedSize)
}

Check warning on line 72 in internal/encoding/encoding.go

View check run for this annotation

Codecov / codecov/patch

internal/encoding/encoding.go#L70-L72

Added lines #L70 - L72 were not covered by tests
if err != nil {
return nil, err
}

Check warning on line 75 in internal/encoding/encoding.go

View check run for this annotation

Codecov / codecov/patch

internal/encoding/encoding.go#L74-L75

Added lines #L74 - L75 were not covered by tests
compressed := c.compressor.EncodeAll(cborEncoded, make([]byte, 0, len(cborEncoded)))
return compressed, nil
}

func (c *ZSTD[T]) Decode(v []byte, t T) error {
cborEncoded, err := c.decompressor.DecodeAll(v, make([]byte, 0, len(v)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future change: we should use a buffer pool for these short-lived buffers (https://pkg.go.dev/sync#Pool). If we do that, we can also allocate these buffers with 1MiB capacities and use WithDecodeAllCapLimit to avoid any allocations.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the plan. Thank you for reminding me. Captured #849

if err != nil {
return err
}
return c.cborEncoding.Decode(cborEncoded, t)
}
Loading
Loading