diff --git a/codec.go b/codec.go index 83b42a0e..c68ec375 100644 --- a/codec.go +++ b/codec.go @@ -195,6 +195,10 @@ func MustUnmarshalRoot(data []byte) Ptr { return p } +var ( + errTooManySegments = errors.New("message has too many segments") +) + // An Encoder represents a framer for serializing a particular Cap'n // Proto stream. type Encoder struct { @@ -220,6 +224,9 @@ func (e *Encoder) Encode(m *Message) error { if nsegs == 0 { return errors.New("encode: message has no segments") } + if nsegs > 1<<32 { + return exc.WrapError("encode", errTooManySegments) + } e.bufs = append(e.bufs[:0], nil) // first element is placeholder for header maxSeg := SegmentID(nsegs - 1) hdrSize := streamHeaderSize(maxSeg) diff --git a/codec_test.go b/codec_test.go index 1e194123..9a66c9d8 100644 --- a/codec_test.go +++ b/codec_test.go @@ -2,8 +2,11 @@ package capnp import ( "bytes" + "errors" "io" "testing" + + "github.com/stretchr/testify/require" ) func TestEncoder(t *testing.T) { @@ -72,6 +75,40 @@ func TestDecoder(t *testing.T) { } } +type tooManySegsArena struct { + data []byte +} + +func (t *tooManySegsArena) NumSegments() int64 { return 1<<32 + 1 } + +func (t *tooManySegsArena) Data(id SegmentID) ([]byte, error) { + return nil, errors.New("no data") +} + +func (t *tooManySegsArena) Allocate(minsz Size, segs map[SegmentID]*Segment) (SegmentID, []byte, error) { + return 0, nil, errors.New("cannot allocate") +} + +func (t *tooManySegsArena) Release() {} + +// TestEncoderTooManySegments verifies attempting to encode an arena that has +// more segments than possible. +func TestEncoderTooManySegments(t *testing.T) { + t.Parallel() + zeroWord := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + arena := &tooManySegsArena{data: zeroWord} + + // Setup via field because NewMessage checks arena has > 1 segments. + var msg Message + msg.Arena = arena + var buf bytes.Buffer + enc := NewEncoder(&buf) + err := enc.Encode(&msg) + + // Encoding should error with a specific error. + require.ErrorIs(t, err, errTooManySegments) +} + func TestDecoder_MaxMessageSize(t *testing.T) { t.Parallel() diff --git a/integration_test.go b/integration_test.go index f1530a01..38f375c6 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1770,7 +1770,7 @@ func BenchmarkMarshal_ReuseMsg(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { a := data[r.Intn(len(data))] - seg, err := msg.Reset(msg.Arena) + seg, err := msg.Reset(capnp.SingleSegment(nil)) if err != nil { b.Fatal(err) } diff --git a/message_test.go b/message_test.go index 5bd8448c..4d3289b3 100644 --- a/message_test.go +++ b/message_test.go @@ -646,11 +646,11 @@ var errReadOnlyArena = errors.New("Allocate called on read-only arena") func BenchmarkMessageGetFirstSegment(b *testing.B) { var msg Message - var arena Arena = SingleSegment(nil) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { + arena := SingleSegment(nil) _, err := msg.Reset(arena) if err != nil { b.Fatal(err) diff --git a/rpc/transport/transport.go b/rpc/transport/transport.go index bdb90844..31a36563 100644 --- a/rpc/transport/transport.go +++ b/rpc/transport/transport.go @@ -218,7 +218,10 @@ type outgoingMsg struct { } func (o *outgoingMsg) Release() { - if m := o.message.Message(); !o.released && m != nil { + if o.released { + return + } + if m := o.message.Message(); m != nil { o.released = true m.Release() } @@ -246,7 +249,10 @@ func (i *incomingMsg) Message() rpccp.Message { } func (i *incomingMsg) Release() { - if m := i.Message().Message(); !i.released && m != nil { + if i.released { + return + } + if m := i.Message().Message(); m != nil { i.released = true m.Release() } diff --git a/rpc/transport/transport_test.go b/rpc/transport/transport_test.go index 5cee8ef6..072ae052 100644 --- a/rpc/transport/transport_test.go +++ b/rpc/transport/transport_test.go @@ -45,12 +45,10 @@ func testTransport(t *testing.T, makePipe func() (t1, t2 Transport, err error)) if err != nil { t.Fatal("t1.NewMessage #1:", err) } - defer callMsg.Release() bootMsg, err := t1.NewMessage() if err != nil { t.Fatal("t1.NewMessage #2:", err) } - defer bootMsg.Release() // Fill in bootstrap message boot, err := bootMsg.Message().NewBootstrap()