Skip to content

Commit

Permalink
Merge pull request #34820 from hashicorp/b-aws-securitylake_data_lake…
Browse files Browse the repository at this point in the history
…-import.crash

r/aws_securitylake_data_lake: Fix panic on import
  • Loading branch information
ewbankkit authored Dec 8, 2023
2 parents bea5560 + 0d8c0bf commit bc1b46e
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 41 deletions.
3 changes: 3 additions & 0 deletions .changelog/34820.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
resource/aws_securitylake_data_lake: Fix `reflect.Set: value of type basetypes.StringValue is not assignable to type types.ARN` panic when importing resources with `nil` ARN fields
```
78 changes: 37 additions & 41 deletions internal/framework/flex/autoflex.go
Original file line number Diff line number Diff line change
Expand Up @@ -855,19 +855,19 @@ func (flattener autoFlattener) convert(ctx context.Context, vFrom, vTo reflect.V
tTo := valTo.Type(ctx)
switch k := vFrom.Kind(); k {
case reflect.Bool:
diags.Append(flattener.bool(ctx, vFrom, tTo, vTo)...)
diags.Append(flattener.bool(ctx, vFrom, false, tTo, vTo)...)
return diags

case reflect.Float32, reflect.Float64:
diags.Append(flattener.float(ctx, vFrom, tTo, vTo)...)
diags.Append(flattener.float(ctx, vFrom, false, tTo, vTo)...)
return diags

case reflect.Int32, reflect.Int64:
diags.Append(flattener.int(ctx, vFrom, tTo, vTo)...)
diags.Append(flattener.int(ctx, vFrom, false, tTo, vTo)...)
return diags

case reflect.String:
diags.Append(flattener.string(ctx, vFrom, tTo, vTo)...)
diags.Append(flattener.string(ctx, vFrom, false, tTo, vTo)...)
return diags

case reflect.Ptr:
Expand Down Expand Up @@ -898,12 +898,16 @@ func (flattener autoFlattener) convert(ctx context.Context, vFrom, vTo reflect.V
}

// bool copies an AWS API bool value to a compatible Plugin Framework value.
func (flattener autoFlattener) bool(ctx context.Context, vFrom reflect.Value, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
func (flattener autoFlattener) bool(ctx context.Context, vFrom reflect.Value, isNullFrom bool, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
var diags diag.Diagnostics

switch tTo := tTo.(type) {
case basetypes.BoolTypable:
v, d := tTo.ValueFromBool(ctx, types.BoolValue(vFrom.Bool()))
boolValue := types.BoolNull()
if !isNullFrom {
boolValue = types.BoolValue(vFrom.Bool())
}
v, d := tTo.ValueFromBool(ctx, boolValue)
diags.Append(d...)
if diags.HasError() {
return diags
Expand All @@ -925,12 +929,16 @@ func (flattener autoFlattener) bool(ctx context.Context, vFrom reflect.Value, tT
}

// float copies an AWS API float value to a compatible Plugin Framework value.
func (flattener autoFlattener) float(ctx context.Context, vFrom reflect.Value, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
func (flattener autoFlattener) float(ctx context.Context, vFrom reflect.Value, isNullFrom bool, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
var diags diag.Diagnostics

switch tTo := tTo.(type) {
case basetypes.Float64Typable:
v, d := tTo.ValueFromFloat64(ctx, types.Float64Value(vFrom.Float()))
float64Value := types.Float64Null()
if !isNullFrom {
float64Value = types.Float64Value(vFrom.Float())
}
v, d := tTo.ValueFromFloat64(ctx, float64Value)
diags.Append(d...)
if diags.HasError() {
return diags
Expand All @@ -952,12 +960,16 @@ func (flattener autoFlattener) float(ctx context.Context, vFrom reflect.Value, t
}

// int copies an AWS API int value to a compatible Plugin Framework value.
func (flattener autoFlattener) int(ctx context.Context, vFrom reflect.Value, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
func (flattener autoFlattener) int(ctx context.Context, vFrom reflect.Value, isNullFrom bool, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
var diags diag.Diagnostics

switch tTo := tTo.(type) {
case basetypes.Int64Typable:
v, d := tTo.ValueFromInt64(ctx, types.Int64Value(vFrom.Int()))
int64Value := types.Int64Null()
if !isNullFrom {
int64Value = types.Int64Value(vFrom.Int())
}
v, d := tTo.ValueFromInt64(ctx, int64Value)
diags.Append(d...)
if diags.HasError() {
return diags
Expand All @@ -979,12 +991,16 @@ func (flattener autoFlattener) int(ctx context.Context, vFrom reflect.Value, tTo
}

// string copies an AWS API string value to a compatible Plugin Framework value.
func (flattener autoFlattener) string(ctx context.Context, vFrom reflect.Value, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
func (flattener autoFlattener) string(ctx context.Context, vFrom reflect.Value, isNullFrom bool, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
var diags diag.Diagnostics

switch tTo := tTo.(type) {
case basetypes.StringTypable:
v, d := tTo.ValueFromString(ctx, types.StringValue(vFrom.String()))
stringValue := types.StringNull()
if !isNullFrom {
stringValue = types.StringValue(vFrom.String())
}
v, d := tTo.ValueFromString(ctx, stringValue)
diags.Append(d...)
if diags.HasError() {
return diags
Expand All @@ -1009,49 +1025,29 @@ func (flattener autoFlattener) string(ctx context.Context, vFrom reflect.Value,
func (flattener autoFlattener) ptr(ctx context.Context, vFrom reflect.Value, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
var diags diag.Diagnostics

switch vElem := vFrom.Elem(); vFrom.Type().Elem().Kind() {
switch vElem, isNilFrom := vFrom.Elem(), vFrom.IsNil(); vFrom.Type().Elem().Kind() {
case reflect.Bool:
if vFrom.IsNil() {
vTo.Set(reflect.ValueOf(types.BoolNull()))
return diags
}

diags.Append(flattener.bool(ctx, vElem, tTo, vTo)...)
diags.Append(flattener.bool(ctx, vElem, isNilFrom, tTo, vTo)...)
return diags

case reflect.Float32, reflect.Float64:
if vFrom.IsNil() {
vTo.Set(reflect.ValueOf(types.Float64Null()))
return diags
}

diags.Append(flattener.float(ctx, vElem, tTo, vTo)...)
diags.Append(flattener.float(ctx, vElem, isNilFrom, tTo, vTo)...)
return diags

case reflect.Int32, reflect.Int64:
if vFrom.IsNil() {
vTo.Set(reflect.ValueOf(types.Int64Null()))
return diags
}

diags.Append(flattener.int(ctx, vElem, tTo, vTo)...)
diags.Append(flattener.int(ctx, vElem, isNilFrom, tTo, vTo)...)
return diags

case reflect.String:
if vFrom.IsNil() {
vTo.Set(reflect.ValueOf(types.StringNull()))
return diags
}

diags.Append(flattener.string(ctx, vElem, tTo, vTo)...)
diags.Append(flattener.string(ctx, vElem, isNilFrom, tTo, vTo)...)
return diags

case reflect.Struct:
if tTo, ok := tTo.(fwtypes.NestedObjectType); ok {
//
// *struct -> types.List(OfObject).
//
diags.Append(flattener.ptrToStructNestedObject(ctx, vFrom, tTo, vTo)...)
diags.Append(flattener.ptrToStructNestedObject(ctx, vElem, isNilFrom, tTo, vTo)...)
return diags
}
}
Expand Down Expand Up @@ -1407,10 +1403,10 @@ func (flattener autoFlattener) structToNestedObject(ctx context.Context, vFrom r
}

// ptrToStructNestedObject copies an AWS API *struct value to a compatible Plugin Framework NestedObjectValue value.
func (flattener autoFlattener) ptrToStructNestedObject(ctx context.Context, vFrom reflect.Value, tTo fwtypes.NestedObjectType, vTo reflect.Value) diag.Diagnostics {
func (flattener autoFlattener) ptrToStructNestedObject(ctx context.Context, vFrom reflect.Value, isNullFrom bool, tTo fwtypes.NestedObjectType, vTo reflect.Value) diag.Diagnostics {
var diags diag.Diagnostics

if vFrom.IsNil() {
if isNullFrom {
val, d := tTo.NullValue(ctx)
diags.Append(d...)
if diags.HasError() {
Expand All @@ -1428,7 +1424,7 @@ func (flattener autoFlattener) ptrToStructNestedObject(ctx context.Context, vFro
return diags
}

diags.Append(autoFlexConvertStruct(ctx, vFrom.Elem().Interface(), to, flattener)...)
diags.Append(autoFlexConvertStruct(ctx, vFrom.Interface(), to, flattener)...)
if diags.HasError() {
return diags
}
Expand Down
120 changes: 120 additions & 0 deletions internal/framework/flex/autoflex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,65 @@ func TestGenericExpandAdvanced(t *testing.T) {
}
}

type TestFlexTF17 struct {
Field1 fwtypes.ARN `tfsdk:"field1"`
}

func TestGenericExpandCustomStringType(t *testing.T) {
t.Parallel()

a := "arn:aws:securityhub:us-west-2:1234567890:control/cis-aws-foundations-benchmark/v/1.2.0/1.1" //lintignore:AWSAT003,AWSAT005
ctx := context.Background()
testCases := []struct {
Context context.Context //nolint:containedctx // testing context use
TestName string
Source any
Target any
WantErr bool
WantTarget any
}{
{
TestName: "single ARN Source and single string Target",
Source: &TestFlexTF17{Field1: fwtypes.ARNValue(a)},
Target: &TestFlexAWS01{},
WantTarget: &TestFlexAWS01{Field1: a},
},
{
TestName: "single ARN Source and single *string Target",
Source: &TestFlexTF17{Field1: fwtypes.ARNValue(a)},
Target: &TestFlexAWS02{},
WantTarget: &TestFlexAWS02{Field1: aws.String(a)},
},
}

for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.TestName, func(t *testing.T) {
t.Parallel()

testCtx := ctx //nolint:contextcheck // simplify use of testing context
if testCase.Context != nil {
testCtx = testCase.Context
}

err := Expand(testCtx, testCase.Source, testCase.Target)
gotErr := err != nil

if gotErr != testCase.WantErr {
t.Errorf("gotErr = %v, wantErr = %v", gotErr, testCase.WantErr)
}

if gotErr {
if !testCase.WantErr {
t.Errorf("err = %q", err)
}
} else if diff := cmp.Diff(testCase.Target, testCase.WantTarget); diff != "" {
t.Errorf("unexpected diff (+wanted, -got): %s", diff)
}
})
}
}

func TestGenericFlatten(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -1392,3 +1451,64 @@ func TestGenericFlattenAdvanced(t *testing.T) {
})
}
}

func TestGenericFlattenCustomStringType(t *testing.T) {
t.Parallel()

a := "arn:aws:securityhub:us-west-2:1234567890:control/cis-aws-foundations-benchmark/v/1.2.0/1.1" //lintignore:AWSAT003,AWSAT005
ctx := context.Background()
testCases := []struct {
Context context.Context //nolint:containedctx // testing context use
TestName string
Source any
Target any
WantErr bool
WantTarget any
}{
{
TestName: "single string Source and single ARN Target",
Source: &TestFlexAWS01{Field1: a},
Target: &TestFlexTF17{},
WantTarget: &TestFlexTF17{Field1: fwtypes.ARNValue(a)},
},
{
TestName: "single *string Source and single ARN Target",
Source: &TestFlexAWS02{Field1: aws.String(a)},
Target: &TestFlexTF17{},
WantTarget: &TestFlexTF17{Field1: fwtypes.ARNValue(a)},
},
{
TestName: "single nil *string Source and single ARN Target",
Source: &TestFlexAWS02{},
Target: &TestFlexTF17{},
WantTarget: &TestFlexTF17{Field1: fwtypes.ARNNull()},
},
}

for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.TestName, func(t *testing.T) {
t.Parallel()

testCtx := ctx //nolint:contextcheck // simplify use of testing context
if testCase.Context != nil {
testCtx = testCase.Context
}

err := Flatten(testCtx, testCase.Source, testCase.Target)
gotErr := err != nil

if gotErr != testCase.WantErr {
t.Errorf("gotErr = %v, wantErr = %v", gotErr, testCase.WantErr)
}

if gotErr {
if !testCase.WantErr {
t.Errorf("err = %q", err)
}
} else if diff := cmp.Diff(testCase.Target, testCase.WantTarget); diff != "" {
t.Errorf("unexpected diff (+wanted, -got): %s", diff)
}
})
}
}

0 comments on commit bc1b46e

Please sign in to comment.