diff --git a/dot/import.go b/dot/import.go index 25d0337b71..0a78503f82 100644 --- a/dot/import.go +++ b/dot/import.go @@ -21,7 +21,6 @@ import ( "encoding/json" "errors" "io/ioutil" - "math/big" "path/filepath" "github.com/ChainSafe/gossamer/dot/state" @@ -97,8 +96,7 @@ func newHeaderFromFile(filename string) (*types.Header, error) { return nil, errors.New("invalid number field in header JSON") } - numBytes := common.MustHexToBytes(hexNum) - num := big.NewInt(0).SetBytes(numBytes) + num := common.MustHexToBigInt(hexNum) parentHashStr, ok := jsonHeader["parentHash"].(string) if !ok { diff --git a/lib/common/common.go b/lib/common/common.go index 2b6d5a9f80..fea9e961f6 100644 --- a/lib/common/common.go +++ b/lib/common/common.go @@ -21,6 +21,7 @@ import ( "encoding/hex" "errors" "io" + "math/big" "strconv" "strings" ) @@ -103,6 +104,32 @@ func MustHexToBytes(in string) []byte { return out } +// MustHexToBigInt turns a 0x prefixed hex string into a big.Int +// it panic if it cannot decode the string +func MustHexToBigInt(in string) *big.Int { + if len(in) < 2 { + panic("invalid string") + } + + if strings.Compare(in[:2], "0x") != 0 { + panic(ErrNoPrefix) + } + + in = in[2:] + + // Ensure we have an even length + if len(in)%2 != 0 { + in = "0" + in + } + + out, err := hex.DecodeString(in) + if err != nil { + panic(err) + } + + return big.NewInt(0).SetBytes(out) +} + // BytesToHex turns a byte slice into a 0x prefixed hex string func BytesToHex(in []byte) string { s := hex.EncodeToString(in) diff --git a/lib/common/common_test.go b/lib/common/common_test.go index 0b574f09a5..b88c5aabcc 100644 --- a/lib/common/common_test.go +++ b/lib/common/common_test.go @@ -18,8 +18,12 @@ package common import ( "bytes" + "math/big" "reflect" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestStringToInts(t *testing.T) { @@ -182,3 +186,32 @@ func TestSwapNibbles(t *testing.T) { } } } + +func TestMustHexToBigInt(t *testing.T) { + tests := []struct { + in string + out *big.Int + }{ + {"0x0", big.NewInt(0).SetBytes([]byte{0})}, + {"0x00", big.NewInt(0).SetBytes([]byte{0})}, + {"0x1", big.NewInt(1)}, + {"0x01", big.NewInt(1)}, + {"0xf", big.NewInt(15)}, + {"0x0f", big.NewInt(15)}, + {"0x10", big.NewInt(16)}, + {"0xff", big.NewInt(255)}, + {"0x50429", big.NewInt(328745)}, + {"0x050429", big.NewInt(328745)}, + } + + for _, test := range tests { + res := MustHexToBigInt(test.in) + require.Equal(t, test.out, res) + } +} + +func TestMustHexToBigIntPanic(t *testing.T) { + assert.Panics(t, func() { MustHexToBigInt("1") }, "should panic for string len < 2") + assert.Panics(t, func() { MustHexToBigInt("12") }, "should panic for string not starting with 0x") + assert.Panics(t, func() { MustHexToBigInt("0xzz") }, "should panic for string not containing hex characters") +}