diff --git a/gen.go b/gen.go index c572c0d..8cb50c7 100644 --- a/gen.go +++ b/gen.go @@ -1085,9 +1085,8 @@ func emitCborMarshalStructMap(w io.Writer, gti *GenTypeInfo) error { if _, err := w.Write({{ .MapHeaderAsByteString }}); err != nil { return err } -{{ if .NeedsScratch }} + scratch := make([]byte, 9) -{{ end }} `) if err != nil { return err diff --git a/testgen/main.go b/testgen/main.go index c2d7a2e..3d03ab3 100644 --- a/testgen/main.go +++ b/testgen/main.go @@ -18,6 +18,7 @@ func main() { if err := cbg.WriteMapEncodersToFile("testing/cbor_map_gen.go", "testing", types.SimpleTypeTree{}, + types.NeedScratchForMap{}, ); err != nil { panic(err) } diff --git a/testing/cbor_map_gen.go b/testing/cbor_map_gen.go index 1016c5d..1009d6f 100644 --- a/testing/cbor_map_gen.go +++ b/testing/cbor_map_gen.go @@ -408,3 +408,91 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { return nil } +func (t *NeedScratchForMap) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + if _, err := w.Write([]byte{161}); err != nil { + return err + } + + scratch := make([]byte, 9) + + // t.Thing (bool) (bool) + if len("Thing") > cbg.MaxLength { + return xerrors.Errorf("Value in field \"Thing\" was too long") + } + + if err := cbg.WriteMajorTypeHeaderBuf(scratch, w, cbg.MajTextString, uint64(len("Thing"))); err != nil { + return err + } + if _, err := io.WriteString(w, string("Thing")); err != nil { + return err + } + + if err := cbg.WriteBool(w, t.Thing); err != nil { + return err + } + return nil +} + +func (t *NeedScratchForMap) UnmarshalCBOR(r io.Reader) error { + *t = NeedScratchForMap{} + + br := cbg.GetPeeker(r) + scratch := make([]byte, 8) + + maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra > cbg.MaxLength { + return fmt.Errorf("NeedScratchForMap: map struct too large (%d)", extra) + } + + var name string + n := extra + + for i := uint64(0); i < n; i++ { + + { + sval, err := cbg.ReadStringBuf(br, scratch) + if err != nil { + return err + } + + name = string(sval) + } + + switch name { + // t.Thing (bool) (bool) + case "Thing": + + maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) + if err != nil { + return err + } + if maj != cbg.MajOther { + return fmt.Errorf("booleans must be major type 7") + } + switch extra { + case 20: + t.Thing = false + case 21: + t.Thing = true + default: + return fmt.Errorf("booleans are either major type 7, value 20 or 21 (got %d)", extra) + } + + default: + return fmt.Errorf("unknown struct field %d: '%s'", i, name) + } + } + + return nil +} diff --git a/testing/roundtrip_test.go b/testing/roundtrip_test.go index b506ff5..642e539 100644 --- a/testing/roundtrip_test.go +++ b/testing/roundtrip_test.go @@ -37,6 +37,10 @@ func TestSimpleTypeTree(t *testing.T) { testTypeRoundtrips(t, reflect.TypeOf(SimpleTypeTree{})) } +func TestNeedScratchForMap(t *testing.T) { + testTypeRoundtrips(t, reflect.TypeOf(NeedScratchForMap{})) +} + func testValueRoundtrip(t *testing.T, obj cbg.CBORMarshaler, nobj cbg.CBORUnmarshaler) { buf := new(bytes.Buffer) diff --git a/testing/types.go b/testing/types.go index aa3a993..623ac60 100644 --- a/testing/types.go +++ b/testing/types.go @@ -54,3 +54,8 @@ type FixedArrays struct { Uint8 [20]uint8 Uint64 [20]uint64 } + +// Do not add fields to this type. +type NeedScratchForMap struct { + Thing bool +}