From fa491493468c17c5486ef3c1d213a16a3fd3ff09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Bourgois?= Date: Wed, 19 Feb 2020 17:36:13 +0100 Subject: [PATCH] fix: recursive arg validation (#712) --- internal/core/reflect.go | 80 +++++++++++++-- internal/core/validate.go | 18 ++-- internal/core/validate_test.go | 176 +++++++++++++++++++++++++++++++++ 3 files changed, 260 insertions(+), 14 deletions(-) create mode 100644 internal/core/validate_test.go diff --git a/internal/core/reflect.go b/internal/core/reflect.go index 5a9da508fc..28cad2f066 100644 --- a/internal/core/reflect.go +++ b/internal/core/reflect.go @@ -1,7 +1,9 @@ package core import ( + "fmt" "reflect" + "sort" "strings" "github.com/scaleway/scaleway-sdk-go/strcase" @@ -20,14 +22,80 @@ func newObjectWithForcedJSONTags(t reflect.Type) interface{} { return reflect.New(reflect.StructOf(structFieldsCopy)).Interface() } -// getValueForFieldByName search for a field in a cmdArgs and returns its value if this field exists. +// getValuesForFieldByName recursively search for fields in a cmdArgs' value and returns its values if they exist. // The search is based on the name of the field. -func getValueForFieldByName(cmdArgs interface{}, fieldName string) (value reflect.Value, isValid bool) { - field := reflect.ValueOf(cmdArgs).Elem().FieldByName(fieldName) - if !field.IsValid() { - return field, false +func getValuesForFieldByName(value reflect.Value, parts []string) (values []reflect.Value, err error) { + if len(parts) == 0 { + return []reflect.Value{value}, nil } - return field, true + + switch value.Kind() { + case reflect.Ptr: + return getValuesForFieldByName(value.Elem(), parts) + + case reflect.Slice: + values := []reflect.Value(nil) + for i := 0; i < value.Len(); i++ { + newValues, err := getValuesForFieldByName(value.Index(i), parts[1:]) + if err != nil { + return nil, err + } + values = append(values, newValues...) + } + return values, nil + + case reflect.Map: + if value.IsNil() { + return nil, nil + } + + values := []reflect.Value(nil) + + mapKeys := value.MapKeys() + sort.Slice(mapKeys, func(i, j int) bool { + return mapKeys[i].String() < mapKeys[j].String() + }) + + for _, mapKey := range mapKeys { + mapValue := value.MapIndex(mapKey) + newValues, err := getValuesForFieldByName(mapValue, parts[1:]) + if err != nil { + return nil, err + } + values = append(values, newValues...) + } + return values, nil + + case reflect.Struct: + anonymousFieldIndexes := []int(nil) + fieldIndexByName := map[string]int{} + + for i := 0; i < value.NumField(); i++ { + field := value.Type().Field(i) + if field.Anonymous { + anonymousFieldIndexes = append(anonymousFieldIndexes, i) + } else { + fieldIndexByName[field.Name] = i + } + } + + fieldName := strcase.ToPublicGoName(parts[0]) + if fieldIndex, exist := fieldIndexByName[fieldName]; exist { + return getValuesForFieldByName(value.Field(fieldIndex), parts[1:]) + } + + // If it does not exist we try to find it in nested anonymous field + for fieldIndex := len(anonymousFieldIndexes) - 1; fieldIndex >= 0; fieldIndex-- { + newValues, err := getValuesForFieldByName(value.Field(fieldIndex), parts) + if err == nil { + return newValues, nil + } + } + + return nil, fmt.Errorf("field %v does not exist for %v", fieldName, value.Type().Name()) + } + + return nil, fmt.Errorf("case is not handled") } // isFieldZero returns whether a field is set to its zero value diff --git a/internal/core/validate.go b/internal/core/validate.go index ac6dd93829..6867e1f40b 100644 --- a/internal/core/validate.go +++ b/internal/core/validate.go @@ -1,6 +1,7 @@ package core import ( + "reflect" "strings" "github.com/scaleway/scaleway-cli/internal/args" @@ -34,20 +35,21 @@ func DefaultCommandValidateFunc() CommandValidateFunc { // validateArgValues validates values passed to the different args of a Command. func validateArgValues(cmd *Command, cmdArgs interface{}) error { for _, argSpec := range cmd.ArgSpecs { - fieldName := strings.ReplaceAll(strcase.ToPublicGoName(argSpec.Name), "."+sliceSchema, "") - fieldName = strings.ReplaceAll(fieldName, "."+mapSchema, "") - fieldValue, fieldExists := getValueForFieldByName(cmdArgs, fieldName) - if !fieldExists { - logger.Infof("could not validate arg value for '%v': invalid fieldName: %v", argSpec.Name, fieldName) + fieldName := strcase.ToPublicGoName(argSpec.Name) + fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, ".")) + if err != nil { + logger.Infof("could not validate arg value for '%v': invalid fieldName: %v: %v", argSpec.Name, fieldName, err.Error()) continue } validateFunc := DefaultArgSpecValidateFunc() if argSpec.ValidateFunc != nil { validateFunc = argSpec.ValidateFunc } - err := validateFunc(argSpec, fieldValue.Interface()) - if err != nil { - return err + for _, fieldValue := range fieldValues { + err := validateFunc(argSpec, fieldValue.Interface()) + if err != nil { + return err + } } } return nil diff --git a/internal/core/validate_test.go b/internal/core/validate_test.go new file mode 100644 index 0000000000..c273bbbe1b --- /dev/null +++ b/internal/core/validate_test.go @@ -0,0 +1,176 @@ +package core + +import ( + "fmt" + "testing" + + "github.com/alecthomas/assert" +) + +type Element struct { + ID int + Name string + ElementsMap map[string]Element + ElementsSlice []Element +} + +type elementCustom struct { + *Element + Short string +} + +func Test_DefaultCommandValidateFunc(t *testing.T) { + type TestCase struct { + command *Command + parsedArguments interface{} + } + + run := func(testCase TestCase) func(t *testing.T) { + return func(t *testing.T) { + err := DefaultCommandValidateFunc()(testCase.command, testCase.parsedArguments) + assert.Equal(t, fmt.Errorf("arg validation called"), err) + } + } + + t.Run("simple", run(TestCase{ + command: &Command{ + ArgSpecs: ArgSpecs{ + { + Name: "name", + ValidateFunc: func(argSpec *ArgSpec, value interface{}) error { + return fmt.Errorf("arg validation called") + }, + }, + }, + }, + parsedArguments: &Element{ + Name: "bob", + }, + })) + + t.Run("map", run(TestCase{ + command: &Command{ + ArgSpecs: ArgSpecs{ + { + Name: "elements-map.{key}.id", + }, + { + Name: "elements-map.{key}.name", + ValidateFunc: func(argSpec *ArgSpec, value interface{}) error { + return fmt.Errorf("arg validation called") + }, + }, + }, + }, + parsedArguments: &Element{ + ElementsMap: map[string]Element{ + "first": { + ID: 1, + Name: "first", + }, + "second": { + ID: 2, + Name: "second", + }, + }, + }, + })) + + t.Run("slice", run(TestCase{ + command: &Command{ + ArgSpecs: ArgSpecs{ + { + Name: "elements-slice.{index}.id", + }, + { + Name: "elements-slice.{index}.name", + ValidateFunc: func(argSpec *ArgSpec, value interface{}) error { + return fmt.Errorf("arg validation called") + }, + }, + }, + }, + parsedArguments: &Element{ + ElementsSlice: []Element{ + { + ID: 1, + Name: "first", + }, + { + ID: 2, + Name: "second", + }, + }, + }, + })) + + t.Run("slice-of-slice", run(TestCase{ + command: &Command{ + ArgSpecs: ArgSpecs{ + { + Name: "elements-slice.{index}.id", + }, + { + Name: "elements-slice.{index}.elements-slice.{index}.name", + ValidateFunc: func(argSpec *ArgSpec, value interface{}) error { + return fmt.Errorf("arg validation called") + }, + }, + }, + }, + parsedArguments: &Element{ + ElementsSlice: []Element{ + { + ID: 1, + }, + { + ElementsSlice: []Element{ + { + Name: "bob", + }, + }, + }, + }, + }, + })) + + t.Run("new-field", run(TestCase{ + command: &Command{ + ArgSpecs: ArgSpecs{ + { + Name: "name", + }, + { + Name: "short", + ValidateFunc: func(argSpec *ArgSpec, value interface{}) error { + return fmt.Errorf("arg validation called") + }, + }, + }, + }, + parsedArguments: &elementCustom{ + Short: "bob", + }, + })) + + t.Run("anonymous-field", run(TestCase{ + command: &Command{ + ArgSpecs: ArgSpecs{ + { + Name: "short", + }, + { + Name: "name", + ValidateFunc: func(argSpec *ArgSpec, value interface{}) error { + return fmt.Errorf("arg validation called") + }, + }, + }, + }, + parsedArguments: &elementCustom{ + Element: &Element{ + Name: "bob", + }, + }, + })) +}