From fd5c031877d3554979ebf472ef2fffa2e7ef19f7 Mon Sep 17 00:00:00 2001 From: Dirk McCormick Date: Tue, 16 Feb 2021 15:17:43 +0100 Subject: [PATCH] feat: allow unmarshalling of object with same fields + more fields than marshalled object --- gen.go | 4 +- package.go | 1 + testgen/main.go | 2 + testing/cbor_gen.go | 2 + testing/cbor_map_gen.go | 588 +++++++++++++++++++++++++++++++++++++- testing/roundtrip_test.go | 93 ++++++ testing/types.go | 22 ++ 7 files changed, 709 insertions(+), 3 deletions(-) diff --git a/gen.go b/gen.go index ed2bf24..7399f8b 100644 --- a/gen.go +++ b/gen.go @@ -64,6 +64,7 @@ import ( var _ = xerrors.Errorf +var _ = cid.Undef `) } @@ -1269,7 +1270,8 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { return doTemplate(w, gti, ` default: - return fmt.Errorf("unknown struct field %d: '%s'", i, name) + // Field doesn't exist on this type, so ignore it + cbg.ScanForLinks(r, func(cid.Cid){}) } } diff --git a/package.go b/package.go index 1943e09..7fa26f7 100644 --- a/package.go +++ b/package.go @@ -16,6 +16,7 @@ var ( defaultImports = []Import{ {Name: "cbg", PkgPath: "github.com/whyrusleeping/cbor-gen"}, {Name: "xerrors", PkgPath: "golang.org/x/xerrors"}, + {Name: "cid", PkgPath: "github.com/ipfs/go-cid"}, } ) diff --git a/testgen/main.go b/testgen/main.go index acf5be4..c94aee7 100644 --- a/testgen/main.go +++ b/testgen/main.go @@ -20,6 +20,8 @@ func main() { if err := cbg.WriteMapEncodersToFile("testing/cbor_map_gen.go", "testing", types.SimpleTypeTree{}, types.NeedScratchForMap{}, + types.SimpleStructV1{}, + types.SimpleStructV2{}, ); err != nil { panic(err) } diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index b3e2e14..6d349d5 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -6,11 +6,13 @@ import ( "fmt" "io" + cid "github.com/ipfs/go-cid" cbg "github.com/whyrusleeping/cbor-gen" xerrors "golang.org/x/xerrors" ) var _ = xerrors.Errorf +var _ = cid.Undef var lengthBufSignedArray = []byte{129} diff --git a/testing/cbor_map_gen.go b/testing/cbor_map_gen.go index 1009d6f..3c2071a 100644 --- a/testing/cbor_map_gen.go +++ b/testing/cbor_map_gen.go @@ -6,11 +6,13 @@ import ( "fmt" "io" + cid "github.com/ipfs/go-cid" cbg "github.com/whyrusleeping/cbor-gen" xerrors "golang.org/x/xerrors" ) var _ = xerrors.Errorf +var _ = cid.Undef func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { if t == nil { @@ -402,7 +404,8 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { } default: - return fmt.Errorf("unknown struct field %d: '%s'", i, name) + // Field doesn't exist on this type, so ignore it + cbg.ScanForLinks(r, func(cid.Cid) {}) } } @@ -490,7 +493,588 @@ func (t *NeedScratchForMap) UnmarshalCBOR(r io.Reader) error { } default: - return fmt.Errorf("unknown struct field %d: '%s'", i, name) + // Field doesn't exist on this type, so ignore it + cbg.ScanForLinks(r, func(cid.Cid) {}) + } + } + + return nil +} +func (t *SimpleStructV1) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write([]byte{164}); err != nil { + return err + } + + scratch := make([]byte, 9) + + // t.OldStr (string) (string) + if len("OldStr") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"OldStr\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldStr"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("OldStr")); err != nil { + return err + } + + if len(t.OldStr) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.OldStr was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.OldStr))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.OldStr)); err != nil { + return err + } + + // t.OldBytes ([]uint8) (slice) + if len("OldBytes") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"OldBytes\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldBytes"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("OldBytes")); err != nil { + return err + } + + if len(t.OldBytes) > cbg.ByteArrayMaxLen { + return xerrors.Errorf("Byte array in field t.OldBytes was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajByteString, uint64(len(t.OldBytes))); err != nil { + return err + } + + if _, err := w.Write(t.OldBytes[:]); err != nil { + return err + } + + // t.OldNum (uint64) (uint64) + if len("OldNum") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"OldNum\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldNum"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("OldNum")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.OldNum)); err != nil { + return err + } + + // t.OldPtr (cid.Cid) (struct) + if len("OldPtr") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"OldPtr\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldPtr"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("OldPtr")); err != nil { + return err + } + + if t.OldPtr == nil { + if _, err := w.Write(cbg.CborNull); err != nil { + return err + } + } else { + if err := cbg.WriteCidBuf(scratch, w, *t.OldPtr); err != nil { + return xerrors.Errorf("failed to write cid field t.OldPtr: %w", err) + } + } + + return nil +} + +func (t *SimpleStructV1) UnmarshalCBOR(r io.Reader) error { + *t = SimpleStructV1{} + + br := cbg.GetPeeker(r) + scratch := make([]byte, 8) + + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra > cbg.MaxLength { + return fmt.Errorf("SimpleStructV1: map struct too large (%d)", extra) + } + + var name string + n := extra + + for i := uint64(0); i < n; i++ { + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + name = string(sval) + } + + switch name { + // t.OldStr (string) (string) + case "OldStr": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.OldStr = string(sval) + } + // t.OldBytes ([]uint8) (slice) + case "OldBytes": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + if extra > cbg.ByteArrayMaxLen { + return fmt.Errorf("t.OldBytes: byte array too large (%d)", extra) + } + if maj != cbg.MajByteString { + return fmt.Errorf("expected byte array") + } + + if extra > 0 { + t.OldBytes = make([]uint8, extra) + } + + if _, err := io.ReadFull(br, t.OldBytes[:]); err != nil { + return err + } + // t.OldNum (uint64) (uint64) + case "OldNum": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.OldNum = uint64(extra) + + } + // t.OldPtr (cid.Cid) (struct) + case "OldPtr": + + { + + b, err := br.ReadByte() + if err != nil { + return err + } + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { + return err + } + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.OldPtr: %w", err) + } + + t.OldPtr = &c + } + + } + + default: + // Field doesn't exist on this type, so ignore it + cbg.ScanForLinks(r, func(cid.Cid) {}) + } + } + + return nil +} +func (t *SimpleStructV2) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write([]byte{168}); err != nil { + return err + } + + scratch := make([]byte, 9) + + // t.OldStr (string) (string) + if len("OldStr") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"OldStr\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldStr"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("OldStr")); err != nil { + return err + } + + if len(t.OldStr) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.OldStr was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.OldStr))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.OldStr)); err != nil { + return err + } + + // t.NewStr (string) (string) + if len("NewStr") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"NewStr\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("NewStr"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("NewStr")); err != nil { + return err + } + + if len(t.NewStr) > cbg.MaxLength { + return xerrors.Errorf("Value in field t.NewStr was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len(t.NewStr))); err != nil { + return err + } + if _, err := io.WriteString(w, string(t.NewStr)); err != nil { + return err + } + + // t.OldBytes ([]uint8) (slice) + if len("OldBytes") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"OldBytes\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldBytes"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("OldBytes")); err != nil { + return err + } + + if len(t.OldBytes) > cbg.ByteArrayMaxLen { + return xerrors.Errorf("Byte array in field t.OldBytes was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajByteString, uint64(len(t.OldBytes))); err != nil { + return err + } + + if _, err := w.Write(t.OldBytes[:]); err != nil { + return err + } + + // t.NewBytes ([]uint8) (slice) + if len("NewBytes") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"NewBytes\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("NewBytes"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("NewBytes")); err != nil { + return err + } + + if len(t.NewBytes) > cbg.ByteArrayMaxLen { + return xerrors.Errorf("Byte array in field t.NewBytes was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajByteString, uint64(len(t.NewBytes))); err != nil { + return err + } + + if _, err := w.Write(t.NewBytes[:]); err != nil { + return err + } + + // t.OldNum (uint64) (uint64) + if len("OldNum") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"OldNum\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldNum"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("OldNum")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.OldNum)); err != nil { + return err + } + + // t.NewNum (uint64) (uint64) + if len("NewNum") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"NewNum\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("NewNum"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("NewNum")); err != nil { + return err + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajUnsignedInt, uint64(t.NewNum)); err != nil { + return err + } + + // t.OldPtr (cid.Cid) (struct) + if len("OldPtr") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"OldPtr\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("OldPtr"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("OldPtr")); err != nil { + return err + } + + if t.OldPtr == nil { + if _, err := w.Write(cbg.CborNull); err != nil { + return err + } + } else { + if err := cbg.WriteCidBuf(scratch, w, *t.OldPtr); err != nil { + return xerrors.Errorf("failed to write cid field t.OldPtr: %w", err) + } + } + + // t.NewPtr (cid.Cid) (struct) + if len("NewPtr") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"NewPtr\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("NewPtr"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("NewPtr")); err != nil { + return err + } + + if t.NewPtr == nil { + if _, err := w.Write(cbg.CborNull); err != nil { + return err + } + } else { + if err := cbg.WriteCidBuf(scratch, w, *t.NewPtr); err != nil { + return xerrors.Errorf("failed to write cid field t.NewPtr: %w", err) + } + } + + return nil +} + +func (t *SimpleStructV2) UnmarshalCBOR(r io.Reader) error { + *t = SimpleStructV2{} + + br := cbg.GetPeeker(r) + scratch := make([]byte, 8) + + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra > cbg.MaxLength { + return fmt.Errorf("SimpleStructV2: map struct too large (%d)", extra) + } + + var name string + n := extra + + for i := uint64(0); i < n; i++ { + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + name = string(sval) + } + + switch name { + // t.OldStr (string) (string) + case "OldStr": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.OldStr = string(sval) + } + // t.NewStr (string) (string) + case "NewStr": + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + t.NewStr = string(sval) + } + // t.OldBytes ([]uint8) (slice) + case "OldBytes": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + if extra > cbg.ByteArrayMaxLen { + return fmt.Errorf("t.OldBytes: byte array too large (%d)", extra) + } + if maj != cbg.MajByteString { + return fmt.Errorf("expected byte array") + } + + if extra > 0 { + t.OldBytes = make([]uint8, extra) + } + + if _, err := io.ReadFull(br, t.OldBytes[:]); err != nil { + return err + } + // t.NewBytes ([]uint8) (slice) + case "NewBytes": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + + if extra > cbg.ByteArrayMaxLen { + return fmt.Errorf("t.NewBytes: byte array too large (%d)", extra) + } + if maj != cbg.MajByteString { + return fmt.Errorf("expected byte array") + } + + if extra > 0 { + t.NewBytes = make([]uint8, extra) + } + + if _, err := io.ReadFull(br, t.NewBytes[:]); err != nil { + return err + } + // t.OldNum (uint64) (uint64) + case "OldNum": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.OldNum = uint64(extra) + + } + // t.NewNum (uint64) (uint64) + case "NewNum": + + { + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.NewNum = uint64(extra) + + } + // t.OldPtr (cid.Cid) (struct) + case "OldPtr": + + { + + b, err := br.ReadByte() + if err != nil { + return err + } + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { + return err + } + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.OldPtr: %w", err) + } + + t.OldPtr = &c + } + + } + // t.NewPtr (cid.Cid) (struct) + case "NewPtr": + + { + + b, err := br.ReadByte() + if err != nil { + return err + } + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { + return err + } + + c, err := cbg.ReadCid(br) + if err != nil { + return xerrors.Errorf("failed to read cid field t.NewPtr: %w", err) + } + + t.NewPtr = &c + } + + } + + default: + // Field doesn't exist on this type, so ignore it + cbg.ScanForLinks(r, func(cid.Cid) {}) } } diff --git a/testing/roundtrip_test.go b/testing/roundtrip_test.go index fed37c6..886887e 100644 --- a/testing/roundtrip_test.go +++ b/testing/roundtrip_test.go @@ -3,6 +3,7 @@ package testing import ( "bytes" "encoding/json" + "github.com/ipfs/go-cid" "math/rand" "reflect" "testing" @@ -162,3 +163,95 @@ func TestTimeIsh(t *testing.T) { } } + +func TestLessToMoreFieldsRoundTrip(t *testing.T) { + dummyCid, _ := cid.Parse("bafkqaaa") + obj := &SimpleStructV1{ + OldStr: "hello", + OldBytes: []byte("bytes"), + OldNum: 10, + OldPtr: &dummyCid, + } + + buf := new(bytes.Buffer) + if err := obj.MarshalCBOR(buf); err != nil { + t.Fatal("failed marshaling", err) + } + + enc := buf.Bytes() + + nobj := SimpleStructV2{} + if err := nobj.UnmarshalCBOR(bytes.NewReader(enc)); err != nil { + t.Logf("got bad bytes: %x", enc) + t.Fatal("failed to round trip object: ", err) + } + + if obj.OldStr != nobj.OldStr { + t.Fatal("mismatch ", obj.OldStr, " != ", nobj.OldStr) + } + if nobj.NewStr != "" { + t.Fatal("expected field to be zero value") + } + + if obj.OldNum != nobj.OldNum { + t.Fatal("mismatch ", obj.OldNum, " != ", nobj.OldNum) + } + if nobj.NewNum != 0 { + t.Fatal("expected field to be zero value") + } + + if !bytes.Equal(obj.OldBytes, nobj.OldBytes) { + t.Fatal("mismatch ", obj.OldBytes, " != ", nobj.OldBytes) + } + if nobj.NewBytes != nil { + t.Fatal("expected field to be zero value") + } + + if *obj.OldPtr != *nobj.OldPtr { + t.Fatal("mismatch ", obj.OldPtr, " != ", nobj.OldPtr) + } + if nobj.NewPtr != nil { + t.Fatal("expected field to be zero value") + } +} + +func TestMoreToLessFieldsRoundTrip(t *testing.T) { + dummyCid1, _ := cid.Parse("bafkqaaa") + dummyCid2, _ := cid.Parse("bafkqaab") + obj := &SimpleStructV2{ + OldStr: "oldstr", + NewStr: "newstr", + OldBytes: []byte("oldbytes"), + NewBytes: []byte("newbytes"), + OldNum: 10, + NewNum: 11, + OldPtr: &dummyCid1, + NewPtr: &dummyCid2, + } + + buf := new(bytes.Buffer) + if err := obj.MarshalCBOR(buf); err != nil { + t.Fatal("failed marshaling", err) + } + + enc := buf.Bytes() + + nobj := SimpleStructV1{} + if err := nobj.UnmarshalCBOR(bytes.NewReader(enc)); err != nil { + t.Logf("got bad bytes: %x", enc) + t.Fatal("failed to round trip object: ", err) + } + + if obj.OldStr != nobj.OldStr { + t.Fatal("mismatch", obj.OldStr, " != ", nobj.OldStr) + } + if obj.OldNum != nobj.OldNum { + t.Fatal("mismatch ", obj.OldNum, " != ", nobj.OldNum) + } + if !bytes.Equal(obj.OldBytes, nobj.OldBytes) { + t.Fatal("mismatch ", obj.OldBytes, " != ", nobj.OldBytes) + } + if *obj.OldPtr != *nobj.OldPtr { + t.Fatal("mismatch ", obj.OldPtr, " != ", nobj.OldPtr) + } +} diff --git a/testing/types.go b/testing/types.go index 58a8d2c..239baa9 100644 --- a/testing/types.go +++ b/testing/types.go @@ -1,6 +1,7 @@ package testing import ( + "github.com/ipfs/go-cid" cbg "github.com/whyrusleeping/cbor-gen" ) @@ -43,6 +44,27 @@ type SimpleTypeTree struct { NotPizza *uint64 } +type SimpleStructV1 struct { + OldStr string + OldBytes []byte + OldNum uint64 + OldPtr *cid.Cid +} + +type SimpleStructV2 struct { + OldStr string + NewStr string + + OldBytes []byte + NewBytes []byte + + OldNum uint64 + NewNum uint64 + + OldPtr *cid.Cid + NewPtr *cid.Cid +} + type DeferredContainer struct { Stuff *SimpleTypeOne Deferred *cbg.Deferred