Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi: Make NewMessage() usable for creating messages for reading #591

Merged
merged 5 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions answer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestPromiseFulfill(t *testing.T) {
t.Run("Done", func(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
done := p.Answer().Done()
msg, seg, _ := NewMessage(SingleSegment(nil))
msg, seg := NewSingleSegmentMessage(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{DataSize: 8})
Expand All @@ -75,7 +75,7 @@ func TestPromiseFulfill(t *testing.T) {
p := NewPromise(dummyMethod, dummyPipelineCaller{}, nil)
defer p.ReleaseClients()
ans := p.Answer()
msg, seg, _ := NewMessage(SingleSegment(nil))
msg, seg := NewSingleSegmentMessage(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{DataSize: 8})
Expand All @@ -99,7 +99,7 @@ func TestPromiseFulfill(t *testing.T) {
h := new(dummyHook)
c := NewClient(h)
defer c.Release()
msg, seg, _ := NewMessage(SingleSegment(nil))
msg, seg := NewSingleSegmentMessage(nil)
defer msg.Release()

res, _ := NewStruct(seg, ObjectSize{PointerCount: 3})
Expand Down
20 changes: 19 additions & 1 deletion arena.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ func (ssa *SingleSegmentArena) Release() {
type MultiSegmentArena struct {
segs []Segment

// rawData is set when the individual segments were all demuxed from
// the passed raw data slice.
rawData []byte

// bp is the bufferpool assotiated with this arena's segments if it was
// initialized for writing.
bp *bufferpool.Pool
Expand All @@ -175,6 +179,7 @@ func MultiSegment(b [][]byte) *MultiSegmentArena {
if b == nil {
msa := multiSegmentPool.Get().(*MultiSegmentArena)
msa.fromPool = true
msa.bp = &bufferpool.Default
return msa
}
return multiSegment(b)
Expand All @@ -190,6 +195,14 @@ func MultiSegment(b [][]byte) *MultiSegmentArena {
// Calling Release is optional; if not done the garbage collector
// will release the memory per usual.
func (msa *MultiSegmentArena) Release() {
// When this was demuxed from a single slice, return the entire slice.
if msa.rawData != nil && msa.bp != nil {
zeroSlice(msa.rawData)
msa.bp.Put(msa.rawData)
msa.bp = nil
}
msa.rawData = nil

for i := range msa.segs {
if msa.bp != nil {
zeroSlice(msa.segs[i].data)
Expand Down Expand Up @@ -236,7 +249,10 @@ var multiSegmentPool = sync.Pool{

// demuxArena slices data into a multi-segment arena. It assumes that
// len(data) >= hdr.totalSize().
func (msa *MultiSegmentArena) demux(hdr streamHeader, data []byte) error {
//
// bp should point to the bufferpool which will receive back data once the
// arena is released. It may be nil if this should not be returned anywhere.
func (msa *MultiSegmentArena) demux(hdr streamHeader, data []byte, bp *bufferpool.Pool) error {
maxSeg := hdr.maxSegment()
if int64(maxSeg) > int64(maxInt-1) {
return errors.New("number of segments overflows int")
Expand All @@ -261,6 +277,8 @@ func (msa *MultiSegmentArena) demux(hdr streamHeader, data []byte) error {
msa.segs[i].id = i
}

msa.rawData = data
msa.bp = bp
return nil
}

Expand Down
9 changes: 8 additions & 1 deletion canonical.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@ import (
// for equivalent structs, even as the schema evolves. The blob is
// suitable for hashing or signing.
func Canonicalize(s Struct) ([]byte, error) {
msg, seg, _ := NewMessage(SingleSegment(nil))
msg, seg := NewSingleSegmentMessage(nil)
if !s.IsValid() {
// Ensure compatbility to existing behavior: even if the struct
// is not valid, at least the root pointer is allocated and
// returned as canonical. Without this,
// TestCanonicalize/Struct{} fails.
if _, err := msg.allocRootPointerSpace(); err != nil {
return nil, err
}
return seg.Data(), nil
}
root, err := NewRootStruct(seg, canonicalStructSize(s))
Expand Down
20 changes: 10 additions & 10 deletions canonical_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "empty struct",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{})
return s
},
want: []byte{0xfc, 0xff, 0xff, 0xff, 0, 0, 0, 0},
}, {
name: "zero data, zero pointer struct",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{DataSize: 8, PointerCount: 1})
return s
},
want: []byte{0xfc, 0xff, 0xff, 0xff, 0, 0, 0, 0},
}, {
name: "one word data struct",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{DataSize: 8, PointerCount: 1})
s.SetUint16(0, 0xbeef)
return s
Expand All @@ -47,7 +47,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "two pointers to zero structs",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 2})
e1, _ := NewStruct(seg, ObjectSize{DataSize: 8})
e2, _ := NewStruct(seg, ObjectSize{DataSize: 8})
Expand All @@ -63,7 +63,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "pointer to interface",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 2})
iface := NewInterface(seg, 1)
s.SetPtr(0, iface.ToPtr())
Expand All @@ -76,7 +76,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "int list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewInt8List(seg, 5)
s.SetPtr(0, l.ToPtr())
Expand All @@ -95,7 +95,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "zero int list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewInt8List(seg, 5)
s.SetPtr(0, l.ToPtr())
Expand All @@ -110,7 +110,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "struct list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewCompositeList(seg, ObjectSize{DataSize: 8, PointerCount: 1}, 2)
s.SetPtr(0, l.ToPtr())
Expand All @@ -133,7 +133,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "zero struct list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewCompositeList(seg, ObjectSize{DataSize: 16, PointerCount: 2}, 3)
s.SetPtr(0, l.ToPtr())
Expand All @@ -148,7 +148,7 @@ func TestCanonicalize(t *testing.T) {
}, {
name: "zero-length struct list",
f: func() Struct {
_, seg, _ := NewMessage(SingleSegment(nil))
_, seg := NewSingleSegmentMessage(nil)
s, _ := NewStruct(seg, ObjectSize{PointerCount: 1})
l, _ := NewCompositeList(seg, ObjectSize{DataSize: 16, PointerCount: 2}, 0)
s.SetPtr(0, l.ToPtr())
Expand Down
32 changes: 9 additions & 23 deletions capability_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,8 @@ func (dr *dummyReturner) AllocResults(sz ObjectSize) (Struct, error) {
if dr.s.IsValid() {
return Struct{}, errors.New("AllocResults called multiple times")
}
_, seg, err := NewMessage(SingleSegment(nil))
if err != nil {
return Struct{}, err
}
_, seg := NewSingleSegmentMessage(nil)
var err error
dr.s, err = NewRootStruct(seg, sz)
return dr.s, err
}
Expand All @@ -377,10 +375,7 @@ func (dr *dummyReturner) ReleaseResults() {
}

func TestToInterface(t *testing.T) {
_, seg, err := NewMessage(SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
_, seg := NewSingleSegmentMessage(nil)
tests := []struct {
ptr Ptr
in Interface
Expand All @@ -399,10 +394,7 @@ func TestToInterface(t *testing.T) {
}

func TestInterface_value(t *testing.T) {
_, seg, err := NewMessage(SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
_, seg := NewSingleSegmentMessage(nil)
tests := []struct {
in Interface
val rawPointer
Expand All @@ -421,10 +413,7 @@ func TestInterface_value(t *testing.T) {
}

func TestTransform(t *testing.T) {
_, s, err := NewMessage(SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
_, s := NewSingleSegmentMessage(nil)
root, err := NewStruct(s, ObjectSize{PointerCount: 2})
if err != nil {
t.Fatal(err)
Expand All @@ -442,7 +431,7 @@ func TestTransform(t *testing.T) {
b.SetUint64(0, 2)
a.SetPtr(0, b.ToPtr())

dmsg, d, err := NewMessage(SingleSegment(nil))
dmsg, d := NewSingleSegmentMessage(nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -675,20 +664,17 @@ func deepPointerEqual(a, b Ptr) bool {
if !a.IsValid() || !b.IsValid() {
return false
}
msgA, _, _ := NewMessage(SingleSegment(nil))
msgA, _ := NewSingleSegmentMessage(nil)
msgA.SetRoot(a)
abytes, _ := msgA.Marshal()
msgB, _, _ := NewMessage(SingleSegment(nil))
msgB, _ := NewSingleSegmentMessage(nil)
msgB.SetRoot(b)
bbytes, _ := msgB.Marshal()
return bytes.Equal(abytes, bbytes)
}

func newEmptyStruct() Struct {
_, seg, err := NewMessage(SingleSegment(nil))
if err != nil {
panic(err)
}
_, seg := NewSingleSegmentMessage(nil)
s, err := NewRootStruct(seg, ObjectSize{})
if err != nil {
panic(err)
Expand Down
2 changes: 1 addition & 1 deletion capnpc-go/capnpc-go.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func (g *generator) defineSchemaVar() error {
}
sort.Sort(uint64Slice(ids))

msg, seg, _ := capnp.NewMessage(capnp.SingleSegment(nil))
msg, seg := capnp.NewSingleSegmentMessage(nil)
req, _ := schema.NewRootCodeGeneratorRequest(seg)
// TODO(light): find largest object size and use that to allocate list
nodes, _ := req.NewNodes(int32(len(g.nodes)))
Expand Down
7 changes: 2 additions & 5 deletions capnpc-go/fileparts.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,8 @@ func (sd *staticData) init(fileID uint64) {
}

func (sd *staticData) copyData(obj capnp.Ptr) (staticDataRef, error) {
m, _, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
return staticDataRef{}, err
}
err = m.SetRoot(obj)
m, _ := capnp.NewSingleSegmentMessage(nil)
err := m.SetRoot(obj)
if err != nil {
return staticDataRef{}, err
}
Expand Down
22 changes: 17 additions & 5 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,35 @@ func (d *Decoder) Decode() (*Message, error) {
if err != nil {
return nil, exc.WrapError("decode", err)
}

// Special case an empty message to return a new MultiSegment message
// ready for writing. This maintains compatibility to tests and older
// implementation of message and arenas.
if hdr.maxSegment() == 0 && total == 0 {
msg, _ := NewMultiSegmentMessage(nil)
return msg, nil
}

// TODO(someday): if total size is greater than can fit in one buffer,
// attempt to allocate buffer per segment.
if total > maxSize-uint64(len(hdr)) || total > uint64(maxInt) {
return nil, errors.New("decode: message too large")
}

// Read segments.
buf := bufferpool.Default.Get(int(total))
bp := &bufferpool.Default
buf := bp.Get(int(total))
if _, err := io.ReadFull(d.r, buf); err != nil {
return nil, exc.WrapError("decode: read segments", err)
}

arena := MultiSegment(nil)
if err = arena.demux(hdr, buf); err != nil {
if err = arena.demux(hdr, buf, bp); err != nil {
return nil, exc.WrapError("decode", err)
}

return &Message{Arena: arena}, nil
msg, _, err := NewMessage(arena)
return msg, err
}

func (d *Decoder) readHeader(maxSize uint64) (streamHeader, error) {
Expand Down Expand Up @@ -162,11 +173,12 @@ func Unmarshal(data []byte) (*Message, error) {
}

arena := MultiSegment(nil)
if err := arena.demux(hdr, data); err != nil {
if err := arena.demux(hdr, data, nil); err != nil {
return nil, exc.WrapError("unmarshal", err)
}

return &Message{Arena: arena}, nil
msg, _, err := NewMessage(arena)
return msg, err
}

// UnmarshalPacked reads a packed serialized stream into a message.
Expand Down
Loading