From 971bafc586d22c5bb5dcaab03a1d3e5890e3308b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Magiera?= Date: Wed, 24 Jan 2024 00:28:10 +0100 Subject: [PATCH 1/4] feat: Support transparent struct fields --- gen.go | 42 ++++- testgen/main.go | 3 + testing/cbor_gen.go | 338 ++++++++++++++++++++++++++++++++++++++ testing/roundtrip_test.go | 44 +++++ testing/types.go | 15 ++ 5 files changed, 433 insertions(+), 9 deletions(-) diff --git a/gen.go b/gen.go index 1155c04..f0eb2e6 100644 --- a/gen.go +++ b/gen.go @@ -145,8 +145,9 @@ func (f Field) Len() int { } type GenTypeInfo struct { - Name string - Fields []Field + Name string + Fields []Field + Transparent bool } func (gti *GenTypeInfo) Imports() []Import { @@ -178,10 +179,15 @@ func ParseTypeInfo(itype interface{}) (*GenTypeInfo, error) { pkg := t.PkgPath() out := GenTypeInfo{ - Name: t.Name(), + Name: t.Name(), + Transparent: false, } for i := 0; i < t.NumField(); i++ { + if out.Transparent { + return nil, fmt.Errorf("transparent structs must exactly one field") + } + f := t.Field(i) if !nameIsExported(f.Name) { continue @@ -226,6 +232,12 @@ func ParseTypeInfo(itype interface{}) (*GenTypeInfo, error) { constval = &cv } + _, transparent := tags["transparent"] + if transparent && len(out.Fields) > 0 { + return nil, fmt.Errorf("only one transparent field is allowed") + } + out.Transparent = transparent + _, omitempty := tags["omitempty"] _, preservenil := tags["preservenil"] @@ -270,6 +282,8 @@ func tagparse(v string) (map[string]string, error) { out["preservenil"] = "true" } else if elem == "ignore" || elem == "-" { out["ignore"] = "true" + } else if elem == "transparent" { + out["transparent"] = "true" } else { out["name"] = elem } @@ -680,18 +694,19 @@ func emitCborMarshalArrayField(w io.Writer, f Field) error { func emitCborMarshalStructTuple(w io.Writer, gti *GenTypeInfo) error { // 9 byte buffer to accomodate for the maximum header length (cbor varints are maximum 9 bytes_ - err := doTemplate(w, gti, `var lengthBuf{{ .Name }} = {{ .TupleHeaderAsByteString }} + err := doTemplate(w, gti, `{{if not .Transparent}}var lengthBuf{{ .Name }} = {{ .TupleHeaderAsByteString }}{{end}} func (t *{{ .Name }}) MarshalCBOR(w io.Writer) error { - if t == nil { + {{if not .Transparent}}if t == nil { _, err := w.Write(cbg.CborNull) return err - } + }{{end}} cw := cbg.NewCborWriter(w) - + {{if not .Transparent}} if _, err := cw.Write(lengthBuf{{ .Name }}); err != nil { return err } + {{end}} `) if err != nil { return err @@ -1428,7 +1443,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) { *t = {{.Name}}{} cr := cbg.NewCborReader(r) - + {{ if not .Transparent }} maj, extra, err := {{ ReadHeader "cr" }} if err != nil { return err @@ -1446,7 +1461,12 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) { if extra != {{ len .Fields }} { return fmt.Errorf("cbor input had wrong number of fields") } - + {{ else }} + var maj byte + var extra uint64 + _ = maj + _ = extra + {{ end }} `) if err != nil { return err @@ -1539,6 +1559,10 @@ func emitCborMarshalStructMap(w io.Writer, gti *GenTypeInfo) error { } } + if gti.Transparent { + return fmt.Errorf("transparent fields not supported in map mode, use tuple encoding (outcome should be the same)") + } + err := doTemplate(w, gti, `func (t *{{ .Name }}) MarshalCBOR(w io.Writer) error { if t == nil { _, err := w.Write(cbg.CborNull) diff --git a/testgen/main.go b/testgen/main.go index 214f268..abb1cc2 100644 --- a/testgen/main.go +++ b/testgen/main.go @@ -14,6 +14,9 @@ func main() { types.FixedArrays{}, types.ThingWithSomeTime{}, types.BigField{}, + types.IntArray{}, + types.IntAliasArray{}, + types.TupleIntArray{}, ); err != nil { panic(err) } diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index e25a5d2..97e50f2 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -1297,3 +1297,341 @@ func (t *BigField) UnmarshalCBOR(r io.Reader) (err error) { return nil } + +func (t *IntArray) MarshalCBOR(w io.Writer) error { + + cw := cbg.NewCborWriter(w) + + // t.Ints ([]int64) (slice) + if len(t.Ints) > cbg.MaxLength { + return xerrors.Errorf("Slice value in field t.Ints was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajArray, uint64(len(t.Ints))); err != nil { + return err + } + for _, v := range t.Ints { + if v >= 0 { + if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(v)); err != nil { + return err + } + } else { + if err := cw.WriteMajorTypeHeader(cbg.MajNegativeInt, uint64(-v-1)); err != nil { + return err + } + } + + } + return nil +} + +func (t *IntArray) UnmarshalCBOR(r io.Reader) (err error) { + *t = IntArray{} + + cr := cbg.NewCborReader(r) + + var maj byte + var extra uint64 + _ = maj + _ = extra + + // t.Ints ([]int64) (slice) + + maj, extra, err = cr.ReadHeader() + if err != nil { + return err + } + + if extra > cbg.MaxLength { + return fmt.Errorf("t.Ints: array too large (%d)", extra) + } + + if maj != cbg.MajArray { + return fmt.Errorf("expected cbor array") + } + + if extra > 0 { + t.Ints = make([]int64, extra) + } + + for i := 0; i < int(extra); i++ { + { + var maj byte + var extra uint64 + var err error + _ = maj + _ = extra + _ = err + { + maj, extra, err := cr.ReadHeader() + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative overflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.Ints[i] = int64(extraI) + } + + } + } + return nil +} + +func (t *IntAliasArray) MarshalCBOR(w io.Writer) error { + + cw := cbg.NewCborWriter(w) + + // t.Ints ([]testing.IntAlias) (slice) + if len(t.Ints) > cbg.MaxLength { + return xerrors.Errorf("Slice value in field t.Ints was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajArray, uint64(len(t.Ints))); err != nil { + return err + } + for _, v := range t.Ints { + if v >= 0 { + if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(v)); err != nil { + return err + } + } else { + if err := cw.WriteMajorTypeHeader(cbg.MajNegativeInt, uint64(-v-1)); err != nil { + return err + } + } + + } + return nil +} + +func (t *IntAliasArray) UnmarshalCBOR(r io.Reader) (err error) { + *t = IntAliasArray{} + + cr := cbg.NewCborReader(r) + + var maj byte + var extra uint64 + _ = maj + _ = extra + + // t.Ints ([]testing.IntAlias) (slice) + + maj, extra, err = cr.ReadHeader() + if err != nil { + return err + } + + if extra > cbg.MaxLength { + return fmt.Errorf("t.Ints: array too large (%d)", extra) + } + + if maj != cbg.MajArray { + return fmt.Errorf("expected cbor array") + } + + if extra > 0 { + t.Ints = make([]IntAlias, extra) + } + + for i := 0; i < int(extra); i++ { + { + var maj byte + var extra uint64 + var err error + _ = maj + _ = extra + _ = err + { + maj, extra, err := cr.ReadHeader() + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative overflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.Ints[i] = IntAlias(extraI) + } + + } + } + return nil +} + +var lengthBufTupleIntArray = []byte{131} + +func (t *TupleIntArray) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + + cw := cbg.NewCborWriter(w) + + if _, err := cw.Write(lengthBufTupleIntArray); err != nil { + return err + } + + // t.Int1 (int64) (int64) + if t.Int1 >= 0 { + if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(t.Int1)); err != nil { + return err + } + } else { + if err := cw.WriteMajorTypeHeader(cbg.MajNegativeInt, uint64(-t.Int1-1)); err != nil { + return err + } + } + + // t.Int2 (int64) (int64) + if t.Int2 >= 0 { + if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(t.Int2)); err != nil { + return err + } + } else { + if err := cw.WriteMajorTypeHeader(cbg.MajNegativeInt, uint64(-t.Int2-1)); err != nil { + return err + } + } + + // t.Int3 (int64) (int64) + if t.Int3 >= 0 { + if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(t.Int3)); err != nil { + return err + } + } else { + if err := cw.WriteMajorTypeHeader(cbg.MajNegativeInt, uint64(-t.Int3-1)); err != nil { + return err + } + } + return nil +} + +func (t *TupleIntArray) UnmarshalCBOR(r io.Reader) (err error) { + *t = TupleIntArray{} + + cr := cbg.NewCborReader(r) + + maj, extra, err := cr.ReadHeader() + if err != nil { + return err + } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + + if maj != cbg.MajArray { + return fmt.Errorf("cbor input should be of type array") + } + + if extra != 3 { + return fmt.Errorf("cbor input had wrong number of fields") + } + + // t.Int1 (int64) (int64) + { + maj, extra, err := cr.ReadHeader() + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative overflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.Int1 = int64(extraI) + } + // t.Int2 (int64) (int64) + { + maj, extra, err := cr.ReadHeader() + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative overflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.Int2 = int64(extraI) + } + // t.Int3 (int64) (int64) + { + maj, extra, err := cr.ReadHeader() + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative overflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.Int3 = int64(extraI) + } + return nil +} diff --git a/testing/roundtrip_test.go b/testing/roundtrip_test.go index 81e99dd..af7402b 100644 --- a/testing/roundtrip_test.go +++ b/testing/roundtrip_test.go @@ -451,3 +451,47 @@ func TestMapOfStringToString(t *testing.T) { } //TODO same for strings + +func TestTransparentIntArray(t *testing.T) { + t.Run("roundtrip", func(t *testing.T) { + zero := &IntArray{} + recepticle := &IntArray{} + testValueRoundtrip(t, zero, recepticle) + }) + + t.Run("roundtrip intalias", func(t *testing.T) { + zero := &IntAliasArray{} + recepticle := &IntAliasArray{} + testValueRoundtrip(t, zero, recepticle) + }) + + // non-zero values + t.Run("roundtrip non-zero", func(t *testing.T) { + val := &IntArray{Ints: []int64{1, 2, 3}} + recepticle := &IntArray{} + testValueRoundtrip(t, val, recepticle) + }) + t.Run("roundtrip non-zero intalias", func(t *testing.T) { + val := &IntAliasArray{Ints: []IntAlias{1, 2, 3}} + recepticle := &IntAliasArray{} + testValueRoundtrip(t, val, recepticle) + }) + + // tuple struct to/from transparent int array + t.Run("roundtrip tuple struct to transparent", func(t *testing.T) { + val := &TupleIntArray{2, 4, 5} + recepticle := &IntArray{} + testValueRoundtrip(t, val, recepticle) + if val.Int1 != recepticle.Ints[0] { + t.Fatal("mismatch") + } + }) + t.Run("roundtrip transparent to tuple struct", func(t *testing.T) { + val := &IntArray{Ints: []int64{2, 4, 5}} + recepticle := &TupleIntArray{} + testValueRoundtrip(t, val, recepticle) + if val.Ints[0] != recepticle.Int1 { + t.Fatal("mismatch") + } + }) +} diff --git a/testing/types.go b/testing/types.go index 46b9518..b9299f5 100644 --- a/testing/types.go +++ b/testing/types.go @@ -145,3 +145,18 @@ type TestSliceNilPreserve struct { NotOther []byte `cborgen:"preservenil"` Beep int64 } + +type IntAlias int64 + +type IntArray struct { + Ints []int64 `cborgen:"transparent"` +} +type IntAliasArray struct { + Ints []IntAlias `cborgen:"transparent"` +} + +type TupleIntArray struct { + Int1 int64 + Int2 int64 + Int3 int64 +} From 0145bb9c7c22b796e14ae33e376c2da9577ee3c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Magiera?= Date: Wed, 24 Jan 2024 01:16:41 +0100 Subject: [PATCH 2/4] Transparent newtypes from non-struct newtypes --- gen.go | 37 ++++- testgen/main.go | 3 + testing/cbor_gen.go | 292 ++++++++++++++++++++++++++++++++++++++ testing/roundtrip_test.go | 70 +++++++++ testing/types.go | 6 + 5 files changed, 404 insertions(+), 4 deletions(-) diff --git a/gen.go b/gen.go index f0eb2e6..926af67 100644 --- a/gen.go +++ b/gen.go @@ -183,6 +183,27 @@ func ParseTypeInfo(itype interface{}) (*GenTypeInfo, error) { Transparent: false, } + if t.Kind() != reflect.Struct { + return &GenTypeInfo{ + Name: t.Name(), + Transparent: true, + Fields: []Field{ + { + Name: ".", + MapKey: "", + Pointer: t.Kind() == reflect.Ptr, + Type: t, + Pkg: pkg, + Const: nil, + OmitEmpty: false, + PreserveNil: false, + IterLabel: "", + MaxLen: NoUsrMaxLen, + }, + }, + }, nil + } + for i := 0; i < t.NumField(); i++ { if out.Transparent { return nil, fmt.Errorf("transparent structs must exactly one field") @@ -713,8 +734,12 @@ func (t *{{ .Name }}) MarshalCBOR(w io.Writer) error { } for _, f := range gti.Fields { - fmt.Fprintf(w, "\n\t// t.%s (%s) (%s)", f.Name, f.Type, f.Type.Kind()) - f.Name = "t." + f.Name + if f.Name == "." { + f.Name = "(*t)" + } else { + f.Name = "t." + f.Name + } + fmt.Fprintf(w, "\n\t// %s (%s) (%s)", f.Name, f.Type, f.Type.Kind()) switch f.Type.Kind() { case reflect.String: @@ -1473,8 +1498,12 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) { } for _, f := range gti.Fields { - fmt.Fprintf(w, "\t// t.%s (%s) (%s)\n", f.Name, f.Type, f.Type.Kind()) - f.Name = "t." + f.Name + if f.Name == "." { + f.Name = "(*t)" // self + } else { + f.Name = "t." + f.Name + } + fmt.Fprintf(w, "\t// %s (%s) (%s)\n", f.Name, f.Type, f.Type.Kind()) switch f.Type.Kind() { case reflect.String: diff --git a/testgen/main.go b/testgen/main.go index abb1cc2..e714467 100644 --- a/testgen/main.go +++ b/testgen/main.go @@ -17,6 +17,9 @@ func main() { types.IntArray{}, types.IntAliasArray{}, types.TupleIntArray{}, + types.IntArrayNewType{}, + types.IntArrayAliasNewType{}, + types.MapTransparentType{}, ); err != nil { panic(err) } diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index 97e50f2..62d8483 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -1635,3 +1635,295 @@ func (t *TupleIntArray) UnmarshalCBOR(r io.Reader) (err error) { } return nil } + +func (t *IntArrayNewType) MarshalCBOR(w io.Writer) error { + + cw := cbg.NewCborWriter(w) + + // (*t) (testing.IntArrayNewType) (slice) + if len((*t)) > cbg.MaxLength { + return xerrors.Errorf("Slice value in field (*t) was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajArray, uint64(len((*t)))); err != nil { + return err + } + for _, v := range *t { + if v >= 0 { + if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(v)); err != nil { + return err + } + } else { + if err := cw.WriteMajorTypeHeader(cbg.MajNegativeInt, uint64(-v-1)); err != nil { + return err + } + } + + } + return nil +} + +func (t *IntArrayNewType) UnmarshalCBOR(r io.Reader) (err error) { + *t = IntArrayNewType{} + + cr := cbg.NewCborReader(r) + + var maj byte + var extra uint64 + _ = maj + _ = extra + + // (*t) (testing.IntArrayNewType) (slice) + + maj, extra, err = cr.ReadHeader() + if err != nil { + return err + } + + if extra > cbg.MaxLength { + return fmt.Errorf("(*t): array too large (%d)", extra) + } + + if maj != cbg.MajArray { + return fmt.Errorf("expected cbor array") + } + + if extra > 0 { + (*t) = make([]int64, extra) + } + + for i := 0; i < int(extra); i++ { + { + var maj byte + var extra uint64 + var err error + _ = maj + _ = extra + _ = err + { + maj, extra, err := cr.ReadHeader() + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative overflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + (*t)[i] = int64(extraI) + } + + } + } + return nil +} + +func (t *IntArrayAliasNewType) MarshalCBOR(w io.Writer) error { + + cw := cbg.NewCborWriter(w) + + // (*t) (testing.IntArrayAliasNewType) (slice) + if len((*t)) > cbg.MaxLength { + return xerrors.Errorf("Slice value in field (*t) was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajArray, uint64(len((*t)))); err != nil { + return err + } + for _, v := range *t { + if v >= 0 { + if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(v)); err != nil { + return err + } + } else { + if err := cw.WriteMajorTypeHeader(cbg.MajNegativeInt, uint64(-v-1)); err != nil { + return err + } + } + + } + return nil +} + +func (t *IntArrayAliasNewType) UnmarshalCBOR(r io.Reader) (err error) { + *t = IntArrayAliasNewType{} + + cr := cbg.NewCborReader(r) + + var maj byte + var extra uint64 + _ = maj + _ = extra + + // (*t) (testing.IntArrayAliasNewType) (slice) + + maj, extra, err = cr.ReadHeader() + if err != nil { + return err + } + + if extra > cbg.MaxLength { + return fmt.Errorf("(*t): array too large (%d)", extra) + } + + if maj != cbg.MajArray { + return fmt.Errorf("expected cbor array") + } + + if extra > 0 { + (*t) = make([]IntAlias, extra) + } + + for i := 0; i < int(extra); i++ { + { + var maj byte + var extra uint64 + var err error + _ = maj + _ = extra + _ = err + { + maj, extra, err := cr.ReadHeader() + var extraI int64 + if err != nil { + return err + } + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative overflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + (*t)[i] = IntAlias(extraI) + } + + } + } + return nil +} + +func (t *MapTransparentType) MarshalCBOR(w io.Writer) error { + + cw := cbg.NewCborWriter(w) + + // (*t) (testing.MapTransparentType) (map) + { + if len((*t)) > 4096 { + return xerrors.Errorf("cannot marshal (*t) map too large") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajMap, uint64(len((*t)))); err != nil { + return err + } + + keys := make([]string, 0, len((*t))) + for k := range *t { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + v := (*t)[k] + + if len(k) > cbg.MaxLength { + return xerrors.Errorf("Value in field k was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(k))); err != nil { + return err + } + if _, err := cw.WriteString(string(k)); err != nil { + return err + } + + if len(v) > cbg.MaxLength { + return xerrors.Errorf("Value in field v was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(v))); err != nil { + return err + } + if _, err := cw.WriteString(string(v)); err != nil { + return err + } + + } + } + return nil +} + +func (t *MapTransparentType) UnmarshalCBOR(r io.Reader) (err error) { + *t = MapTransparentType{} + + cr := cbg.NewCborReader(r) + + var maj byte + var extra uint64 + _ = maj + _ = extra + + // (*t) (testing.MapTransparentType) (map) + + maj, extra, err = cr.ReadHeader() + if err != nil { + return err + } + if maj != cbg.MajMap { + return fmt.Errorf("expected a map (major type 5)") + } + if extra > 4096 { + return fmt.Errorf("(*t): map too large") + } + + (*t) = make(map[string]string, extra) + + for i, l := 0, int(extra); i < l; i++ { + + var k string + + { + sval, err := cbg.ReadString(cr) + if err != nil { + return err + } + + k = string(sval) + } + + var v string + + { + sval, err := cbg.ReadString(cr) + if err != nil { + return err + } + + v = string(sval) + } + + (*t)[k] = v + + } + return nil +} diff --git a/testing/roundtrip_test.go b/testing/roundtrip_test.go index af7402b..2769e5d 100644 --- a/testing/roundtrip_test.go +++ b/testing/roundtrip_test.go @@ -494,4 +494,74 @@ func TestTransparentIntArray(t *testing.T) { t.Fatal("mismatch") } }) + + // IntArrayNewType / IntArrayAliasNewType + t.Run("roundtrip IntArrayNewType", func(t *testing.T) { + zero := &IntArrayNewType{} + recepticle := &IntArrayNewType{} + testValueRoundtrip(t, zero, recepticle) + }) + t.Run("roundtrip IntArrayAliasNewType", func(t *testing.T) { + zero := &IntArrayAliasNewType{} + recepticle := &IntArrayAliasNewType{} + testValueRoundtrip(t, zero, recepticle) + }) + t.Run("roundtrip non-zero IntArrayNewType", func(t *testing.T) { + val := &IntArrayNewType{1, 2, 3} + recepticle := &IntArrayNewType{} + testValueRoundtrip(t, val, recepticle) + }) + t.Run("roundtrip non-zero IntArrayAliasNewType", func(t *testing.T) { + val := &IntArrayAliasNewType{1, 2, 3} + recepticle := &IntArrayAliasNewType{} + testValueRoundtrip(t, val, recepticle) + }) + // NewTypes into/from TupleIntArray + t.Run("roundtrip IntArrayNewType to TupleIntArray", func(t *testing.T) { + val := IntArrayNewType{1, 2, 3} + recepticle := &TupleIntArray{} + testValueRoundtrip(t, &val, recepticle) + if val[0] != recepticle.Int1 { + t.Fatal("mismatch") + } + }) + t.Run("roundtrip IntArrayAliasNewType to TupleIntArray", func(t *testing.T) { + val := IntArrayAliasNewType{1, 2, 3} + recepticle := &TupleIntArray{} + testValueRoundtrip(t, &val, recepticle) + if int64(val[0]) != recepticle.Int1 { + t.Fatal("mismatch") + } + }) + t.Run("roundtrip TupleIntArray to IntArrayNewType", func(t *testing.T) { + val := TupleIntArray{2, 4, 5} + recepticle := IntArrayNewType{} + testValueRoundtrip(t, &val, &recepticle) + if val.Int1 != recepticle[0] { + t.Fatal("mismatch") + } + }) + t.Run("roundtrip TupleIntArray to IntArrayAliasNewType", func(t *testing.T) { + val := TupleIntArray{2, 4, 5} + recepticle := IntArrayAliasNewType{} + testValueRoundtrip(t, &val, &recepticle) + if val.Int1 != int64(recepticle[0]) { + t.Fatal("mismatch") + } + }) +} + +func TestMapTransparentType(t *testing.T) { + t.Run("roundtrip", func(t *testing.T) { + zero := MapTransparentType{} + recepticle := &MapTransparentType{} + testValueRoundtrip(t, &zero, recepticle) + }) + + // non-zero values + t.Run("roundtrip non-zero", func(t *testing.T) { + val := MapTransparentType(map[string]string{"foo": "bar"}) + recepticle := &MapTransparentType{} + testValueRoundtrip(t, &val, recepticle) + }) } diff --git a/testing/types.go b/testing/types.go index b9299f5..b962838 100644 --- a/testing/types.go +++ b/testing/types.go @@ -155,8 +155,14 @@ type IntAliasArray struct { Ints []IntAlias `cborgen:"transparent"` } +type IntArrayNewType []int64 + +type IntArrayAliasNewType []IntAlias + type TupleIntArray struct { Int1 int64 Int2 int64 Int3 int64 } + +type MapTransparentType map[string]string From 0efd71a8493b4f771c9de9eaf01f39502080f72e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Magiera?= Date: Wed, 24 Jan 2024 14:49:06 +0100 Subject: [PATCH 3/4] Cleanup transparent templates, golden value tests --- gen.go | 49 +++++++++++++++++++++++------------- testing/cbor_gen.go | 15 ----------- testing/roundtrip_test.go | 53 ++++++++++++++++++++++++++++----------- 3 files changed, 70 insertions(+), 47 deletions(-) diff --git a/gen.go b/gen.go index 926af67..9606091 100644 --- a/gen.go +++ b/gen.go @@ -206,7 +206,7 @@ func ParseTypeInfo(itype interface{}) (*GenTypeInfo, error) { for i := 0; i < t.NumField(); i++ { if out.Transparent { - return nil, fmt.Errorf("transparent structs must exactly one field") + return nil, fmt.Errorf("transparent structs must have exactly one field") } f := t.Field(i) @@ -713,22 +713,29 @@ func emitCborMarshalArrayField(w io.Writer, f Field) error { return nil } -func emitCborMarshalStructTuple(w io.Writer, gti *GenTypeInfo) error { - // 9 byte buffer to accomodate for the maximum header length (cbor varints are maximum 9 bytes_ - err := doTemplate(w, gti, `{{if not .Transparent}}var lengthBuf{{ .Name }} = {{ .TupleHeaderAsByteString }}{{end}} +func emitCborMarshalStructTuple(w io.Writer, gti *GenTypeInfo) (err error) { + // 9 byte buffer to accommodate for the maximum header length (cbor varints are maximum 9 bytes_ + if gti.Transparent { + err = doTemplate(w, gti, ` func (t *{{ .Name }}) MarshalCBOR(w io.Writer) error { - {{if not .Transparent}}if t == nil { + cw := cbg.NewCborWriter(w) +`) + } else { + err = doTemplate(w, gti, `var lengthBuf{{ .Name }} = {{ .TupleHeaderAsByteString }} +func (t *{{ .Name }}) MarshalCBOR(w io.Writer) error { + if t == nil { _, err := w.Write(cbg.CborNull) return err - }{{end}} + } cw := cbg.NewCborWriter(w) - {{if not .Transparent}} + if _, err := cw.Write(lengthBuf{{ .Name }}); err != nil { return err } - {{end}} + `) + } if err != nil { return err } @@ -1462,13 +1469,25 @@ func emitCborUnmarshalArrayField(w io.Writer, f Field) error { return nil } -func emitCborUnmarshalStructTuple(w io.Writer, gti *GenTypeInfo) error { - err := doTemplate(w, gti, ` +func emitCborUnmarshalStructTuple(w io.Writer, gti *GenTypeInfo) (err error) { + if gti.Transparent { + err = doTemplate(w, gti, ` +func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) { + *t = {{.Name}}{} + + cr := cbg.NewCborReader(r) + var maj byte + var extra uint64 + _ = maj + _ = extra +`) + } else { + err = doTemplate(w, gti, ` func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) { *t = {{.Name}}{} cr := cbg.NewCborReader(r) - {{ if not .Transparent }} + maj, extra, err := {{ ReadHeader "cr" }} if err != nil { return err @@ -1486,13 +1505,9 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) { if extra != {{ len .Fields }} { return fmt.Errorf("cbor input had wrong number of fields") } - {{ else }} - var maj byte - var extra uint64 - _ = maj - _ = extra - {{ end }} + `) + } if err != nil { return err } diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index 62d8483..7a254d6 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -1299,7 +1299,6 @@ func (t *BigField) UnmarshalCBOR(r io.Reader) (err error) { } func (t *IntArray) MarshalCBOR(w io.Writer) error { - cw := cbg.NewCborWriter(w) // t.Ints ([]int64) (slice) @@ -1329,12 +1328,10 @@ func (t *IntArray) UnmarshalCBOR(r io.Reader) (err error) { *t = IntArray{} cr := cbg.NewCborReader(r) - var maj byte var extra uint64 _ = maj _ = extra - // t.Ints ([]int64) (slice) maj, extra, err = cr.ReadHeader() @@ -1393,7 +1390,6 @@ func (t *IntArray) UnmarshalCBOR(r io.Reader) (err error) { } func (t *IntAliasArray) MarshalCBOR(w io.Writer) error { - cw := cbg.NewCborWriter(w) // t.Ints ([]testing.IntAlias) (slice) @@ -1423,12 +1419,10 @@ func (t *IntAliasArray) UnmarshalCBOR(r io.Reader) (err error) { *t = IntAliasArray{} cr := cbg.NewCborReader(r) - var maj byte var extra uint64 _ = maj _ = extra - // t.Ints ([]testing.IntAlias) (slice) maj, extra, err = cr.ReadHeader() @@ -1637,7 +1631,6 @@ func (t *TupleIntArray) UnmarshalCBOR(r io.Reader) (err error) { } func (t *IntArrayNewType) MarshalCBOR(w io.Writer) error { - cw := cbg.NewCborWriter(w) // (*t) (testing.IntArrayNewType) (slice) @@ -1667,12 +1660,10 @@ func (t *IntArrayNewType) UnmarshalCBOR(r io.Reader) (err error) { *t = IntArrayNewType{} cr := cbg.NewCborReader(r) - var maj byte var extra uint64 _ = maj _ = extra - // (*t) (testing.IntArrayNewType) (slice) maj, extra, err = cr.ReadHeader() @@ -1731,7 +1722,6 @@ func (t *IntArrayNewType) UnmarshalCBOR(r io.Reader) (err error) { } func (t *IntArrayAliasNewType) MarshalCBOR(w io.Writer) error { - cw := cbg.NewCborWriter(w) // (*t) (testing.IntArrayAliasNewType) (slice) @@ -1761,12 +1751,10 @@ func (t *IntArrayAliasNewType) UnmarshalCBOR(r io.Reader) (err error) { *t = IntArrayAliasNewType{} cr := cbg.NewCborReader(r) - var maj byte var extra uint64 _ = maj _ = extra - // (*t) (testing.IntArrayAliasNewType) (slice) maj, extra, err = cr.ReadHeader() @@ -1825,7 +1813,6 @@ func (t *IntArrayAliasNewType) UnmarshalCBOR(r io.Reader) (err error) { } func (t *MapTransparentType) MarshalCBOR(w io.Writer) error { - cw := cbg.NewCborWriter(w) // (*t) (testing.MapTransparentType) (map) @@ -1877,12 +1864,10 @@ func (t *MapTransparentType) UnmarshalCBOR(r io.Reader) (err error) { *t = MapTransparentType{} cr := cbg.NewCborReader(r) - var maj byte var extra uint64 _ = maj _ = extra - // (*t) (testing.MapTransparentType) (map) maj, extra, err = cr.ReadHeader() diff --git a/testing/roundtrip_test.go b/testing/roundtrip_test.go index 2769e5d..bd607aa 100644 --- a/testing/roundtrip_test.go +++ b/testing/roundtrip_test.go @@ -53,9 +53,26 @@ func TestNilPreserveWorks(t *testing.T) { testTypeRoundtrips(t, reflect.TypeOf(TestSliceNilPreserve{})) } -func testValueRoundtrip(t *testing.T, obj cbg.CBORMarshaler, nobj cbg.CBORUnmarshaler) { +type RoundTripOptions struct { + Golden []byte +} + +type RoundTripOption func(*RoundTripOptions) + +func WithGolden(golden []byte) RoundTripOption { + return func(opts *RoundTripOptions) { + opts.Golden = golden + } +} + +func testValueRoundtrip(t *testing.T, obj cbg.CBORMarshaler, nobj cbg.CBORUnmarshaler, options ...RoundTripOption) { t.Helper() + opts := &RoundTripOptions{} + for _, option := range options { + option(opts) + } + buf := new(bytes.Buffer) if err := obj.MarshalCBOR(buf); err != nil { t.Fatal("i guess its fine to fail marshaling") @@ -63,6 +80,12 @@ func testValueRoundtrip(t *testing.T, obj cbg.CBORMarshaler, nobj cbg.CBORUnmars enc := buf.Bytes() + if opts.Golden != nil { + if !bytes.Equal(opts.Golden, enc) { + t.Fatalf("encoding mismatch: %x != %x", opts.Golden, enc) + } + } + if err := nobj.UnmarshalCBOR(bytes.NewReader(enc)); err != nil { t.Logf("got bad bytes: %x", enc) t.Fatal("failed to round trip object: ", err) @@ -456,20 +479,20 @@ func TestTransparentIntArray(t *testing.T) { t.Run("roundtrip", func(t *testing.T) { zero := &IntArray{} recepticle := &IntArray{} - testValueRoundtrip(t, zero, recepticle) + testValueRoundtrip(t, zero, recepticle, WithGolden([]byte{0x80})) }) t.Run("roundtrip intalias", func(t *testing.T) { zero := &IntAliasArray{} recepticle := &IntAliasArray{} - testValueRoundtrip(t, zero, recepticle) + testValueRoundtrip(t, zero, recepticle, WithGolden([]byte{0x80})) }) // non-zero values t.Run("roundtrip non-zero", func(t *testing.T) { val := &IntArray{Ints: []int64{1, 2, 3}} recepticle := &IntArray{} - testValueRoundtrip(t, val, recepticle) + testValueRoundtrip(t, val, recepticle, WithGolden([]byte{0x83, 0x01, 0x02, 0x03})) }) t.Run("roundtrip non-zero intalias", func(t *testing.T) { val := &IntAliasArray{Ints: []IntAlias{1, 2, 3}} @@ -481,7 +504,7 @@ func TestTransparentIntArray(t *testing.T) { t.Run("roundtrip tuple struct to transparent", func(t *testing.T) { val := &TupleIntArray{2, 4, 5} recepticle := &IntArray{} - testValueRoundtrip(t, val, recepticle) + testValueRoundtrip(t, val, recepticle, WithGolden([]byte{0x83, 0x02, 0x04, 0x05})) if val.Int1 != recepticle.Ints[0] { t.Fatal("mismatch") } @@ -489,7 +512,7 @@ func TestTransparentIntArray(t *testing.T) { t.Run("roundtrip transparent to tuple struct", func(t *testing.T) { val := &IntArray{Ints: []int64{2, 4, 5}} recepticle := &TupleIntArray{} - testValueRoundtrip(t, val, recepticle) + testValueRoundtrip(t, val, recepticle, WithGolden([]byte{0x83, 0x02, 0x04, 0x05})) if val.Ints[0] != recepticle.Int1 { t.Fatal("mismatch") } @@ -499,7 +522,7 @@ func TestTransparentIntArray(t *testing.T) { t.Run("roundtrip IntArrayNewType", func(t *testing.T) { zero := &IntArrayNewType{} recepticle := &IntArrayNewType{} - testValueRoundtrip(t, zero, recepticle) + testValueRoundtrip(t, zero, recepticle, WithGolden([]byte{0x80})) }) t.Run("roundtrip IntArrayAliasNewType", func(t *testing.T) { zero := &IntArrayAliasNewType{} @@ -509,18 +532,18 @@ func TestTransparentIntArray(t *testing.T) { t.Run("roundtrip non-zero IntArrayNewType", func(t *testing.T) { val := &IntArrayNewType{1, 2, 3} recepticle := &IntArrayNewType{} - testValueRoundtrip(t, val, recepticle) + testValueRoundtrip(t, val, recepticle, WithGolden([]byte{0x83, 0x01, 0x02, 0x03})) }) t.Run("roundtrip non-zero IntArrayAliasNewType", func(t *testing.T) { val := &IntArrayAliasNewType{1, 2, 3} recepticle := &IntArrayAliasNewType{} - testValueRoundtrip(t, val, recepticle) + testValueRoundtrip(t, val, recepticle, WithGolden([]byte{0x83, 0x01, 0x02, 0x03})) }) // NewTypes into/from TupleIntArray t.Run("roundtrip IntArrayNewType to TupleIntArray", func(t *testing.T) { val := IntArrayNewType{1, 2, 3} recepticle := &TupleIntArray{} - testValueRoundtrip(t, &val, recepticle) + testValueRoundtrip(t, &val, recepticle, WithGolden([]byte{0x83, 0x01, 0x02, 0x03})) if val[0] != recepticle.Int1 { t.Fatal("mismatch") } @@ -528,7 +551,7 @@ func TestTransparentIntArray(t *testing.T) { t.Run("roundtrip IntArrayAliasNewType to TupleIntArray", func(t *testing.T) { val := IntArrayAliasNewType{1, 2, 3} recepticle := &TupleIntArray{} - testValueRoundtrip(t, &val, recepticle) + testValueRoundtrip(t, &val, recepticle, WithGolden([]byte{0x83, 0x01, 0x02, 0x03})) if int64(val[0]) != recepticle.Int1 { t.Fatal("mismatch") } @@ -536,7 +559,7 @@ func TestTransparentIntArray(t *testing.T) { t.Run("roundtrip TupleIntArray to IntArrayNewType", func(t *testing.T) { val := TupleIntArray{2, 4, 5} recepticle := IntArrayNewType{} - testValueRoundtrip(t, &val, &recepticle) + testValueRoundtrip(t, &val, &recepticle, WithGolden([]byte{0x83, 0x02, 0x04, 0x05})) if val.Int1 != recepticle[0] { t.Fatal("mismatch") } @@ -544,7 +567,7 @@ func TestTransparentIntArray(t *testing.T) { t.Run("roundtrip TupleIntArray to IntArrayAliasNewType", func(t *testing.T) { val := TupleIntArray{2, 4, 5} recepticle := IntArrayAliasNewType{} - testValueRoundtrip(t, &val, &recepticle) + testValueRoundtrip(t, &val, &recepticle, WithGolden([]byte{0x83, 0x02, 0x04, 0x05})) if val.Int1 != int64(recepticle[0]) { t.Fatal("mismatch") } @@ -555,13 +578,13 @@ func TestMapTransparentType(t *testing.T) { t.Run("roundtrip", func(t *testing.T) { zero := MapTransparentType{} recepticle := &MapTransparentType{} - testValueRoundtrip(t, &zero, recepticle) + testValueRoundtrip(t, &zero, recepticle, WithGolden([]byte{0xa0})) }) // non-zero values t.Run("roundtrip non-zero", func(t *testing.T) { val := MapTransparentType(map[string]string{"foo": "bar"}) recepticle := &MapTransparentType{} - testValueRoundtrip(t, &val, recepticle) + testValueRoundtrip(t, &val, recepticle, WithGolden([]byte{0xa1, 0x63, 0x66, 0x6f, 0x6f, 0x63, 0x62, 0x61, 0x72})) }) } From 6509eab6cdca33c22d15bb9abcd5aab678e351b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Magiera?= Date: Wed, 24 Jan 2024 15:12:17 +0100 Subject: [PATCH 4/4] transparent: Name the magic '.' field name --- gen.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/gen.go b/gen.go index 9606091..d97f091 100644 --- a/gen.go +++ b/gen.go @@ -87,6 +87,10 @@ var _ = sort.Sort `) } +// FieldNameSelf is the name of the field that is the marshal target itself. +// This is used in non-struct types which are handled like transparent structs. +const FieldNameSelf = "." + type Field struct { Name string MapKey string @@ -189,7 +193,7 @@ func ParseTypeInfo(itype interface{}) (*GenTypeInfo, error) { Transparent: true, Fields: []Field{ { - Name: ".", + Name: FieldNameSelf, MapKey: "", Pointer: t.Kind() == reflect.Ptr, Type: t, @@ -741,7 +745,7 @@ func (t *{{ .Name }}) MarshalCBOR(w io.Writer) error { } for _, f := range gti.Fields { - if f.Name == "." { + if f.Name == FieldNameSelf { f.Name = "(*t)" } else { f.Name = "t." + f.Name @@ -1513,7 +1517,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) { } for _, f := range gti.Fields { - if f.Name == "." { + if f.Name == FieldNameSelf { f.Name = "(*t)" // self } else { f.Name = "t." + f.Name