diff --git a/integration_test.go b/integration_test.go index e28026e6..02fa2e68 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1955,6 +1955,50 @@ func BenchmarkDecode(b *testing.B) { } } +func BenchmarkDecode_Reuse(b *testing.B) { + var buf bytes.Buffer + + r := rand.New(rand.NewSource(12345)) + enc := capnp.NewEncoder(&buf) + count := 10000 + + for i := 0; i < count; i++ { + a := generateA(r) + msg, seg, _ := capnp.NewMessage(capnp.SingleSegment(nil)) + root, _ := air.NewRootBenchmarkA(seg) + a.fill(root) + enc.Encode(msg) + } + + blob := buf.Bytes() + + b.ReportAllocs() + b.SetBytes(int64(buf.Len())) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + dec := capnp.NewDecoder(bytes.NewReader(blob)) + dec.ReuseBuffer() + + for { + msg, err := dec.Decode() + + if err == io.EOF { + break + } + + if err != nil { + b.Fatal(err) + } + + _, err = air.ReadRootBenchmarkA(msg) + if err != nil { + b.Fatal(err) + } + } + } +} + type testArena []byte func (ta testArena) NumSegments() int64 { diff --git a/mem.go b/mem.go index 6898b8d3..919b5027 100644 --- a/mem.go +++ b/mem.go @@ -359,6 +359,23 @@ func (ssa *singleSegmentArena) Allocate(sz Size, segs map[SegmentID]*Segment) (S return 0, *ssa, nil } +type roSingleSegment []byte + +func (ss roSingleSegment) NumSegments() int64 { + return 1 +} + +func (ss roSingleSegment) Data(id SegmentID) ([]byte, error) { + if id != 0 { + return nil, errSegmentOutOfBounds + } + return ss, nil +} + +func (ss roSingleSegment) Allocate(sz Size, segs map[SegmentID]*Segment) (SegmentID, []byte, error) { + return 0, nil, errors.New("capnp: segment is read-only") +} + type multiSegmentArena [][]byte // MultiSegment returns a new arena that allocates new segments when @@ -421,6 +438,14 @@ func (msa *multiSegmentArena) Allocate(sz Size, segs map[SegmentID]*Segment) (Se type Decoder struct { r io.Reader + segbuf [msgHeaderSize]byte + hdrbuf []byte + + reuse bool + buf []byte + msg Message + arena roSingleSegment + // Maximum number of bytes that can be read per call to Decode. // If not set, a reasonable default is used. MaxMessageSize uint64 @@ -443,24 +468,23 @@ func (d *Decoder) Decode() (*Message, error) { if maxSize == 0 { maxSize = defaultDecodeLimit } - var maxSegBuf [msgHeaderSize]byte - if _, err := io.ReadFull(d.r, maxSegBuf[:]); err != nil { + if _, err := io.ReadFull(d.r, d.segbuf[:]); err != nil { return nil, err } - maxSeg := binary.LittleEndian.Uint32(maxSegBuf[:]) + maxSeg := binary.LittleEndian.Uint32(d.segbuf[:]) if maxSeg > maxStreamSegments { return nil, errTooManySegments } hdrSize := streamHeaderSize(maxSeg) - if hdrSize > maxSize { + if hdrSize > maxSize || hdrSize > (1<<31-1) { return nil, errDecodeLimit } - hdrBuf := make([]byte, hdrSize) - copy(hdrBuf, maxSegBuf[:]) - if _, err := io.ReadFull(d.r, hdrBuf[msgHeaderSize:]); err != nil { + d.hdrbuf = resizeSlice(d.hdrbuf, int(hdrSize)) + copy(d.hdrbuf, d.segbuf[:]) + if _, err := io.ReadFull(d.r, d.hdrbuf[msgHeaderSize:]); err != nil { return nil, err } - hdr, _, err := parseStreamHeader(hdrBuf) + hdr, _, err := parseStreamHeader(d.hdrbuf) if err != nil { return nil, err } @@ -468,18 +492,52 @@ func (d *Decoder) Decode() (*Message, error) { if err != nil { return nil, err } - if total > maxSize-hdrSize { + // TODO(someday): if total size is greater than can fit in one buffer, + // attempt to allocate buffer per segment. + if total > maxSize-hdrSize || total > (1<<31-1) { return nil, errDecodeLimit } - buf := make([]byte, int(total)) - if _, err := io.ReadFull(d.r, buf); err != nil { - return nil, err + if !d.reuse { + buf := make([]byte, int(total)) + if _, err := io.ReadFull(d.r, buf); err != nil { + return nil, err + } + arena, err := demuxArena(hdr, buf) + if err != nil { + return nil, err + } + return &Message{Arena: arena}, nil } - arena, err := demuxArena(hdr, buf) - if err != nil { + d.buf = resizeSlice(d.buf, int(total)) + if _, err := io.ReadFull(d.r, d.buf); err != nil { return nil, err } - return &Message{Arena: arena}, nil + var arena Arena + if hdr.maxSegment() == 0 { + d.arena = d.buf[:len(d.buf):len(d.buf)] + arena = &d.arena + } else { + var err error + arena, err = demuxArena(hdr, d.buf) + if err != nil { + return nil, err + } + } + d.msg.Reset(arena) + return &d.msg, nil +} + +func resizeSlice(b []byte, size int) []byte { + if cap(b) < size { + return make([]byte, size) + } + return b[:size] +} + +// ReuseBuffer causes the decoder to reuse its buffer on subsequent decodes. +// The decoder may return messages that cannot handle allocations. +func (d *Decoder) ReuseBuffer() { + d.reuse = true } // Unmarshal reads an unpacked serialized stream into a message. No