Skip to content

Commit

Permalink
feat(scale): add range checks to decodeUint function (#2683)
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardmack authored Sep 8, 2022
1 parent 62d750d commit ac700f8
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 40 deletions.
6 changes: 3 additions & 3 deletions internal/trie/node/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func Test_decodeBranch(t *testing.T) {
variant: branchVariant.bits,
partialKeyLength: 1,
errWrapped: ErrDecodeChildHash,
errMessage: "cannot decode child hash: at index 10: EOF",
errMessage: "cannot decode child hash: at index 10: reading byte: EOF",
},
"success for branch variant": {
reader: bytes.NewBuffer(
Expand Down Expand Up @@ -203,7 +203,7 @@ func Test_decodeBranch(t *testing.T) {
variant: branchWithValueVariant.bits,
partialKeyLength: 1,
errWrapped: ErrDecodeValue,
errMessage: "cannot decode value: EOF",
errMessage: "cannot decode value: reading byte: EOF",
},
"success for branch with value": {
reader: bytes.NewBuffer(concatByteSlices([][]byte{
Expand Down Expand Up @@ -333,7 +333,7 @@ func Test_decodeLeaf(t *testing.T) {
variant: leafVariant.bits,
partialKeyLength: 1,
errWrapped: ErrDecodeValue,
errMessage: "cannot decode value: could not decode invalid integer",
errMessage: "cannot decode value: unknown prefix for compact uint: 255",
},
"zero value": {
reader: bytes.NewBuffer([]byte{
Expand Down
2 changes: 1 addition & 1 deletion lib/runtime/version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func Test_DecodeVersion(t *testing.T) {
{255, 255}, // error
}),
errWrapped: ErrDecodingVersionField,
errMessage: "decoding version field impl name: could not decode invalid integer",
errMessage: "decoding version field impl name: unknown prefix for compact uint: 255",
},
// TODO add transaction version decode error once
// https://github.com/ChainSafe/gossamer/pull/2683
Expand Down
91 changes: 61 additions & 30 deletions pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func (ds *decodeState) decodeVaryingDataTypeSlice(dstv reflect.Value) (err error
if err != nil {
return
}
for i := 0; i < l; i++ {
for i := uint(0); i < l; i++ {
vdt := vdts.VaryingDataType
vdtv := reflect.New(reflect.TypeOf(vdt))
vdtv.Elem().Set(reflect.ValueOf(vdt))
Expand Down Expand Up @@ -397,7 +397,7 @@ func (ds *decodeState) decodeSlice(dstv reflect.Value) (err error) {
}
in := dstv.Interface()
temp := reflect.New(reflect.ValueOf(in).Type())
for i := 0; i < l; i++ {
for i := uint(0); i < l; i++ {
tempElemType := reflect.TypeOf(in).Elem()
tempElem := reflect.New(tempElemType).Elem()

Expand Down Expand Up @@ -478,59 +478,90 @@ func (ds *decodeState) decodeBool(dstv reflect.Value) (err error) {

// decodeUint will decode unsigned integer
func (ds *decodeState) decodeUint(dstv reflect.Value) (err error) {
b, err := ds.ReadByte()
const maxUint32 = ^uint32(0)
const maxUint64 = ^uint64(0)
prefix, err := ds.ReadByte()
if err != nil {
return
return fmt.Errorf("reading byte: %w", err)
}

in := dstv.Interface()
temp := reflect.New(reflect.TypeOf(in))
// check mode of encoding, stored at 2 least significant bits
mode := b & 3
switch {
case mode <= 2:
var val int64
val, err = ds.decodeSmallInt(b, mode)
mode := prefix % 4
var value uint64
switch mode {
case 0:
value = uint64(prefix >> 2)
case 1:
buf, err := ds.ReadByte()
if err != nil {
return
return fmt.Errorf("reading byte: %w", err)
}
temp.Elem().Set(reflect.ValueOf(val).Convert(reflect.TypeOf(in)))
dstv.Set(temp.Elem())
default:
// >4 byte mode
topSixBits := b >> 2
byteLen := uint(topSixBits) + 4

value = uint64(binary.LittleEndian.Uint16([]byte{prefix, buf}) >> 2)
if value <= 0b0011_1111 || value > 0b0111_1111_1111_1111 {
return fmt.Errorf("%w: %d (%b)", ErrU16OutOfRange, value, value)
}
case 2:
buf := make([]byte, 3)
_, err = ds.Read(buf)
if err != nil {
return fmt.Errorf("reading bytes: %w", err)
}
value = uint64(binary.LittleEndian.Uint32(append([]byte{prefix}, buf...)) >> 2)
if value <= 0b0011_1111_1111_1111 || value > uint64(maxUint32>>2) {
return fmt.Errorf("%w: %d (%b)", ErrU32OutOfRange, value, value)
}
case 3:
byteLen := (prefix >> 2) + 4
buf := make([]byte, byteLen)
_, err = ds.Read(buf)
if err != nil {
return
return fmt.Errorf("reading bytes: %w", err)
}

var o uint64
if byteLen == 4 {
o = uint64(binary.LittleEndian.Uint32(buf))
} else if byteLen > 4 && byteLen <= 8 {
switch byteLen {
case 4:
value = uint64(binary.LittleEndian.Uint32(buf))
if value <= uint64(maxUint32>>2) {
return fmt.Errorf("%w: %d (%b)", ErrU32OutOfRange, value, value)
}
case 8:
const uintSize = 32 << (^uint(0) >> 32 & 1)
if uintSize == 32 {
return ErrU64NotSupported
}
tmp := make([]byte, 8)
copy(tmp, buf)
o = binary.LittleEndian.Uint64(tmp)
} else {
err = errors.New("could not decode invalid integer")
return
value = binary.LittleEndian.Uint64(tmp)
if value <= maxUint64>>8 {
return fmt.Errorf("%w: %d (%b)", ErrU64OutOfRange, value, value)
}
default:
return fmt.Errorf("%w: %d", ErrCompactUintPrefixUnknown, prefix)

}
dstv.Set(reflect.ValueOf(o).Convert(reflect.TypeOf(in)))
}
temp.Elem().Set(reflect.ValueOf(value).Convert(reflect.TypeOf(in)))
dstv.Set(temp.Elem())
return
}

var (
ErrU16OutOfRange = errors.New("uint16 out of range")
ErrU32OutOfRange = errors.New("uint32 out of range")
ErrU64OutOfRange = errors.New("uint64 out of range")
ErrU64NotSupported = errors.New("uint64 is not supported")
ErrCompactUintPrefixUnknown = errors.New("unknown prefix for compact uint")
)

// decodeLength is helper method which calls decodeUint and casts to int
func (ds *decodeState) decodeLength() (l int, err error) {
func (ds *decodeState) decodeLength() (l uint, err error) {
dstv := reflect.New(reflect.TypeOf(l))
err = ds.decodeUint(dstv.Elem())
if err != nil {
return
}
l = dstv.Elem().Interface().(int)
l = dstv.Elem().Interface().(uint)
return
}

Expand Down
99 changes: 99 additions & 0 deletions pkg/scale/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
)

func Test_decodeState_decodeFixedWidthInt(t *testing.T) {
Expand Down Expand Up @@ -302,3 +303,101 @@ func Test_Decoder_Decode_MultipleCalls(t *testing.T) {
})
}
}

func Test_decodeState_decodeUint(t *testing.T) {
t.Parallel()
decodeUint32Tests := tests{
{
name: "int(1) mode 0",
in: uint32(1),
want: []byte{0x04},
},
{
name: "int(16383) mode 1",
in: int(16383),
want: []byte{0xfd, 0xff},
},
{
name: "int(1073741823) mode 2",
in: int(1073741823),
want: []byte{0xfe, 0xff, 0xff, 0xff},
},
{
name: "int(4294967295) mode 3",
in: int(4294967295),
want: []byte{0x3, 0xff, 0xff, 0xff, 0xff},
},
{
name: "myCustomInt(9223372036854775807) mode 3, 64bit",
in: myCustomInt(9223372036854775807),
want: []byte{19, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
},
{
name: "uint(overload)",
in: int(0),
want: []byte{0x07, 0x08, 0x09, 0x10, 0x0, 0x40},
wantErr: true,
},
{
name: "uint(16384) mode 2",
in: int(16384),
want: []byte{0x02, 0x00, 0x01, 0x0},
},
{
name: "uint(0) mode 1, error",
in: int(0),
want: []byte{0x01, 0x00},
wantErr: true,
},
{
name: "uint(0) mode 2, error",
in: int(0),
want: []byte{0x02, 0x00, 0x00, 0x0},
wantErr: true,
},
{
name: "uint(0) mode 3, error",
in: int(0),
want: []byte{0x03, 0x00, 0x00, 0x0},
wantErr: true,
},
{
name: "mode 3, 64bit, error",
in: int(0),
want: []byte{19, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
wantErr: true,
},
{
name: "[]int{1 << 32, 2, 3, 1 << 32}",
in: uint(4),
want: []byte{0x10, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x0c, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01},
},
{
name: "[4]int{1 << 32, 2, 3, 1 << 32}",
in: [4]int{0, 0, 0, 0},
want: []byte{0x07, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x0c, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01},
wantErr: true,
},
}

for _, tt := range decodeUint32Tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
dst := reflect.New(reflect.TypeOf(tt.in)).Elem().Interface()
dstv := reflect.ValueOf(&dst)
elem := indirect(dstv)

ds := decodeState{
Reader: bytes.NewBuffer(tt.want),
}
err := ds.decodeUint(elem)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.in, dst)
})
}
}
21 changes: 15 additions & 6 deletions pkg/scale/encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ var (
in: int(1),
want: []byte{0x04},
},
{
name: "int(42)",
in: int(42),
want: []byte{0xa8},
},
{
name: "int(16383)",
in: int(16383),
Expand Down Expand Up @@ -821,9 +826,11 @@ var (
want: []byte{0x10, 0x03, 0x00, 0x00, 0x00, 0x40, 0x08, 0x0c, 0x10},
},
{
name: "[]int{1 << 32, 2, 3, 1 << 32}",
in: []int{1 << 32, 2, 3, 1 << 32},
want: []byte{0x10, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x0c, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01},
name: "[]int64{1 << 32, 2, 3, 1 << 32}",
in: []int64{1 << 32, 2, 3, 1 << 32},
want: []byte{0x10, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
0x00},
},
{
name: "[]bool{true, false, true}",
Expand Down Expand Up @@ -864,9 +871,11 @@ var (
want: []byte{0x03, 0x00, 0x00, 0x00, 0x40, 0x08, 0x0c, 0x10},
},
{
name: "[4]int{1 << 32, 2, 3, 1 << 32}",
in: [4]int{1 << 32, 2, 3, 1 << 32},
want: []byte{0x07, 0x00, 0x00, 0x00, 0x00, 0x01, 0x08, 0x0c, 0x07, 0x00, 0x00, 0x00, 0x00, 0x01},
name: "[4]int64{1 << 32, 2, 3, 1 << 32}",
in: [4]int64{1 << 32, 2, 3, 1 << 32},
want: []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
0x00},
},
{
name: "[3]bool{true, false, true}",
Expand Down

0 comments on commit ac700f8

Please sign in to comment.