Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

r/aws_securitylake_data_lake: Fix panic on import #34820

Merged
merged 4 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/34820.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
resource/aws_securitylake_data_lake: Fix `reflect.Set: value of type basetypes.StringValue is not assignable to type types.ARN` panic when importing resources with `nil` ARN fields
```
78 changes: 37 additions & 41 deletions internal/framework/flex/autoflex.go
Original file line number Diff line number Diff line change
Expand Up @@ -855,19 +855,19 @@ func (flattener autoFlattener) convert(ctx context.Context, vFrom, vTo reflect.V
tTo := valTo.Type(ctx)
switch k := vFrom.Kind(); k {
case reflect.Bool:
diags.Append(flattener.bool(ctx, vFrom, tTo, vTo)...)
diags.Append(flattener.bool(ctx, vFrom, false, tTo, vTo)...)
return diags

case reflect.Float32, reflect.Float64:
diags.Append(flattener.float(ctx, vFrom, tTo, vTo)...)
diags.Append(flattener.float(ctx, vFrom, false, tTo, vTo)...)
return diags

case reflect.Int32, reflect.Int64:
diags.Append(flattener.int(ctx, vFrom, tTo, vTo)...)
diags.Append(flattener.int(ctx, vFrom, false, tTo, vTo)...)
return diags

case reflect.String:
diags.Append(flattener.string(ctx, vFrom, tTo, vTo)...)
diags.Append(flattener.string(ctx, vFrom, false, tTo, vTo)...)
return diags

case reflect.Ptr:
Expand Down Expand Up @@ -898,12 +898,16 @@ func (flattener autoFlattener) convert(ctx context.Context, vFrom, vTo reflect.V
}

// bool copies an AWS API bool value to a compatible Plugin Framework value.
func (flattener autoFlattener) bool(ctx context.Context, vFrom reflect.Value, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
func (flattener autoFlattener) bool(ctx context.Context, vFrom reflect.Value, isNullFrom bool, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
var diags diag.Diagnostics

switch tTo := tTo.(type) {
case basetypes.BoolTypable:
v, d := tTo.ValueFromBool(ctx, types.BoolValue(vFrom.Bool()))
boolValue := types.BoolNull()
if !isNullFrom {
boolValue = types.BoolValue(vFrom.Bool())
}
v, d := tTo.ValueFromBool(ctx, boolValue)
diags.Append(d...)
if diags.HasError() {
return diags
Expand All @@ -925,12 +929,16 @@ func (flattener autoFlattener) bool(ctx context.Context, vFrom reflect.Value, tT
}

// float copies an AWS API float value to a compatible Plugin Framework value.
func (flattener autoFlattener) float(ctx context.Context, vFrom reflect.Value, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
func (flattener autoFlattener) float(ctx context.Context, vFrom reflect.Value, isNullFrom bool, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
var diags diag.Diagnostics

switch tTo := tTo.(type) {
case basetypes.Float64Typable:
v, d := tTo.ValueFromFloat64(ctx, types.Float64Value(vFrom.Float()))
float64Value := types.Float64Null()
if !isNullFrom {
float64Value = types.Float64Value(vFrom.Float())
}
v, d := tTo.ValueFromFloat64(ctx, float64Value)
diags.Append(d...)
if diags.HasError() {
return diags
Expand All @@ -952,12 +960,16 @@ func (flattener autoFlattener) float(ctx context.Context, vFrom reflect.Value, t
}

// int copies an AWS API int value to a compatible Plugin Framework value.
func (flattener autoFlattener) int(ctx context.Context, vFrom reflect.Value, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
func (flattener autoFlattener) int(ctx context.Context, vFrom reflect.Value, isNullFrom bool, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
var diags diag.Diagnostics

switch tTo := tTo.(type) {
case basetypes.Int64Typable:
v, d := tTo.ValueFromInt64(ctx, types.Int64Value(vFrom.Int()))
int64Value := types.Int64Null()
if !isNullFrom {
int64Value = types.Int64Value(vFrom.Int())
}
v, d := tTo.ValueFromInt64(ctx, int64Value)
diags.Append(d...)
if diags.HasError() {
return diags
Expand All @@ -979,12 +991,16 @@ func (flattener autoFlattener) int(ctx context.Context, vFrom reflect.Value, tTo
}

// string copies an AWS API string value to a compatible Plugin Framework value.
func (flattener autoFlattener) string(ctx context.Context, vFrom reflect.Value, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
func (flattener autoFlattener) string(ctx context.Context, vFrom reflect.Value, isNullFrom bool, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
var diags diag.Diagnostics

switch tTo := tTo.(type) {
case basetypes.StringTypable:
v, d := tTo.ValueFromString(ctx, types.StringValue(vFrom.String()))
stringValue := types.StringNull()
if !isNullFrom {
stringValue = types.StringValue(vFrom.String())
}
v, d := tTo.ValueFromString(ctx, stringValue)
diags.Append(d...)
if diags.HasError() {
return diags
Expand All @@ -1009,49 +1025,29 @@ func (flattener autoFlattener) string(ctx context.Context, vFrom reflect.Value,
func (flattener autoFlattener) ptr(ctx context.Context, vFrom reflect.Value, tTo attr.Type, vTo reflect.Value) diag.Diagnostics {
var diags diag.Diagnostics

switch vElem := vFrom.Elem(); vFrom.Type().Elem().Kind() {
switch vElem, isNilFrom := vFrom.Elem(), vFrom.IsNil(); vFrom.Type().Elem().Kind() {
case reflect.Bool:
if vFrom.IsNil() {
vTo.Set(reflect.ValueOf(types.BoolNull()))
return diags
}

diags.Append(flattener.bool(ctx, vElem, tTo, vTo)...)
diags.Append(flattener.bool(ctx, vElem, isNilFrom, tTo, vTo)...)
return diags

case reflect.Float32, reflect.Float64:
if vFrom.IsNil() {
vTo.Set(reflect.ValueOf(types.Float64Null()))
return diags
}

diags.Append(flattener.float(ctx, vElem, tTo, vTo)...)
diags.Append(flattener.float(ctx, vElem, isNilFrom, tTo, vTo)...)
return diags

case reflect.Int32, reflect.Int64:
if vFrom.IsNil() {
vTo.Set(reflect.ValueOf(types.Int64Null()))
return diags
}

diags.Append(flattener.int(ctx, vElem, tTo, vTo)...)
diags.Append(flattener.int(ctx, vElem, isNilFrom, tTo, vTo)...)
return diags

case reflect.String:
if vFrom.IsNil() {
vTo.Set(reflect.ValueOf(types.StringNull()))
return diags
}

diags.Append(flattener.string(ctx, vElem, tTo, vTo)...)
diags.Append(flattener.string(ctx, vElem, isNilFrom, tTo, vTo)...)
return diags

case reflect.Struct:
if tTo, ok := tTo.(fwtypes.NestedObjectType); ok {
//
// *struct -> types.List(OfObject).
//
diags.Append(flattener.ptrToStructNestedObject(ctx, vFrom, tTo, vTo)...)
diags.Append(flattener.ptrToStructNestedObject(ctx, vElem, isNilFrom, tTo, vTo)...)
return diags
}
}
Expand Down Expand Up @@ -1407,10 +1403,10 @@ func (flattener autoFlattener) structToNestedObject(ctx context.Context, vFrom r
}

// ptrToStructNestedObject copies an AWS API *struct value to a compatible Plugin Framework NestedObjectValue value.
func (flattener autoFlattener) ptrToStructNestedObject(ctx context.Context, vFrom reflect.Value, tTo fwtypes.NestedObjectType, vTo reflect.Value) diag.Diagnostics {
func (flattener autoFlattener) ptrToStructNestedObject(ctx context.Context, vFrom reflect.Value, isNullFrom bool, tTo fwtypes.NestedObjectType, vTo reflect.Value) diag.Diagnostics {
var diags diag.Diagnostics

if vFrom.IsNil() {
if isNullFrom {
val, d := tTo.NullValue(ctx)
diags.Append(d...)
if diags.HasError() {
Expand All @@ -1428,7 +1424,7 @@ func (flattener autoFlattener) ptrToStructNestedObject(ctx context.Context, vFro
return diags
}

diags.Append(autoFlexConvertStruct(ctx, vFrom.Elem().Interface(), to, flattener)...)
diags.Append(autoFlexConvertStruct(ctx, vFrom.Interface(), to, flattener)...)
if diags.HasError() {
return diags
}
Expand Down
120 changes: 120 additions & 0 deletions internal/framework/flex/autoflex_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,65 @@ func TestGenericExpandAdvanced(t *testing.T) {
}
}

type TestFlexTF17 struct {
Field1 fwtypes.ARN `tfsdk:"field1"`
}

func TestGenericExpandCustomStringType(t *testing.T) {
t.Parallel()

a := "arn:aws:securityhub:us-west-2:1234567890:control/cis-aws-foundations-benchmark/v/1.2.0/1.1" //lintignore:AWSAT003,AWSAT005
ctx := context.Background()
testCases := []struct {
Context context.Context //nolint:containedctx // testing context use
TestName string
Source any
Target any
WantErr bool
WantTarget any
}{
{
TestName: "single ARN Source and single string Target",
Source: &TestFlexTF17{Field1: fwtypes.ARNValue(a)},
Target: &TestFlexAWS01{},
WantTarget: &TestFlexAWS01{Field1: a},
},
{
TestName: "single ARN Source and single *string Target",
Source: &TestFlexTF17{Field1: fwtypes.ARNValue(a)},
Target: &TestFlexAWS02{},
WantTarget: &TestFlexAWS02{Field1: aws.String(a)},
},
}

for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.TestName, func(t *testing.T) {
t.Parallel()

testCtx := ctx //nolint:contextcheck // simplify use of testing context
if testCase.Context != nil {
testCtx = testCase.Context
}

err := Expand(testCtx, testCase.Source, testCase.Target)
gotErr := err != nil

if gotErr != testCase.WantErr {
t.Errorf("gotErr = %v, wantErr = %v", gotErr, testCase.WantErr)
}

if gotErr {
if !testCase.WantErr {
t.Errorf("err = %q", err)
}
} else if diff := cmp.Diff(testCase.Target, testCase.WantTarget); diff != "" {
t.Errorf("unexpected diff (+wanted, -got): %s", diff)
}
})
}
}

func TestGenericFlatten(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -1392,3 +1451,64 @@ func TestGenericFlattenAdvanced(t *testing.T) {
})
}
}

func TestGenericFlattenCustomStringType(t *testing.T) {
t.Parallel()

a := "arn:aws:securityhub:us-west-2:1234567890:control/cis-aws-foundations-benchmark/v/1.2.0/1.1" //lintignore:AWSAT003,AWSAT005
ctx := context.Background()
testCases := []struct {
Context context.Context //nolint:containedctx // testing context use
TestName string
Source any
Target any
WantErr bool
WantTarget any
}{
{
TestName: "single string Source and single ARN Target",
Source: &TestFlexAWS01{Field1: a},
Target: &TestFlexTF17{},
WantTarget: &TestFlexTF17{Field1: fwtypes.ARNValue(a)},
},
{
TestName: "single *string Source and single ARN Target",
Source: &TestFlexAWS02{Field1: aws.String(a)},
Target: &TestFlexTF17{},
WantTarget: &TestFlexTF17{Field1: fwtypes.ARNValue(a)},
},
{
TestName: "single nil *string Source and single ARN Target",
Source: &TestFlexAWS02{},
Target: &TestFlexTF17{},
WantTarget: &TestFlexTF17{Field1: fwtypes.ARNNull()},
},
}

for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.TestName, func(t *testing.T) {
t.Parallel()

testCtx := ctx //nolint:contextcheck // simplify use of testing context
if testCase.Context != nil {
testCtx = testCase.Context
}

err := Flatten(testCtx, testCase.Source, testCase.Target)
gotErr := err != nil

if gotErr != testCase.WantErr {
t.Errorf("gotErr = %v, wantErr = %v", gotErr, testCase.WantErr)
}

if gotErr {
if !testCase.WantErr {
t.Errorf("err = %q", err)
}
} else if diff := cmp.Diff(testCase.Target, testCase.WantTarget); diff != "" {
t.Errorf("unexpected diff (+wanted, -got): %s", diff)
}
})
}
}