diff --git a/pkg/core/persist.go b/pkg/core/persist.go index 342dac77c..819518f57 100644 --- a/pkg/core/persist.go +++ b/pkg/core/persist.go @@ -7,13 +7,9 @@ import ( type ( Persist struct { - Name string - Kind PersistKind - } - - Fs struct { - Persist - GenerateNewFs bool + Name string + Kind PersistKind + GenerateNew bool } Secrets struct { diff --git a/pkg/lang/golang/arguments.go b/pkg/lang/golang/arguments.go index ffa32b329..562755997 100644 --- a/pkg/lang/golang/arguments.go +++ b/pkg/lang/golang/arguments.go @@ -3,6 +3,7 @@ package golang import ( "fmt" + "github.com/klothoplatform/klotho/pkg/query" sitter "github.com/smacker/go-tree-sitter" ) @@ -12,27 +13,36 @@ type Argument struct { } // GetArguements is passed a tree-sitter node, which is of type argument_list, and returns a list of in order Arguments -func GetArguements(args *sitter.Node) []Argument { - - arguments := []Argument{} - nextMatch := doQuery(args, findArgs) +func getArguements(args *sitter.Node) (arguments []Argument, found bool) { + fnName := "" + nextMatch := doQuery(args, findFunctionCall) for { match, found := nextMatch() if !found { break } - + fn := match["function"] arg := match["arg"] + + if fnName != "" && !query.NodeContentEquals(fn, fnName) { + break + } + + fnName = fn.Content() + if arg == nil { continue } arguments = append(arguments, Argument{Content: arg.Content(), Type: arg.Type()}) } - return arguments + if fnName != "" { + found = true + } + return } -func ArgumentListToString(args []Argument) string { +func argumentListToString(args []Argument) string { result := "(" for index, arg := range args { if index < len(args)-1 { diff --git a/pkg/lang/golang/arguments_test.go b/pkg/lang/golang/arguments_test.go new file mode 100644 index 000000000..f0baea7d2 --- /dev/null +++ b/pkg/lang/golang/arguments_test.go @@ -0,0 +1,58 @@ +package golang + +import ( + "strings" + "testing" + + "github.com/klothoplatform/klotho/pkg/core" + "github.com/stretchr/testify/assert" +) + +func Test_GetArguements(t *testing.T) { + tests := []struct { + name string + source string + want []Argument + wantFound bool + }{ + { + name: "finds next function Name and args", + source: ` + x = s.my_func("val") + y = s.other_func("something_else) + `, + want: []Argument{ + {Content: `"val"`, Type: "interpreted_string_literal"}, + }, + wantFound: true, + }, + { + name: "args not required", + source: `v, err := s.someFunc()`, + wantFound: false, + }, + { + name: "a call containing other function calls as args", + source: `v, err := runtimevar.OpenVariable(context.TODO(), fmt.Sprintf("file://%s?decoder=string", path))`, + want: []Argument{ + {Content: "context.TODO()", Type: "call_expression"}, + {Content: `fmt.Sprintf("file://%s?decoder=string", path)`, Type: "call_expression"}, + }, + wantFound: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + + f, err := core.NewSourceFile("", strings.NewReader(tt.source), Language) + if !assert.NoError(err) { + return + } + args, found := getArguements(f.Tree().RootNode()) + + assert.ElementsMatch(tt.want, args) + assert.Equal(tt.wantFound, found) + }) + } +} diff --git a/pkg/lang/golang/aws_runtime/Lambda_Dockerfile b/pkg/lang/golang/aws_runtime/Lambda_Dockerfile index e09fafdb7..036831198 100644 --- a/pkg/lang/golang/aws_runtime/Lambda_Dockerfile +++ b/pkg/lang/golang/aws_runtime/Lambda_Dockerfile @@ -1,13 +1,13 @@ -FROM public.ecr.aws/lambda/provided:al2 - -WORKDIR ${LAMBDA_TASK_ROOT} +FROM golang:1.20 as builder -RUN yum install -y golang -RUN go env -w GOPROXY=https://proxy.golang.org,direct +WORKDIR /usr/src/app +ENV GOOS=linux GOARCH=amd64 CGO_ENABLED=0 +COPY go.mod ./ +RUN go mod tidy && go mod download && go mod verify COPY . . -RUN env GOOS=linux GOARCH=amd64 CGO_ENABLED=0 -RUN go mod tidy -RUN go build -o=/main +RUN go build -o /usr/local/bin/app -ENTRYPOINT ["/main"] \ No newline at end of file +FROM public.ecr.aws/lambda/provided:al2 +COPY --from=builder /usr/local/bin/app main +ENTRYPOINT ["/main"] diff --git a/pkg/lang/golang/aws_runtime/aws.go b/pkg/lang/golang/aws_runtime/aws.go index 298490709..3b0e3870c 100644 --- a/pkg/lang/golang/aws_runtime/aws.go +++ b/pkg/lang/golang/aws_runtime/aws.go @@ -65,3 +65,12 @@ func (r *AwsRuntime) GetFsImports() []golang.Import { {Package: "gocloud.dev/blob/s3blob", Alias: "_"}, } } + +func (r *AwsRuntime) GetSecretsImports() []golang.Import { + return []golang.Import{ + {Package: "os"}, + {Package: "strings"}, + {Package: "gocloud.dev/runtimevar"}, + {Package: "gocloud.dev/runtimevar/awssecretsmanager", Alias: "_"}, + } +} diff --git a/pkg/lang/golang/plugin_fs.go b/pkg/lang/golang/plugin_fs.go index 8a2c91d4e..036ca5915 100644 --- a/pkg/lang/golang/plugin_fs.go +++ b/pkg/lang/golang/plugin_fs.go @@ -1,8 +1,8 @@ package golang import ( - "errors" "fmt" + "strings" "github.com/klothoplatform/klotho/pkg/annotation" "github.com/klothoplatform/klotho/pkg/core" @@ -81,47 +81,45 @@ func (p *PersistFsPlugin) transformFS(f *core.SourceFile, cap *core.Annotation, unit.EnvironmentVariables = append(unit.EnvironmentVariables, fsEnvVar) - args := GetArguements(result.args) + args, _ := getArguements(result.expression) + // Generate the new node content before replacing the node. We just set it so we can compile correctly + newNodeContent := `var _ = ` + args[1].Content + "\n" - // We need to check to make sure the path supplied to the original node content is a static string. This is because it will get erased and we dont want to leave os level orphaned code - if !args[0].IsString() { - return nil, errors.New("must supply static string for secret path") - } - - args[0].Content = "nil" args[1].Content = fmt.Sprintf(`"s3://" + os.Getenv("%s") + "?region=" + os.Getenv("AWS_REGION")`, fsEnvVar.Name) - err := f.ReplaceNodeContent(result.args, ArgumentListToString(args)) - if err != nil { - return nil, err - } - err = f.ReplaceNodeContent(result.operator, "blob") + newArgContent := argumentListToString(args) + + newExpressionContent := strings.ReplaceAll(result.expression.Content(), result.args.Content(), newArgContent) + newNodeContent += newExpressionContent + + err := f.ReplaceNodeContent(result.expression, newNodeContent) if err != nil { return nil, err } - err = UpdateImportsInFile(f, p.runtime.GetFsImports(), []Import{{Package: "gocloud.dev/blob/fileblob"}}) + err = UpdateImportsInFile(f, p.runtime.GetFsImports(), []Import{}) if err != nil { return nil, err } persist := &core.Persist{ - Kind: core.PersistFileKind, - Name: cap.Capability.ID, + Kind: core.PersistFileKind, + Name: cap.Capability.ID, + GenerateNew: true, } return persist, nil } type persistResult struct { - varName string - operator *sitter.Node - args *sitter.Node + varName string + expression *sitter.Node + args *sitter.Node } func queryFS(file *core.SourceFile, annotation *core.Annotation) *persistResult { log := zap.L().With(logging.FileField(file), logging.AnnotationField(annotation)) - fileBlobImport := GetNamedImportInFile(file, "gocloud.dev/blob/fileblob") + fileBlobImport := GetNamedImportInFile(file, "gocloud.dev/blob") nextMatch := doQuery(annotation.Node, fileBucket) @@ -138,7 +136,7 @@ func queryFS(file *core.SourceFile, annotation *core.Annotation) *persistResult return nil } } else { - if !query.NodeContentEquals(id, "fileblob") { + if !query.NodeContentEquals(id, "blob") { return nil } } @@ -150,8 +148,8 @@ func queryFS(file *core.SourceFile, annotation *core.Annotation) *persistResult } return &persistResult{ - varName: varName.Content(), - operator: id, - args: args, + varName: varName.Content(), + expression: match["expression"], + args: args, } } diff --git a/pkg/lang/golang/plugin_fs_test.go b/pkg/lang/golang/plugin_fs_test.go index 02bc4f513..245f8be4d 100644 --- a/pkg/lang/golang/plugin_fs_test.go +++ b/pkg/lang/golang/plugin_fs_test.go @@ -1,6 +1,7 @@ package golang import ( + "fmt" "strings" "testing" @@ -19,14 +20,14 @@ func Test_queryFS(t *testing.T) { name: "simple file blob", source: ` import ( - "gocloud.dev/blob/fileblob" + "gocloud.dev/blob" ) /** * @klotho::persist { * id = "test" * } */ -bucket, err := fileblob.OpenBucket("myDir", nil)`, +bucket, err := blob.OpenBucket(context.Background(), fmt.Sprintf("file://%s", path))`, want: &persistResult{ varName: "bucket", }, @@ -35,14 +36,14 @@ bucket, err := fileblob.OpenBucket("myDir", nil)`, name: "simple var file blob", source: ` import ( - "gocloud.dev/blob/fileblob" + "gocloud.dev/blob" ) /** * @klotho::persist { * id = "test" * } */ -var bucket, err = fileblob.OpenBucket("myDir", nil)`, +var bucket, err = blob.OpenBucket(context.Background(), fmt.Sprintf("file://%s", path))`, want: &persistResult{ varName: "bucket", }, @@ -51,15 +52,16 @@ var bucket, err = fileblob.OpenBucket("myDir", nil)`, name: "simple var file blob", source: ` import ( - "gocloud.dev/blob/fileblob" + "gocloud.dev/blob" ) +var bucket *blob.Bucket +var err error /** * @klotho::persist { * id = "test" * } */ -var bucket, err -bucket, err = fileblob.OpenBucket("myDir", nil)`, +bucket, err = blob.OpenBucket(context.Background(), fmt.Sprintf("file://%s", path))`, want: &persistResult{ varName: "bucket", }, @@ -68,30 +70,14 @@ bucket, err = fileblob.OpenBucket("myDir", nil)`, name: "aliased file blob", source: ` import ( - alias "gocloud.dev/blob/fileblob" - ) - /** - * @klotho::persist { - * id = "test" - * } - */ - bucket, err := alias.OpenBucket("myDir", nil)`, - want: &persistResult{ - varName: "bucket", - }, - }, - { - name: "non string as path still found in query", - source: ` - import ( - alias "gocloud.dev/blob/fileblob" + alias "gocloud.dev/blob" ) /** * @klotho::persist { * id = "test" * } */ - bucket, err := alias.OpenBucket(myDir, nil)`, + bucket, err := alias.OpenBucket(context.Background(), fmt.Sprintf("file://%s", path))`, want: &persistResult{ varName: "bucket", }, @@ -100,14 +86,14 @@ bucket, err = fileblob.OpenBucket("myDir", nil)`, name: "wrong import no match", source: ` import ( - "gocloud.dev/blob/fileblobby" + "gocloud.dev/blobby" ) /** * @klotho::persist { * id = "test" * } */ - bucket, err := fileblobby.OpenBucket(myDir, nil)`, + bucket, err := blobby.OpenBucket(context.Background(), fmt.Sprintf("file://%s", path))`, }, } for _, tt := range tests { @@ -128,6 +114,7 @@ bucket, err = fileblob.OpenBucket("myDir", nil)`, assert.Nil(result) return } + fmt.Println(result) assert.Equal(tt.want.varName, result.varName) }) } @@ -148,14 +135,14 @@ func Test_Transform(t *testing.T) { name: "simple file blob", source: `package fs import ( - "gocloud.dev/blob/fileblob" + "gocloud.dev/blob" ) /** * @klotho::persist { * id = "test" * } */ -bucket, err := fileblob.OpenBucket("myDir", nil) +bucket, err := blob.OpenBucket(context.Background(), fmt.Sprintf("file://%s", path)) `, want: testResult{ resource: core.Persist{ @@ -165,8 +152,8 @@ bucket, err := fileblob.OpenBucket("myDir", nil) content: `package fs import ( - "gocloud.dev/blob" _ "gocloud.dev/blob/s3blob" + "gocloud.dev/blob" ) /** @@ -174,7 +161,8 @@ import ( * id = "test" * } */ -bucket, err := blob.OpenBucket(nil, "s3://" + os.Getenv("test_fs_bucket") + "?region=" + os.Getenv("AWS_REGION")) +var _ = fmt.Sprintf("file://%s", path) +bucket, err := blob.OpenBucket(context.Background(), "s3://" + os.Getenv("test_fs_bucket") + "?region=" + os.Getenv("AWS_REGION")) `, }, }, @@ -182,14 +170,14 @@ bucket, err := blob.OpenBucket(nil, "s3://" + os.Getenv("test_fs_bucket") + "?re name: "long var file blob", source: `package fs import ( - "gocloud.dev/blob/fileblob" + "gocloud.dev/blob" ) /** * @klotho::persist { * id = "test" * } */ -var bucket, err = fileblob.OpenBucket("myDir", nil) +var bucket, err = blob.OpenBucket(context.Background(), fmt.Sprintf("file://%s", path)) `, want: testResult{ resource: core.Persist{ @@ -199,8 +187,8 @@ var bucket, err = fileblob.OpenBucket("myDir", nil) content: `package fs import ( - "gocloud.dev/blob" _ "gocloud.dev/blob/s3blob" + "gocloud.dev/blob" ) /** @@ -208,7 +196,8 @@ import ( * id = "test" * } */ -var bucket, err = blob.OpenBucket(nil, "s3://" + os.Getenv("test_fs_bucket") + "?region=" + os.Getenv("AWS_REGION")) +var _ = fmt.Sprintf("file://%s", path) +var bucket, err = blob.OpenBucket(context.Background(), "s3://" + os.Getenv("test_fs_bucket") + "?region=" + os.Getenv("AWS_REGION")) `, }, }, @@ -216,15 +205,16 @@ var bucket, err = blob.OpenBucket(nil, "s3://" + os.Getenv("test_fs_bucket") + " name: "var deckaration file blob", source: `package fs import ( - "gocloud.dev/blob/fileblob" + "gocloud.dev/blob" ) +var bucket *blob.Bucket +var err error /** * @klotho::persist { * id = "test" * } */ -var bucket, err -bucket, err = fileblob.OpenBucket("myDir", nil) +bucket, err = blob.OpenBucket(context.Background(), fmt.Sprintf("file://%s", path)) `, want: testResult{ resource: core.Persist{ @@ -234,34 +224,22 @@ bucket, err = fileblob.OpenBucket("myDir", nil) content: `package fs import ( - "gocloud.dev/blob" _ "gocloud.dev/blob/s3blob" + "gocloud.dev/blob" ) +var bucket *blob.Bucket +var err error /** * @klotho::persist { * id = "test" * } */ -var bucket, err -bucket, err = blob.OpenBucket(nil, "s3://" + os.Getenv("test_fs_bucket") + "?region=" + os.Getenv("AWS_REGION")) +var _ = fmt.Sprintf("file://%s", path) +bucket, err = blob.OpenBucket(context.Background(), "s3://" + os.Getenv("test_fs_bucket") + "?region=" + os.Getenv("AWS_REGION")) `, }, }, - { - name: "non string as path throws err", - source: `package fs - import ( - alias "gocloud.dev/blob/fileblob" - ) - /** - * @klotho::persist { - * id = "test" - * } - */ - bucket, err := alias.OpenBucket(myDir, nil)`, - wantErr: true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/pkg/lang/golang/plugin_secrets.go b/pkg/lang/golang/plugin_secrets.go new file mode 100644 index 000000000..a8476bacd --- /dev/null +++ b/pkg/lang/golang/plugin_secrets.go @@ -0,0 +1,157 @@ +package golang + +import ( + "fmt" + "strings" + + "github.com/klothoplatform/klotho/pkg/annotation" + "github.com/klothoplatform/klotho/pkg/config" + "github.com/klothoplatform/klotho/pkg/core" + "github.com/klothoplatform/klotho/pkg/logging" + "github.com/klothoplatform/klotho/pkg/multierr" + "github.com/klothoplatform/klotho/pkg/query" + sitter "github.com/smacker/go-tree-sitter" + "go.uber.org/zap" +) + +type PersistSecretsPlugin struct { + runtime Runtime + config *config.Application +} + +func (p PersistSecretsPlugin) Name() string { return "Persist" } + +func (p PersistSecretsPlugin) Transform(result *core.CompilationResult, deps *core.Dependencies) error { + + var errs multierr.Error + for _, res := range result.Resources() { + unit, ok := res.(*core.ExecutionUnit) + if !ok { + continue + } + for _, goSource := range unit.FilesOfLang(goLang) { + resources, err := p.handleFile(goSource, unit) + if err != nil { + errs.Append(core.WrapErrf(err, "failed to handle persist in unit %s", unit.Name)) + continue + } + + for _, r := range resources { + result.Add(r) + + deps.Add(core.ResourceKey{ + Name: unit.Name, + Kind: core.ExecutionUnitKind, + }, r.Key()) + } + } + } + + return errs.ErrOrNil() +} + +func (p *PersistSecretsPlugin) handleFile(f *core.SourceFile, unit *core.ExecutionUnit) ([]core.CloudResource, error) { + resources := []core.CloudResource{} + var errs multierr.Error + annots := f.Annotations() + for _, annot := range annots { + cap := annot.Capability + if cap.Name != annotation.PersistCapability { + continue + } + secretsResult := querySecret(f, annot) + if secretsResult != nil { + persistResource, err := p.transformSecret(f, annot, secretsResult, unit) + if err != nil { + errs.Append(err) + } + resources = append(resources, persistResource) + + } + } + return resources, errs.ErrOrNil() +} + +func (p *PersistSecretsPlugin) transformSecret(f *core.SourceFile, cap *core.Annotation, result *persistSecretResult, unit *core.ExecutionUnit) (core.CloudResource, error) { + + args, found := getArguements(result.expression) + if !found { + return nil, nil + } + // Generate the new node content before replacing the node. + // We are going to set a new variable to the original file path and split to get the query params + newNodeContent := `klothoRuntimePathSub := ` + args[1].Content + //Split the path to get anything after ? so we can get the query params + newNodeContent += "\nklothoRuntimePathSubChunks := strings.SplitN(klothoRuntimePathSub, \"?\", 2)\n" + newNodeContent += `var queryParams string + if len(klothoRuntimePathSubChunks) == 2 { + queryParams = "&" + klothoRuntimePathSubChunks[1] + } + ` + + args[1].Content = fmt.Sprintf(`"awssecretsmanager://%s?region=" + os.Getenv("AWS_REGION") + queryParams`, p.config.AppName+"_"+cap.Capability.ID) + + newArgContent := argumentListToString(args) + + newExpressionContent := strings.ReplaceAll(result.expression.Content(), result.args.Content(), newArgContent) + newNodeContent += newExpressionContent + + err := f.ReplaceNodeContent(result.expression, newNodeContent) + if err != nil { + return nil, err + } + + err = UpdateImportsInFile(f, p.runtime.GetSecretsImports(), []Import{}) + if err != nil { + return nil, err + } + + persist := &core.Persist{ + Kind: core.PersistSecretKind, + Name: cap.Capability.ID, + } + return persist, nil +} + +type persistSecretResult struct { + varName string + args *sitter.Node + expression *sitter.Node +} + +func querySecret(file *core.SourceFile, annotation *core.Annotation) *persistSecretResult { + log := zap.L().With(logging.FileField(file), logging.AnnotationField(annotation)) + + runtimeVarImport := GetNamedImportInFile(file, "gocloud.dev/runtimevar") + + nextMatch := doQuery(annotation.Node, openVariable) + + match, found := nextMatch() + if !found { + return nil + } + varName, args, id := match["varName"], match["args"], match["id"] + + if id != nil { + if runtimeVarImport.Alias != "" { + if !query.NodeContentEquals(id, runtimeVarImport.Alias) { + return nil + } + } else { + if !query.NodeContentEquals(id, "runtimevar") { + return nil + } + } + } + + if _, found := nextMatch(); found { + log.Warn("too many assignments for fs_secrets") + return nil + } + + return &persistSecretResult{ + varName: varName.Content(), + args: args, + expression: match["expression"], + } +} diff --git a/pkg/lang/golang/plugin_secrets_test.go b/pkg/lang/golang/plugin_secrets_test.go new file mode 100644 index 000000000..dad94af36 --- /dev/null +++ b/pkg/lang/golang/plugin_secrets_test.go @@ -0,0 +1,297 @@ +package golang + +import ( + "fmt" + "strings" + "testing" + + "github.com/klothoplatform/klotho/pkg/config" + "github.com/klothoplatform/klotho/pkg/core" + "github.com/stretchr/testify/assert" +) + +func Test_querySecrets(t *testing.T) { + tests := []struct { + name string + source string + want *persistResult + wantErr bool + }{ + { + name: "simple runtime var", + source: ` +import ( + "gocloud.dev/runtimevar" +) +/** +* @klotho::persist { +* id = "test" +* } +*/ +v, err := runtimevar.OpenVariable(context.TODO(), fmt.Sprintf("file://%s?decoder=string", path))`, + want: &persistResult{ + varName: "v", + }, + }, + { + name: "simple var runtime var", + source: ` +import ( + "gocloud.dev/runtimevar" +) +/** +* @klotho::persist { +* id = "test" +* } +*/ +var v, err = runtimevar.OpenVariable(context.TODO(), fmt.Sprintf("file://%s?decoder=string", path))`, + want: &persistResult{ + varName: "v", + }, + }, + { + name: "simple var declaration", + source: ` +import ( + "gocloud.dev/runtimevar" +) +var v *runtimevar.Variable +var err error +/** +* @klotho::persist { +* id = "test" +* } +*/ +v, err = runtimevar.OpenVariable(context.TODO(), fmt.Sprintf("file://%s?decoder=string", path))`, + want: &persistResult{ + varName: "v", + }, + }, + { + name: "aliased file blob", + source: ` +import ( + alias "gocloud.dev/runtimevar" +) +/** +* @klotho::persist { +* id = "test" +* } +*/ +v, err := alias.OpenVariable(context.TODO(), fmt.Sprintf("file://%s?decoder=string", path))`, + want: &persistResult{ + varName: "v", + }, + }, + { + name: "wrong import no match", + source: ` +import ( + "gocloud.dev/runtimevarrrrr" +) +/** +* @klotho::persist { +* id = "test" +* } +*/ +v, err := runtimevarrrrr.OpenVariable(context.TODO(), fmt.Sprintf("file://%s?decoder=string", path))`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + + f, err := core.NewSourceFile("test.go", strings.NewReader(tt.source), Language) + if !assert.NoError(err) { + return + } + annot, ok := f.Annotations()[core.AnnotationKey{Capability: "persist", ID: "test"}] + + if !assert.True(ok) { + return + } + result := querySecret(f, annot) + if tt.want == nil { + assert.Nil(result) + return + } + fmt.Println(result) + assert.Equal(tt.want.varName, result.varName) + }) + } +} + +func Test_TransformSecrets(t *testing.T) { + type testResult struct { + resource core.Persist + content string + } + tests := []struct { + name string + source string + want testResult + wantErr bool + }{ + { + name: "simple open var", + source: `package fs +import ( + "gocloud.dev/runtimevar" +) +/** +* @klotho::persist { +* id = "test" +* secret = true +* } +*/ +v, err := runtimevar.OpenVariable(context.TODO(), fmt.Sprintf("file://%s?decoder=string", path)) +`, + want: testResult{ + resource: core.Persist{ + Kind: core.PersistSecretKind, + Name: "test", + }, + content: `package fs + +import ( + _ "gocloud.dev/runtimevar/awssecretsmanager" + "gocloud.dev/runtimevar" +) + +/** +* @klotho::persist { +* id = "test" +* secret = true +* } +*/ +klothoRuntimePathSub := fmt.Sprintf("file://%s?decoder=string", path) +klothoRuntimePathSubChunks := strings.SplitN(klothoRuntimePathSub, "?", 2) +var queryParams string + if len(klothoRuntimePathSubChunks) == 2 { + queryParams = "&" + klothoRuntimePathSubChunks[1] + } + v, err := runtimevar.OpenVariable(context.TODO(), "awssecretsmanager://app_test?region=" + os.Getenv("AWS_REGION") + queryParams) +`, + }, + }, + { + name: "long var open var", + source: `package fs +import ( + "gocloud.dev/runtimevar" +) +/** +* @klotho::persist { +* id = "test" +* secret = true +* } +*/ +var v, err = runtimevar.OpenVariable(context.TODO(), fmt.Sprintf("file://%s?decoder=string", path)) +`, + want: testResult{ + resource: core.Persist{ + Kind: core.PersistSecretKind, + Name: "test", + }, + content: `package fs + +import ( + _ "gocloud.dev/runtimevar/awssecretsmanager" + "gocloud.dev/runtimevar" +) + +/** +* @klotho::persist { +* id = "test" +* secret = true +* } +*/ +klothoRuntimePathSub := fmt.Sprintf("file://%s?decoder=string", path) +klothoRuntimePathSubChunks := strings.SplitN(klothoRuntimePathSub, "?", 2) +var queryParams string + if len(klothoRuntimePathSubChunks) == 2 { + queryParams = "&" + klothoRuntimePathSubChunks[1] + } + var v, err = runtimevar.OpenVariable(context.TODO(), "awssecretsmanager://app_test?region=" + os.Getenv("AWS_REGION") + queryParams) +`, + }, + }, + { + name: "var declaration open var", + source: `package fs +import ( + "gocloud.dev/runtimevar" +) +var v *runtimevar.Variable +var err error +/** +* @klotho::persist { +* id = "test" +* secret = true +* } +*/ +v, err = runtimevar.OpenVariable(context.TODO(), fmt.Sprintf("file://%s?decoder=string", path)) +`, + want: testResult{ + resource: core.Persist{ + Kind: core.PersistSecretKind, + Name: "test", + }, + content: `package fs + +import ( + _ "gocloud.dev/runtimevar/awssecretsmanager" + "gocloud.dev/runtimevar" +) + +var v *runtimevar.Variable +var err error +/** +* @klotho::persist { +* id = "test" +* secret = true +* } +*/ +klothoRuntimePathSub := fmt.Sprintf("file://%s?decoder=string", path) +klothoRuntimePathSubChunks := strings.SplitN(klothoRuntimePathSub, "?", 2) +var queryParams string + if len(klothoRuntimePathSubChunks) == 2 { + queryParams = "&" + klothoRuntimePathSubChunks[1] + } + v, err = runtimevar.OpenVariable(context.TODO(), "awssecretsmanager://app_test?region=" + os.Getenv("AWS_REGION") + queryParams) +`, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + + cfg := config.Application{AppName: "app"} + p := PersistSecretsPlugin{runtime: NoopRuntime{}, config: &cfg} + unit := core.ExecutionUnit{} + + f, err := core.NewSourceFile("test.go", strings.NewReader(tt.source), Language) + if !assert.NoError(err) { + return + } + annot, ok := f.Annotations()[core.AnnotationKey{Capability: "persist", ID: "test"}] + + if !assert.True(ok) { + return + } + queryResult := querySecret(f, annot) + fmt.Println(queryResult) + result, err := p.transformSecret(f, annot, queryResult, &unit) + if tt.wantErr { + assert.Error(err) + return + } else if !assert.NoError(err) { + return + } + + assert.Equal(tt.want.resource.Key(), result.Key()) + assert.Equal(tt.want.content, string(f.Program())) + }) + } +} diff --git a/pkg/lang/golang/plugins.go b/pkg/lang/golang/plugins.go index d9ae69800..647aaa4ca 100644 --- a/pkg/lang/golang/plugins.go +++ b/pkg/lang/golang/plugins.go @@ -18,6 +18,7 @@ func NewGoPlugins(cfg *config.Application, runtime Runtime) *GoPlugins { &Expose{Config: cfg}, &AddExecRuntimeFiles{cfg: cfg, runtime: runtime}, &PersistFsPlugin{runtime: runtime}, + &PersistSecretsPlugin{runtime: runtime, config: cfg}, }, } } diff --git a/pkg/lang/golang/queries.go b/pkg/lang/golang/queries.go index 15befa004..61f072405 100644 --- a/pkg/lang/golang/queries.go +++ b/pkg/lang/golang/queries.go @@ -7,8 +7,8 @@ import ( //go:embed queries/imports.scm var findImports string -//go:embed queries/find_args.scm -var findArgs string +//go:embed queries/find_function_call.scm +var findFunctionCall string //go:embed queries/expose/chirouter_assignment.scm var findRouterAssignment string @@ -30,3 +30,6 @@ var packageQuery string //go:embed queries/gocloud/file_bucket.scm var fileBucket string + +//go:embed queries/gocloud/open_variable.scm +var openVariable string diff --git a/pkg/lang/golang/queries/find_args.scm b/pkg/lang/golang/queries/find_args.scm deleted file mode 100644 index 536695131..000000000 --- a/pkg/lang/golang/queries/find_args.scm +++ /dev/null @@ -1,4 +0,0 @@ -; Finds arguments in an arguments_list -(argument_list - (_) @arg -) diff --git a/pkg/lang/golang/queries/find_function_call.scm b/pkg/lang/golang/queries/find_function_call.scm new file mode 100644 index 000000000..eef57aee2 --- /dev/null +++ b/pkg/lang/golang/queries/find_function_call.scm @@ -0,0 +1,7 @@ +; Finds arguments in an arguments_list + (call_expression + function: (_) @function + arguments: (argument_list + (_)@arg + ) +)@call diff --git a/pkg/lang/golang/queries/gocloud/open_variable.scm b/pkg/lang/golang/queries/gocloud/open_variable.scm new file mode 100644 index 000000000..0658ef072 --- /dev/null +++ b/pkg/lang/golang/queries/gocloud/open_variable.scm @@ -0,0 +1,51 @@ +[ + (short_var_declaration + left: (expression_list + (identifier) @varName + (identifier) + ) @variables + right: (expression_list + (call_expression + function: (selector_expression + operand: (identifier) @id + field: (field_identifier) @method + ) + arguments: (argument_list) @args + (#match? @method "OpenVariable") + )@call + ) +)@expression ;; v, err := runtimevar.OpenVariable(context.Background(), "my_secret.key?decoder=string") +(assignment_statement + left: (expression_list + (identifier) @varName + (identifier) + ) @variables + right: (expression_list + (call_expression + function: (selector_expression + operand: (identifier) @id + field: (field_identifier) @method + ) + arguments: (argument_list) @args + (#match? @method "OpenVariable") + )@call + ) +)@expression ;; v, err = runtimevar.OpenVariable(context.Background(), "my_secret.key?decoder=string") +(var_declaration + (var_spec + name: (identifier) @varName + value: (expression_list + (call_expression + function: (selector_expression + operand: (identifier) @id + field: (field_identifier) @method + ) + arguments: (argument_list) @args + (#match? @method "OpenVariable") + )@call + ) + ) +)@expression ;; var v, err = runtimevar.OpenVariable(context.Background(), "my_secret.key?decoder=string") + +] + diff --git a/pkg/lang/golang/runtime.go b/pkg/lang/golang/runtime.go index 5d1eddb33..dd329e749 100644 --- a/pkg/lang/golang/runtime.go +++ b/pkg/lang/golang/runtime.go @@ -14,6 +14,7 @@ type ( Runtime interface { AddExecRuntimeFiles(unit *core.ExecutionUnit, result *core.CompilationResult, deps *core.Dependencies) error GetFsImports() []Import + GetSecretsImports() []Import } ) diff --git a/pkg/lang/golang/runtime_test.go b/pkg/lang/golang/runtime_test.go index 8029f21ee..ef3fcdfb6 100644 --- a/pkg/lang/golang/runtime_test.go +++ b/pkg/lang/golang/runtime_test.go @@ -15,3 +15,9 @@ func (n NoopRuntime) GetFsImports() []Import { {Alias: "_", Package: "gocloud.dev/blob/s3blob"}, } } +func (n NoopRuntime) GetSecretsImports() []Import { + return []Import{ + {Package: "gocloud.dev/runtimevar"}, + {Alias: "_", Package: "gocloud.dev/runtimevar/awssecretsmanager"}, + } +} diff --git a/pkg/provider/aws/infra_template.go b/pkg/provider/aws/infra_template.go index 9a18cd92c..cdfc01397 100644 --- a/pkg/provider/aws/infra_template.go +++ b/pkg/provider/aws/infra_template.go @@ -141,9 +141,7 @@ func (a *AWS) Transform(result *core.CompilationResult, deps *core.Dependencies) }) data.UseVPC = true } - - case *core.Fs: - if res.GenerateNewFs { + if res.Kind == core.PersistFileKind && res.GenerateNew { data.Buckets = append(data.Buckets, provider.FS{ Name: res.Name, }) diff --git a/pkg/provider/aws/infra_template_test.go b/pkg/provider/aws/infra_template_test.go index 7fc3e9a80..137c19bfe 100644 --- a/pkg/provider/aws/infra_template_test.go +++ b/pkg/provider/aws/infra_template_test.go @@ -82,11 +82,10 @@ func TestInfraTemplateModification(t *testing.T) { { name: "bucket test", results: []core.CloudResource{ - &core.Fs{ - Persist: core.Persist{ - Name: "bucket", - }, - GenerateNewFs: true, + &core.Persist{ + Name: "bucket", + Kind: core.PersistFileKind, + GenerateNew: true, }, }, cfg: config.Application{ @@ -103,11 +102,10 @@ func TestInfraTemplateModification(t *testing.T) { { name: "not new bucket test", results: []core.CloudResource{ - &core.Fs{ - Persist: core.Persist{ - Name: "bucket", - }, - GenerateNewFs: false, + &core.Persist{ + Name: "bucket", + Kind: core.PersistFileKind, + GenerateNew: false, }, }, cfg: config.Application{