diff --git a/env.go b/env.go new file mode 100644 index 0000000..606fa77 --- /dev/null +++ b/env.go @@ -0,0 +1,83 @@ +// Copyright © 2023 Bank-Vaults Maintainers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + "os" + "strings" + + "github.com/bank-vaults/secret-init/provider" + "github.com/bank-vaults/secret-init/provider/file" +) + +func GetEnvironMap() map[string]string { + environ := make(map[string]string, len(os.Environ())) + for _, env := range os.Environ() { + split := strings.SplitN(env, "=", 2) + name := split[0] + value := split[1] + environ[name] = value + } + + return environ +} + +func ExtractPathsFromEnvs(envs map[string]string) []string { + var secretPaths []string + + for _, path := range envs { + if p, path := getProviderPath(path); p != nil { + secretPaths = append(secretPaths, path) + } + } + + return secretPaths +} + +func CreateSecretEnvsFrom(envs map[string]string, secrets []provider.Secret) ([]string, error) { + // Reverse the map so we can match + // the environment variable key to the secret + // by using the secret path + reversedEnvs := make(map[string]string) + for envKey, path := range envs { + if p, path := getProviderPath(path); p != nil { + reversedEnvs[path] = envKey + } + } + + var secretsEnv []string + for _, secret := range secrets { + path := secret.Path + value := secret.Value + key, ok := reversedEnvs[path] + if !ok { + return nil, fmt.Errorf("failed to find environment variable key for secret path: %s", path) + } + secretsEnv = append(secretsEnv, fmt.Sprintf("%s=%s", key, value)) + } + + return secretsEnv, nil +} + +// Returns the detected provider name and path with removed prefix +func getProviderPath(path string) (*string, string) { + if strings.HasPrefix(path, "file:") { + var fileProviderName = file.ProviderName + return &fileProviderName, strings.TrimPrefix(path, "file:") + } + + return nil, path +} diff --git a/go.mod b/go.mod index f6224cb..f302877 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,14 @@ require ( github.com/spf13/cast v1.5.1 ) +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + require ( github.com/samber/lo v1.38.1 // indirect + github.com/stretchr/testify v1.8.4 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 // indirect ) diff --git a/go.sum b/go.sum index 38c3e2a..8e6c1e7 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY= github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -6,6 +8,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= @@ -16,5 +20,11 @@ github.com/samber/slog-syslog v1.0.0 h1:4tf8sNv9+qTQ6Fj8+N6U1ZEtUbqbAIzd+q26/Neg github.com/samber/slog-syslog v1.0.0/go.mod h1:jjupk+yHPVSuXuGhKleoClYc/HEaC+Ro5X4YYeBrt6g= github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA= github.com/spf13/cast v1.5.1/go.mod h1:b9PdjNptOpzXr7Rq1q9gJML/2cdGQAo69NKzQ10KN48= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc= golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 216ebab..5c42d0d 100644 --- a/main.go +++ b/main.go @@ -32,8 +32,24 @@ import ( "github.com/spf13/cast" "github.com/bank-vaults/secret-init/provider" + "github.com/bank-vaults/secret-init/provider/file" ) +func NewProvider(providerName string) (provider.Provider, error) { + switch providerName { + case file.ProviderName: + provider, err := file.NewProvider(os.DirFS("/")) + if err != nil { + return nil, err + } + + return provider, nil + + default: + return nil, errors.New("invalid provider specified") + } +} + func main() { var logger *slog.Logger { @@ -94,8 +110,12 @@ func main() { slog.SetDefault(logger) } - // TODO: enable providers - var provider provider.Provider + provider, err := NewProvider(os.Getenv("PROVIDER")) + if err != nil { + logger.Error(fmt.Errorf("failed to create provider: %w", err).Error()) + + os.Exit(1) + } if len(os.Args) == 1 { logger.Error("no command is given, vault-env can't determine the entrypoint (command), please specify it explicitly or let the webhook query it (see documentation)") @@ -115,10 +135,19 @@ func main() { os.Exit(1) } + environ := GetEnvironMap() + paths := ExtractPathsFromEnvs(environ) + ctx := context.Background() - envs, err := provider.LoadSecrets(ctx, os.Environ()) + secrets, err := provider.LoadSecrets(ctx, paths) + if err != nil { + logger.Error(fmt.Errorf("failed to load secrets from provider: %w", err).Error()) + + os.Exit(1) + } + secretsEnv, err := CreateSecretEnvsFrom(environ, secrets) if err != nil { - logger.Error("could not retrieve secrets from the provider.", err) + logger.Error(fmt.Errorf("failed to create environment variables from loaded secrets: %w", err).Error()) os.Exit(1) } @@ -135,7 +164,7 @@ func main() { if daemonMode { logger.Info("in daemon mode...") cmd := exec.Command(binary, entrypointCmd[1:]...) - cmd.Env = append(os.Environ(), envs...) + cmd.Env = append(os.Environ(), secretsEnv...) cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr cmd.Stdout = os.Stdout @@ -184,7 +213,7 @@ func main() { os.Exit(cmd.ProcessState.ExitCode()) } - err = syscall.Exec(binary, entrypointCmd, envs) + err = syscall.Exec(binary, entrypointCmd, secretsEnv) if err != nil { logger.Error(fmt.Errorf("failed to exec process: %w", err).Error(), slog.String("entrypoint", fmt.Sprint(entrypointCmd))) diff --git a/provider/file/file.go b/provider/file/file.go new file mode 100644 index 0000000..ba6a080 --- /dev/null +++ b/provider/file/file.go @@ -0,0 +1,66 @@ +// Copyright © 2023 Bank-Vaults Maintainers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package file + +import ( + "context" + "fmt" + "io/fs" + "strings" + + "github.com/bank-vaults/secret-init/provider" +) + +const ProviderName = "file" + +type Provider struct { + fs fs.FS +} + +func NewProvider(fs fs.FS) (provider.Provider, error) { + if fs == nil { + return nil, fmt.Errorf("file system is nil") + } + + return &Provider{fs: fs}, nil +} + +func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Secret, error) { + var secrets []provider.Secret + + for _, path := range paths { + secret, err := p.getSecretFromFile(path) + if err != nil { + return nil, fmt.Errorf("failed to get secret from file: %w", err) + } + + secrets = append(secrets, provider.Secret{ + Path: path, + Value: secret, + }) + } + + return secrets, nil +} + +func (p *Provider) getSecretFromFile(filepath string) (string, error) { + filepath = strings.TrimLeft(filepath, "/") + content, err := fs.ReadFile(p.fs, filepath) + if err != nil { + return "", fmt.Errorf("failed to read file: %w", err) + } + + return string(content), nil +} diff --git a/provider/file/file_test.go b/provider/file/file_test.go new file mode 100644 index 0000000..50ddddb --- /dev/null +++ b/provider/file/file_test.go @@ -0,0 +1,125 @@ +// Copyright © 2023 Bank-Vaults Maintainers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package file + +import ( + "context" + "io/fs" + "testing" + "testing/fstest" + + "github.com/stretchr/testify/assert" + + "github.com/bank-vaults/secret-init/provider" +) + +func TestNewProvider(t *testing.T) { + tests := []struct { + name string + fs fs.FS + wantErr bool + wantType bool + }{ + { + name: "Valid file system", + fs: fstest.MapFS{ + "test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")}, + "test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")}, + "test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")}, + }, + wantErr: false, + wantType: true, + }, + { + name: "Nil file system", + fs: nil, + wantErr: true, + wantType: false, + }, + } + + for _, tt := range tests { + ttp := tt + t.Run(ttp.name, func(t *testing.T) { + prov, err := NewProvider(ttp.fs) + if (err != nil) != ttp.wantErr { + t.Fatalf("NewProvider() error = %v, wantErr %v", err, ttp.wantErr) + return + } + // Use type assertion to check if the provider is of the correct type + _, ok := prov.(*Provider) + if ok != ttp.wantType { + t.Fatalf("NewProvider() = %v, wantType %v", ok, ttp.wantType) + } + }) + } +} + +func TestLoadSecrets(t *testing.T) { + tests := []struct { + name string + fs fs.FS + paths []string + wantErr bool + wantData []provider.Secret + }{ + { + name: "Load secrets successfully", + fs: fstest.MapFS{ + "test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")}, + "test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")}, + "test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")}, + }, + paths: []string{ + "test/secrets/sqlpass.txt", + "test/secrets/awsaccess.txt", + "test/secrets/awsid.txt", + }, + wantErr: false, + wantData: []provider.Secret{ + {Path: "test/secrets/sqlpass.txt", Value: "3xtr3ms3cr3t"}, + {Path: "test/secrets/awsaccess.txt", Value: "s3cr3t"}, + {Path: "test/secrets/awsid.txt", Value: "secretId"}, + }, + }, + { + name: "Fail to load secrets due to invalid path", + fs: fstest.MapFS{ + "test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")}, + "test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")}, + "test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")}, + }, + paths: []string{ + "test/secrets/mistake/sqlpass.txt", + "test/secrets/mistake/awsaccess.txt", + "test/secrets/mistake/awsid.txt", + }, + wantErr: true, + wantData: nil, + }, + } + + for _, tt := range tests { + ttp := tt + t.Run(ttp.name, func(t *testing.T) { + provider, err := NewProvider(ttp.fs) + if assert.NoError(t, err, "Unexpected error") { + secrets, err := provider.LoadSecrets(context.Background(), ttp.paths) + assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status") + assert.ElementsMatch(t, ttp.wantData, secrets, "Unexpected secrets loaded") + } + }) + } +} diff --git a/provider/provider.go b/provider/provider.go index c690a0c..14e07fc 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -18,5 +18,11 @@ import "context" // Provider is an interface for securely loading secrets based on environment variables. type Provider interface { - LoadSecrets(ctx context.Context, paths []string) ([]string, error) + LoadSecrets(ctx context.Context, paths []string) ([]Secret, error) +} + +// Secret holds Provider-specific secret data. +type Secret struct { + Path string + Value string }