Skip to content
This repository has been archived by the owner on Feb 1, 2023. It is now read-only.

Commit

Permalink
feat: add a custom CID type
Browse files Browse the repository at this point in the history
This allows us to marshal/unmarshal/size protobufs without copying CID around.
  • Loading branch information
Stebalien committed Mar 18, 2020
1 parent f6db5f7 commit 35db51b
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 135 deletions.
26 changes: 12 additions & 14 deletions message/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package message

import (
"encoding/binary"
"fmt"
"errors"
"io"

pb "github.com/ipfs/go-bitswap/message/pb"
Expand Down Expand Up @@ -117,14 +117,15 @@ type Entry struct {
SendDontHave bool
}

var errCidMissing = errors.New("missing cid")

func newMessageFromProto(pbm pb.Message) (BitSwapMessage, error) {
m := newMsg(pbm.Wantlist.Full)
for _, e := range pbm.Wantlist.Entries {
c, err := cid.Cast([]byte(e.Block))
if err != nil {
return nil, fmt.Errorf("incorrectly formatted cid in wantlist: %s", err)
if !e.Block.Cid.Defined() {
return nil, errCidMissing
}
m.addEntry(c, e.Priority, e.Cancel, e.WantType, e.SendDontHave)
m.addEntry(e.Block.Cid, e.Priority, e.Cancel, e.WantType, e.SendDontHave)
}

// deprecated
Expand Down Expand Up @@ -155,13 +156,10 @@ func newMessageFromProto(pbm pb.Message) (BitSwapMessage, error) {
}

for _, bi := range pbm.GetBlockPresences() {
c, err := cid.Cast(bi.GetCid())
if err != nil {
return nil, err
if !bi.Cid.Cid.Defined() {
return nil, errCidMissing
}

t := bi.GetType()
m.AddBlockPresence(c, t)
m.AddBlockPresence(bi.Cid.Cid, bi.Type)
}

m.pendingBytes = pbm.PendingBytes
Expand Down Expand Up @@ -311,7 +309,7 @@ func (m *impl) Size() int {

func BlockPresenceSize(c cid.Cid) int {
return (&pb.Message_BlockPresence{
Cid: c.Bytes(),
Cid: pb.Cid{Cid: c},
Type: pb.Message_Have,
}).Size()
}
Expand Down Expand Up @@ -341,7 +339,7 @@ func FromMsgReader(r msgio.Reader) (BitSwapMessage, error) {

func entryToPB(e *Entry) pb.Message_Wantlist_Entry {
return pb.Message_Wantlist_Entry{
Block: e.Cid.Bytes(),
Block: pb.Cid{Cid: e.Cid},
Priority: int32(e.Priority),
Cancel: e.Cancel,
WantType: e.WantType,
Expand Down Expand Up @@ -385,7 +383,7 @@ func (m *impl) ToProtoV1() *pb.Message {
pbm.BlockPresences = make([]pb.Message_BlockPresence, 0, len(m.blockPresences))
for c, t := range m.blockPresences {
pbm.BlockPresences = append(pbm.BlockPresences, pb.Message_BlockPresence{
Cid: c.Bytes(),
Cid: pb.Cid{Cid: c},
Type: t,
})
}
Expand Down
4 changes: 2 additions & 2 deletions message/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestNewMessageFromProto(t *testing.T) {
str := mkFakeCid("a_key")
protoMessage := new(pb.Message)
protoMessage.Wantlist.Entries = []pb.Message_Wantlist_Entry{
{Block: str.Bytes()},
{Block: pb.Cid{Cid: str}},
}
if !wantlistContains(&protoMessage.Wantlist, str) {
t.Fail()
Expand Down Expand Up @@ -164,7 +164,7 @@ func TestToAndFromNetMessage(t *testing.T) {

func wantlistContains(wantlist *pb.Message_Wantlist, c cid.Cid) bool {
for _, e := range wantlist.GetEntries() {
if bytes.Equal(e.GetBlock(), c.Bytes()) {
if e.Block.Cid.Defined() && c.Equals(e.Block.Cid) {
return true
}
}
Expand Down
43 changes: 43 additions & 0 deletions message/pb/cid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package bitswap_message_pb

import (
"github.com/ipfs/go-cid"
)

// NOTE: Don't "embed" the cid, wrap it like we're doing here. Otherwise, gogo
// will try to use the Bytes() function.

// Cid is a custom type for CIDs in protobufs, that allows us to avoid
// reallocating.
type Cid struct {
Cid cid.Cid
}

func (c Cid) Marshal() ([]byte, error) {
return c.Cid.Bytes(), nil
}

func (c *Cid) MarshalTo(data []byte) (int, error) {
return copy(data[:c.Size()], c.Cid.Bytes()), nil
}

func (c *Cid) Unmarshal(data []byte) (err error) {
c.Cid, err = cid.Cast(data)
return err
}

func (c *Cid) Size() int {
return len(c.Cid.KeyString())
}

func (c Cid) MarshalJSON() ([]byte, error) {
return c.Cid.MarshalJSON()
}

func (c *Cid) UnmarshalJSON(data []byte) error {
return c.Cid.UnmarshalJSON(data)
}

func (c Cid) Equal(other Cid) bool {
return c.Cid.Equals(c.Cid)
}
Loading

0 comments on commit 35db51b

Please sign in to comment.