Skip to content

Commit

Permalink
fix: recursive arg validation (#712)
Browse files Browse the repository at this point in the history
  • Loading branch information
loicbourgois authored Feb 19, 2020
1 parent 1ed8606 commit fa49149
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 14 deletions.
80 changes: 74 additions & 6 deletions internal/core/reflect.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package core

import (
"fmt"
"reflect"
"sort"
"strings"

"github.com/scaleway/scaleway-sdk-go/strcase"
Expand All @@ -20,14 +22,80 @@ func newObjectWithForcedJSONTags(t reflect.Type) interface{} {
return reflect.New(reflect.StructOf(structFieldsCopy)).Interface()
}

// getValueForFieldByName search for a field in a cmdArgs and returns its value if this field exists.
// getValuesForFieldByName recursively search for fields in a cmdArgs' value and returns its values if they exist.
// The search is based on the name of the field.
func getValueForFieldByName(cmdArgs interface{}, fieldName string) (value reflect.Value, isValid bool) {
field := reflect.ValueOf(cmdArgs).Elem().FieldByName(fieldName)
if !field.IsValid() {
return field, false
func getValuesForFieldByName(value reflect.Value, parts []string) (values []reflect.Value, err error) {
if len(parts) == 0 {
return []reflect.Value{value}, nil
}
return field, true

switch value.Kind() {
case reflect.Ptr:
return getValuesForFieldByName(value.Elem(), parts)

case reflect.Slice:
values := []reflect.Value(nil)
for i := 0; i < value.Len(); i++ {
newValues, err := getValuesForFieldByName(value.Index(i), parts[1:])
if err != nil {
return nil, err
}
values = append(values, newValues...)
}
return values, nil

case reflect.Map:
if value.IsNil() {
return nil, nil
}

values := []reflect.Value(nil)

mapKeys := value.MapKeys()
sort.Slice(mapKeys, func(i, j int) bool {
return mapKeys[i].String() < mapKeys[j].String()
})

for _, mapKey := range mapKeys {
mapValue := value.MapIndex(mapKey)
newValues, err := getValuesForFieldByName(mapValue, parts[1:])
if err != nil {
return nil, err
}
values = append(values, newValues...)
}
return values, nil

case reflect.Struct:
anonymousFieldIndexes := []int(nil)
fieldIndexByName := map[string]int{}

for i := 0; i < value.NumField(); i++ {
field := value.Type().Field(i)
if field.Anonymous {
anonymousFieldIndexes = append(anonymousFieldIndexes, i)
} else {
fieldIndexByName[field.Name] = i
}
}

fieldName := strcase.ToPublicGoName(parts[0])
if fieldIndex, exist := fieldIndexByName[fieldName]; exist {
return getValuesForFieldByName(value.Field(fieldIndex), parts[1:])
}

// If it does not exist we try to find it in nested anonymous field
for fieldIndex := len(anonymousFieldIndexes) - 1; fieldIndex >= 0; fieldIndex-- {
newValues, err := getValuesForFieldByName(value.Field(fieldIndex), parts)
if err == nil {
return newValues, nil
}
}

return nil, fmt.Errorf("field %v does not exist for %v", fieldName, value.Type().Name())
}

return nil, fmt.Errorf("case is not handled")
}

// isFieldZero returns whether a field is set to its zero value
Expand Down
18 changes: 10 additions & 8 deletions internal/core/validate.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core

import (
"reflect"
"strings"

"github.com/scaleway/scaleway-cli/internal/args"
Expand Down Expand Up @@ -34,20 +35,21 @@ func DefaultCommandValidateFunc() CommandValidateFunc {
// validateArgValues validates values passed to the different args of a Command.
func validateArgValues(cmd *Command, cmdArgs interface{}) error {
for _, argSpec := range cmd.ArgSpecs {
fieldName := strings.ReplaceAll(strcase.ToPublicGoName(argSpec.Name), "."+sliceSchema, "")
fieldName = strings.ReplaceAll(fieldName, "."+mapSchema, "")
fieldValue, fieldExists := getValueForFieldByName(cmdArgs, fieldName)
if !fieldExists {
logger.Infof("could not validate arg value for '%v': invalid fieldName: %v", argSpec.Name, fieldName)
fieldName := strcase.ToPublicGoName(argSpec.Name)
fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
if err != nil {
logger.Infof("could not validate arg value for '%v': invalid fieldName: %v: %v", argSpec.Name, fieldName, err.Error())
continue
}
validateFunc := DefaultArgSpecValidateFunc()
if argSpec.ValidateFunc != nil {
validateFunc = argSpec.ValidateFunc
}
err := validateFunc(argSpec, fieldValue.Interface())
if err != nil {
return err
for _, fieldValue := range fieldValues {
err := validateFunc(argSpec, fieldValue.Interface())
if err != nil {
return err
}
}
}
return nil
Expand Down
176 changes: 176 additions & 0 deletions internal/core/validate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package core

import (
"fmt"
"testing"

"github.com/alecthomas/assert"
)

type Element struct {
ID int
Name string
ElementsMap map[string]Element
ElementsSlice []Element
}

type elementCustom struct {
*Element
Short string
}

func Test_DefaultCommandValidateFunc(t *testing.T) {
type TestCase struct {
command *Command
parsedArguments interface{}
}

run := func(testCase TestCase) func(t *testing.T) {
return func(t *testing.T) {
err := DefaultCommandValidateFunc()(testCase.command, testCase.parsedArguments)
assert.Equal(t, fmt.Errorf("arg validation called"), err)
}
}

t.Run("simple", run(TestCase{
command: &Command{
ArgSpecs: ArgSpecs{
{
Name: "name",
ValidateFunc: func(argSpec *ArgSpec, value interface{}) error {
return fmt.Errorf("arg validation called")
},
},
},
},
parsedArguments: &Element{
Name: "bob",
},
}))

t.Run("map", run(TestCase{
command: &Command{
ArgSpecs: ArgSpecs{
{
Name: "elements-map.{key}.id",
},
{
Name: "elements-map.{key}.name",
ValidateFunc: func(argSpec *ArgSpec, value interface{}) error {
return fmt.Errorf("arg validation called")
},
},
},
},
parsedArguments: &Element{
ElementsMap: map[string]Element{
"first": {
ID: 1,
Name: "first",
},
"second": {
ID: 2,
Name: "second",
},
},
},
}))

t.Run("slice", run(TestCase{
command: &Command{
ArgSpecs: ArgSpecs{
{
Name: "elements-slice.{index}.id",
},
{
Name: "elements-slice.{index}.name",
ValidateFunc: func(argSpec *ArgSpec, value interface{}) error {
return fmt.Errorf("arg validation called")
},
},
},
},
parsedArguments: &Element{
ElementsSlice: []Element{
{
ID: 1,
Name: "first",
},
{
ID: 2,
Name: "second",
},
},
},
}))

t.Run("slice-of-slice", run(TestCase{
command: &Command{
ArgSpecs: ArgSpecs{
{
Name: "elements-slice.{index}.id",
},
{
Name: "elements-slice.{index}.elements-slice.{index}.name",
ValidateFunc: func(argSpec *ArgSpec, value interface{}) error {
return fmt.Errorf("arg validation called")
},
},
},
},
parsedArguments: &Element{
ElementsSlice: []Element{
{
ID: 1,
},
{
ElementsSlice: []Element{
{
Name: "bob",
},
},
},
},
},
}))

t.Run("new-field", run(TestCase{
command: &Command{
ArgSpecs: ArgSpecs{
{
Name: "name",
},
{
Name: "short",
ValidateFunc: func(argSpec *ArgSpec, value interface{}) error {
return fmt.Errorf("arg validation called")
},
},
},
},
parsedArguments: &elementCustom{
Short: "bob",
},
}))

t.Run("anonymous-field", run(TestCase{
command: &Command{
ArgSpecs: ArgSpecs{
{
Name: "short",
},
{
Name: "name",
ValidateFunc: func(argSpec *ArgSpec, value interface{}) error {
return fmt.Errorf("arg validation called")
},
},
},
},
parsedArguments: &elementCustom{
Element: &Element{
Name: "bob",
},
},
}))
}

0 comments on commit fa49149

Please sign in to comment.