Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pkg/scale): support for custom VaryingDataType types #2612

Merged
merged 7 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 192 additions & 4 deletions pkg/scale/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ SCALE uses a compact encoding for variable width unsigned integers.
### Basic Example

Basic example which encodes and decodes a `uint`.
```
```go
import (
"fmt"
"github.com/ChainSafe/gossamer/pkg/scale"
Expand Down Expand Up @@ -111,7 +111,7 @@ func ExampleBasic() {

Use the `scale` struct tag for struct fields to conform to specific encoding sequence of struct field values. A struct tag of `"-"` will be omitted from encoding and decoding.

```
```go
import (
"fmt"
"github.com/ChainSafe/gossamer/pkg/scale"
Expand Down Expand Up @@ -159,7 +159,7 @@ result := scale.NewResult(int32(0), int32(0)
result.Set(scale.Ok, 10)
```

```
```go
import (
"fmt"
"github.com/ChainSafe/gossamer/pkg/scale"
Expand Down Expand Up @@ -213,7 +213,7 @@ func ExampleResult() {
A `VaryingDataType` is analogous to a Rust enum. A `VaryingDataType` needs to be constructed using the `NewVaryingDataType` constructor. `VaryingDataTypeValue` is an
interface with one `Index() uint` method that needs to be implemented. The returned `uint` index should be unique per type and needs to be the same index as defined in the Rust enum to ensure interopability. To set the value of the `VaryingDataType`, the `VaryingDataType.Set()` function should be called with an associated `VaryingDataTypeValue`.

```
```go
import (
"fmt"
"github.com/ChainSafe/gossamer/pkg/scale"
Expand Down Expand Up @@ -323,4 +323,192 @@ func ExampleVaryingDataTypeSlice() {
panic(fmt.Errorf("uh oh: %+v %+v", vdts, vdts1))
}
}
```

#### Nested VaryingDataType

See `varying_data_type_nested_example.go` for a working example of a custom `VaryingDataType` with another custom `VaryingDataType` as a value of the parent `VaryingDataType`. In the case of nested `VaryingDataTypes`, a custom type needs to be created for the child `VaryingDataType` because it needs to fulfill the `VaryingDataTypeValue` interface.

```go
import (
"fmt"
"reflect"

"github.com/ChainSafe/gossamer/pkg/scale"
)

// ParentVDT is a VaryingDataType that consists of multiple nested VaryingDataType
// instances (aka. a rust enum containing multiple enum options)
type ParentVDT scale.VaryingDataType

// Set will set a VaryingDataTypeValue using the underlying VaryingDataType
func (pvdt *ParentVDT) Set(val scale.VaryingDataTypeValue) (err error) {
// cast to VaryingDataType to use VaryingDataType.Set method
vdt := scale.VaryingDataType(*pvdt)
err = vdt.Set(val)
if err != nil {
return
}
// store original ParentVDT with VaryingDataType that has been set
*pvdt = ParentVDT(vdt)
return
}

// Value will return value from underying VaryingDataType
func (pvdt *ParentVDT) Value() (val scale.VaryingDataTypeValue) {
vdt := scale.VaryingDataType(*pvdt)
return vdt.Value()
}

// NewParentVDT is constructor for ParentVDT
func NewParentVDT() ParentVDT {
// use standard VaryingDataType constructor to construct a VaryingDataType
vdt, err := scale.NewVaryingDataType(NewChildVDT(), NewOtherChildVDT())
if err != nil {
panic(err)
}
// cast to ParentVDT
return ParentVDT(vdt)
}

// ChildVDT type is used as a VaryingDataTypeValue for ParentVDT
type ChildVDT scale.VaryingDataType

// Index fulfills the VaryingDataTypeValue interface. T
func (cvdt ChildVDT) Index() uint {
return 1
}

// Set will set a VaryingDataTypeValue using the underlying VaryingDataType
func (cvdt *ChildVDT) Set(val scale.VaryingDataTypeValue) (err error) {
// cast to VaryingDataType to use VaryingDataType.Set method
vdt := scale.VaryingDataType(*cvdt)
err = vdt.Set(val)
if err != nil {
return
}
// store original ParentVDT with VaryingDataType that has been set
*cvdt = ChildVDT(vdt)
return
}

// Value will return value from underying VaryingDataType
func (cvdt *ChildVDT) Value() (val scale.VaryingDataTypeValue) {
vdt := scale.VaryingDataType(*cvdt)
return vdt.Value()
}

// NewChildVDT is constructor for ChildVDT
func NewChildVDT() ChildVDT {
// use standard VaryingDataType constructor to construct a VaryingDataType
// constarined to types ChildInt16, ChildStruct, and ChildString
vdt, err := scale.NewVaryingDataType(ChildInt16(0), ChildStruct{}, ChildString(""))
if err != nil {
panic(err)
}
// cast to ParentVDT
return ChildVDT(vdt)
}

// OtherChildVDT type is used as a VaryingDataTypeValue for ParentVDT
type OtherChildVDT scale.VaryingDataType

// Index fulfills the VaryingDataTypeValue interface.
func (ocvdt OtherChildVDT) Index() uint {
return 2
}

// Set will set a VaryingDataTypeValue using the underlying VaryingDataType
func (cvdt *OtherChildVDT) Set(val scale.VaryingDataTypeValue) (err error) {
// cast to VaryingDataType to use VaryingDataType.Set method
vdt := scale.VaryingDataType(*cvdt)
err = vdt.Set(val)
if err != nil {
return
}
// store original ParentVDT with VaryingDataType that has been set
*cvdt = OtherChildVDT(vdt)
return
}

// NewOtherChildVDT is constructor for OtherChildVDT
func NewOtherChildVDT() OtherChildVDT {
// use standard VaryingDataType constructor to construct a VaryingDataType
// constarined to types ChildInt16 and ChildStruct
vdt, err := scale.NewVaryingDataType(ChildInt16(0), ChildStruct{}, ChildString(""))
if err != nil {
panic(err)
}
// cast to ParentVDT
return OtherChildVDT(vdt)
}

// ChildInt16 is used as a VaryingDataTypeValue for ChildVDT and OtherChildVDT
type ChildInt16 int16

// Index fulfills the VaryingDataTypeValue interface. The ChildVDT type is used as a
// VaryingDataTypeValue for ParentVDT
func (ci ChildInt16) Index() uint {
return 1
}

// ChildStruct is used as a VaryingDataTypeValue for ChildVDT and OtherChildVDT
type ChildStruct struct {
A string
B bool
}

// Index fulfills the VaryingDataTypeValue interface
func (cs ChildStruct) Index() uint {
return 2
}

// ChildString is used as a VaryingDataTypeValue for ChildVDT and OtherChildVDT
type ChildString string

// Index fulfills the VaryingDataTypeValue interface
func (cs ChildString) Index() uint {
return 3
}

func ExampleNestedVaryingDataType() {
parent := NewParentVDT()

// populate parent with ChildVDT
child := NewChildVDT()
child.Set(ChildInt16(888))
err := parent.Set(child)
if err != nil {
panic(err)
}

// validate ParentVDT.Value()
fmt.Printf("parent.Value(): %+v\n", parent.Value())
// should cast to ChildVDT, since that was set earlier
valChildVDT := parent.Value().(ChildVDT)
// validate ChildVDT.Value() as ChildInt16(888)
fmt.Printf("child.Value(): %+v\n", valChildVDT.Value())

// marshal into scale encoded bytes
bytes, err := scale.Marshal(parent)
if err != nil {
panic(err)
}
fmt.Printf("bytes: % x\n", bytes)

// unmarshal into another ParentVDT
dstParent := NewParentVDT()
err = scale.Unmarshal(bytes, &dstParent)
if err != nil {
panic(err)
}
// assert both ParentVDT instances are the same
fmt.Println(reflect.DeepEqual(parent, dstParent))

// Output:
// parent.Value(): {value:888 cache:map[1:0 2:{A: B:false} 3:]}
// child.Value(): 888
// bytes: 01 01 78 03
// true
}
```
27 changes: 23 additions & 4 deletions pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,12 @@ func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) {
case reflect.Ptr:
err = ds.decodePointer(dstv)
case reflect.Struct:
err = ds.decodeStruct(dstv)
ok := reflect.ValueOf(in).CanConvert(reflect.TypeOf(VaryingDataType{}))
if ok {
err = ds.decodeCustomVaryingDataType(dstv)
} else {
err = ds.decodeStruct(dstv)
}
case reflect.Array:
err = ds.decodeArray(dstv)
case reflect.Slice:
Expand Down Expand Up @@ -344,6 +349,19 @@ func (ds *decodeState) decodeVaryingDataTypeSlice(dstv reflect.Value) (err error
return
}

func (ds *decodeState) decodeCustomVaryingDataType(dstv reflect.Value) (err error) {
initialType := dstv.Type()
converted := dstv.Convert(reflect.TypeOf(VaryingDataType{}))
tempVal := reflect.New(converted.Type())
tempVal.Elem().Set(converted)
err = ds.decodeVaryingDataType(tempVal.Elem())
if err != nil {
return
}
dstv.Set(tempVal.Elem().Convert(initialType))
return
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need tempVal? Can we not directly used converted?

Also, why do you need to set dstv in the end? It does not look like dstv was modified in between (or was it?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dstv is not modified. dstv.Convert gives you a new instance that is converted but not addressable. So a tempVal needs to be created which is then decoded to. Then I set the original dstv with the decoded data converted back to the initial custom type.


func (ds *decodeState) decodeVaryingDataType(dstv reflect.Value) (err error) {
var b byte
b, err = ds.ReadByte()
Expand All @@ -358,12 +376,13 @@ func (ds *decodeState) decodeVaryingDataType(dstv reflect.Value) (err error) {
return
}

tempVal := reflect.New(reflect.TypeOf(val)).Elem()
err = ds.unmarshal(tempVal)
tempVal := reflect.New(reflect.TypeOf(val))
tempVal.Elem().Set(reflect.ValueOf(val))
err = ds.unmarshal(tempVal.Elem())
if err != nil {
return
}
err = vdt.Set(tempVal.Interface().(VaryingDataTypeValue))
err = vdt.Set(tempVal.Elem().Interface().(VaryingDataTypeValue))
if err != nil {
return
}
Expand Down
12 changes: 11 additions & 1 deletion pkg/scale/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ func (es *encodeState) marshal(in interface{}) (err error) {
err = es.marshal(elem.Interface())
}
case reflect.Struct:
err = es.encodeStruct(in)
ok := reflect.ValueOf(in).CanConvert(reflect.TypeOf(VaryingDataType{}))
if ok {
err = es.encodeCustomVaryingDataType(in)
} else {
err = es.encodeStruct(in)
}
case reflect.Array:
err = es.encodeArray(in)
case reflect.Slice:
Expand Down Expand Up @@ -148,6 +153,11 @@ func (es *encodeState) encodeResult(res Result) (err error) {
return
}

func (es *encodeState) encodeCustomVaryingDataType(in interface{}) (err error) {
vdt := reflect.ValueOf(in).Convert(reflect.TypeOf(VaryingDataType{})).Interface().(VaryingDataType)
return es.encodeVaryingDataType(vdt)
}

func (es *encodeState) encodeVaryingDataType(vdt VaryingDataType) (err error) {
err = es.WriteByte(byte(vdt.value.Index()))
if err != nil {
Expand Down
Loading