diff --git a/env_store.go b/env_store.go index 780bf95..958b543 100644 --- a/env_store.go +++ b/env_store.go @@ -32,13 +32,37 @@ import ( "github.com/bank-vaults/secret-init/pkg/provider/vault" ) -var supportedProviders = []string{ - file.ProviderName, - vault.ProviderName, - bao.ProviderName, - aws.ProviderName, - gcp.ProviderName, - azure.ProviderName, +var factories = []provider.Factory{ + { + ProviderType: file.ProviderType, + Validator: file.Valid, + Create: file.NewProvider, + }, + { + ProviderType: vault.ProviderType, + Validator: vault.Valid, + Create: vault.NewProvider, + }, + { + ProviderType: bao.ProviderType, + Validator: bao.Valid, + Create: bao.NewProvider, + }, + { + ProviderType: aws.ProviderType, + Validator: aws.Valid, + Create: aws.NewProvider, + }, + { + ProviderType: gcp.ProviderType, + Validator: gcp.Valid, + Create: gcp.NewProvider, + }, + { + ProviderType: azure.ProviderType, + Validator: azure.Valid, + Create: azure.NewProvider, + }, } // EnvStore is a helper for managing interactions between environment variables and providers, @@ -66,28 +90,11 @@ func NewEnvStore(appConfig *common.Config) *EnvStore { // GetSecretReferences returns a map of secret key=value pairs for each provider func (s *EnvStore) GetSecretReferences() map[string][]string { secretReferences := make(map[string][]string) - for envKey, envPath := range s.data { - providerName, envSecretReference := getProviderPath(envPath) - envSecretReference = envKey + "=" + envSecretReference - switch providerName { - case file.ProviderName: - secretReferences[file.ProviderName] = append(secretReferences[file.ProviderName], envSecretReference) - - case vault.ProviderName: - secretReferences[vault.ProviderName] = append(secretReferences[vault.ProviderName], envSecretReference) - - case bao.ProviderName: - secretReferences[bao.ProviderName] = append(secretReferences[bao.ProviderName], envSecretReference) - - case aws.ProviderName: - secretReferences[aws.ProviderName] = append(secretReferences[aws.ProviderName], envSecretReference) - - case gcp.ProviderName: - secretReferences[gcp.ProviderName] = append(secretReferences[gcp.ProviderName], envSecretReference) - - case azure.ProviderName: - secretReferences[azure.ProviderName] = append(secretReferences[azure.ProviderName], envSecretReference) + for _, factory := range factories { + if factory.Validator(envPath) { + secretReferences[factory.ProviderType] = append(secretReferences[factory.ProviderType], fmt.Sprintf("%s=%s", envKey, envPath)) + } } } @@ -98,56 +105,52 @@ func (s *EnvStore) GetSecretReferences() map[string][]string { // It then asynchronously loads secrets using each provider and it's corresponding paths. // The secrets from each provider are then placed into a single slice. func (s *EnvStore) LoadProviderSecrets(ctx context.Context, providerPaths map[string][]string) ([]provider.Secret, error) { - // At most, we will have one error per provider - errCh := make(chan error, len(supportedProviders)) var providerSecrets []provider.Secret - // Workaround for openBao // Remove once openBao uses BAO_ADDR in their client, instead of VAULT_ADDR - vaultPaths, ok := providerPaths[vault.ProviderName] - if ok { - var err error - providerSecrets, err = s.workaroundForBao(ctx, vaultPaths) + if _, ok := providerPaths[vault.ProviderType]; ok { + vaultSecrets, err := s.workaroundForBao(ctx, providerPaths[vault.ProviderType]) if err != nil { - return nil, fmt.Errorf("failed to workaround for bao: %w", err) + return nil, err } - // Remove the vault paths since they have been processed - delete(providerPaths, vault.ProviderName) + providerSecrets = append(providerSecrets, vaultSecrets...) + delete(providerPaths, vault.ProviderType) } + // At most, we will have one error per provider + errCh := make(chan error, len(factories)) var wg sync.WaitGroup var mu sync.Mutex - for providerName, paths := range providerPaths { wg.Add(1) - go func(providerName string, paths []string, errCh chan<- error) { defer wg.Done() - provider, err := newProvider(ctx, providerName, s.appConfig) - if err != nil { - errCh <- fmt.Errorf("failed to create provider %s: %w", providerName, err) - return + for _, factory := range factories { + if factory.ProviderType == providerName { + provider, err := factory.Create(ctx, s.appConfig) + if err != nil { + errCh <- fmt.Errorf("failed to create provider %s: %w", providerName, err) + return + } + + secrets, err := provider.LoadSecrets(ctx, paths) + if err != nil { + errCh <- fmt.Errorf("failed to load secrets for provider %s: %w", providerName, err) + return + } + + mu.Lock() + providerSecrets = append(providerSecrets, secrets...) + mu.Unlock() + } } - - secrets, err := provider.LoadSecrets(ctx, paths) - if err != nil { - errCh <- fmt.Errorf("failed to load secrets for provider %s: %w", providerName, err) - return - } - - mu.Lock() - providerSecrets = append(providerSecrets, secrets...) - mu.Unlock() }(providerName, paths, errCh) } - - // Wait for all providers to finish wg.Wait() close(errCh) - // Check for errors var errs error for e := range errCh { if e != nil { @@ -163,19 +166,25 @@ func (s *EnvStore) LoadProviderSecrets(ctx context.Context, providerPaths map[st // Workaround for openBao, essentially loading secretes from Vault first. func (s *EnvStore) workaroundForBao(ctx context.Context, vaultPaths []string) ([]provider.Secret, error) { - var secrets []provider.Secret + var providerSecrets []provider.Secret + for _, factory := range factories { + if factory.ProviderType == vault.ProviderType { + provider, err := factory.Create(ctx, s.appConfig) + if err != nil { + return nil, fmt.Errorf("failed to create provider %s: %w", factory.ProviderType, err) + } - provider, err := newProvider(ctx, vault.ProviderName, s.appConfig) - if err != nil { - return nil, fmt.Errorf("failed to create provider %s: %w", vault.ProviderName, err) - } + secrets, err := provider.LoadSecrets(ctx, vaultPaths) + if err != nil { + return nil, fmt.Errorf("failed to load secrets for provider %s: %w", factory.ProviderType, err) + } - secrets, err = provider.LoadSecrets(ctx, vaultPaths) - if err != nil { - return nil, fmt.Errorf("failed to load secrets for provider %s: %w", vault.ProviderName, err) + providerSecrets = append(providerSecrets, secrets...) + break + } } - return secrets, nil + return providerSecrets, nil } // ConvertProviderSecrets converts the loaded secrets to environment variables @@ -188,111 +197,3 @@ func (s *EnvStore) ConvertProviderSecrets(providerSecrets []provider.Secret) []s return secretsEnv } - -// Returns the detected provider name and path with removed prefix -func getProviderPath(path string) (string, string) { - if strings.HasPrefix(path, "file:") { - return file.ProviderName, path - } - - // If the path contains some string formatted as "vault:{STR}#{STR}" - // it is most probably a vault path - if vault.ProviderEnvRegex.MatchString(path) { - return vault.ProviderName, path - } - - // If the path contains some string formatted as "bao:{STR}#{STR}" - // it is most probably a vault path - if bao.ProviderEnvRegex.MatchString(path) { - return bao.ProviderName, path - } - - // Example AWS prefixes: - // arn:aws:secretsmanager:us-west-2:123456789012:secret:my-secret - // arn:aws:ssm:us-west-2:123456789012:parameter/my-parameter - if strings.HasPrefix(path, "arn:aws:secretsmanager:") || strings.HasPrefix(path, "arn:aws:ssm:") { - return aws.ProviderName, path - } - - // Example GCP prefixes: - // gcp:secretmanager:projects/{PROJECT_ID}/secrets/{SECRET_NAME} - // gcp:secretmanager:projects/{PROJECT_ID}/secrets/{SECRET_NAME}/versions/{VERSION|latest} - if strings.HasPrefix(path, "gcp:secretmanager:") { - return gcp.ProviderName, path - } - - // Example Azure Key Vault secret examples: - // azure:keyvault:{SECRET_NAME} - // azure:keyvault:{SECRET_NAME}/{VERSION} - if strings.HasPrefix(path, "azure:keyvault:") { - return azure.ProviderName, path - } - - return "", path -} - -func newProvider(ctx context.Context, providerName string, appConfig *common.Config) (provider.Provider, error) { - switch providerName { - case file.ProviderName: - config := file.LoadConfig() - provider, err := file.NewProvider(config) - if err != nil { - return nil, fmt.Errorf("failed to create file provider: %w", err) - } - return provider, nil - - case vault.ProviderName: - config, err := vault.LoadConfig() - if err != nil { - return nil, fmt.Errorf("failed to create vault config: %w", err) - } - - provider, err := vault.NewProvider(config, appConfig) - if err != nil { - return nil, fmt.Errorf("failed to create vault provider: %w", err) - } - return provider, nil - - case bao.ProviderName: - config, err := bao.LoadConfig() - if err != nil { - return nil, fmt.Errorf("failed to create bao config: %w", err) - } - - provider, err := bao.NewProvider(config, appConfig) - if err != nil { - return nil, fmt.Errorf("failed to create bao provider: %w", err) - } - return provider, nil - - case aws.ProviderName: - config, err := aws.LoadConfig() - if err != nil { - return nil, fmt.Errorf("failed to create aws config: %w", err) - } - - provider := aws.NewProvider(config) - return provider, nil - - case gcp.ProviderName: - provider, err := gcp.NewProvider(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create gcp provider: %w", err) - } - return provider, nil - - case azure.ProviderName: - config, err := azure.LoadConfig() - if err != nil { - return nil, fmt.Errorf("failed to create azure config: %w", err) - } - provider, err := azure.NewProvider(config) - if err != nil { - return nil, fmt.Errorf("failed to create azure provider: %w", err) - } - return provider, nil - - default: - return nil, fmt.Errorf("provider %s is not supported", providerName) - } -} diff --git a/env_store_test.go b/env_store_test.go index a959358..4594375 100644 --- a/env_store_test.go +++ b/env_store_test.go @@ -102,6 +102,32 @@ func TestEnvStore_GetSecretReferences(t *testing.T) { }, }, }, + { + name: "gcp provider", + envs: map[string]string{ + "GCP_SECRET1": "gcp:secretmanager:projects/my-project/secrets/my-secret/versions/1", + "GCP_SECRET2": "gcp:secretmanager:projects/my-project/secrets/my-secret/versions/latest", + }, + wantPaths: map[string][]string{ + "gcp": { + "GCP_SECRET1=gcp:secretmanager:projects/my-project/secrets/my-secret/versions/1", + "GCP_SECRET2=gcp:secretmanager:projects/my-project/secrets/my-secret/versions/latest", + }, + }, + }, + { + name: "azure provider", + envs: map[string]string{ + "AZURE_SECRET1": "azure:keyvault:my-keyvault/my-secret", + "AZURE_SECRET2": "azure:keyvault:my-keyvault/my-secret/latest", + }, + wantPaths: map[string][]string{ + "azure": { + "AZURE_SECRET1=azure:keyvault:my-keyvault/my-secret", + "AZURE_SECRET2=azure:keyvault:my-keyvault/my-secret/latest", + }, + }, + }, { name: "multi provider", envs: map[string]string{ @@ -112,6 +138,10 @@ func TestEnvStore_GetSecretReferences(t *testing.T) { "RABBITMQ_PASSWORD": "bao:secret/data/test/rabbitmq#RABBITMQ_PASSWORD", "AWS_SECRET1": "arn:aws:secretsmanager:us-west-2:123456789012:secret:my-secret", "AWS_SECRET2": "arn:aws:ssm:us-west-2:123456789012:parameter/my-parameter", + "GCP_SECRET1": "gcp:secretmanager:projects/my-project/secrets/my-secret/versions/1", + "GCP_SECRET2": "gcp:secretmanager:projects/my-project/secrets/my-secret/versions/latest", + "AZURE_SECRET1": "azure:keyvault:my-keyvault/my-secret", + "AZURE_SECRET2": "azure:keyvault:my-keyvault/my-secret/latest", }, wantPaths: map[string][]string{ "file": { @@ -129,6 +159,14 @@ func TestEnvStore_GetSecretReferences(t *testing.T) { "AWS_SECRET1=arn:aws:secretsmanager:us-west-2:123456789012:secret:my-secret", "AWS_SECRET2=arn:aws:ssm:us-west-2:123456789012:parameter/my-parameter", }, + "gcp": { + "GCP_SECRET1=gcp:secretmanager:projects/my-project/secrets/my-secret/versions/1", + "GCP_SECRET2=gcp:secretmanager:projects/my-project/secrets/my-secret/versions/latest", + }, + "azure": { + "AZURE_SECRET1=azure:keyvault:my-keyvault/my-secret", + "AZURE_SECRET2=azure:keyvault:my-keyvault/my-secret/latest", + }, }, }, } @@ -163,7 +201,6 @@ func TestEnvStore_LoadProviderSecrets(t *testing.T) { name string providerPaths map[string][]string wantProviderSecrets []provider.Secret - addvault bool err error }{ { @@ -179,7 +216,6 @@ func TestEnvStore_LoadProviderSecrets(t *testing.T) { Value: "secretId", }, }, - addvault: false, }, { name: "Fail to create provider", @@ -188,15 +224,14 @@ func TestEnvStore_LoadProviderSecrets(t *testing.T) { "AWS_SECRET_ACCESS_KEY_ID=file:" + secretFile, }, }, - addvault: false, - err: fmt.Errorf("failed to create provider invalid: provider invalid is not supported"), + err: fmt.Errorf("failed to create provider invalid: provider invalid is not supported"), }, } for _, tt := range tests { ttp := tt t.Run(ttp.name, func(t *testing.T) { - createEnvsForProvider(ttp.addvault, secretFile) + os.Setenv("AWS_SECRET_ACCESS_KEY_ID", "file:"+secretFile) providerSecrets, err := NewEnvStore(&common.Config{}).LoadProviderSecrets(context.Background(), ttp.providerPaths) if err != nil { @@ -217,7 +252,6 @@ func TestEnvStore_ConvertProviderSecrets(t *testing.T) { name string providerSecrets []provider.Secret wantSecretsEnv []string - addvault bool err error }{ { @@ -231,14 +265,13 @@ func TestEnvStore_ConvertProviderSecrets(t *testing.T) { wantSecretsEnv: []string{ "AWS_SECRET_ACCESS_KEY_ID=secretId", }, - addvault: false, }, } for _, tt := range tests { ttp := tt t.Run(ttp.name, func(t *testing.T) { - createEnvsForProvider(ttp.addvault, secretFile) + os.Setenv("AWS_SECRET_ACCESS_KEY_ID", "file:"+secretFile) secretsEnv := NewEnvStore(&common.Config{}).ConvertProviderSecrets(ttp.providerSecrets) if ttp.wantSecretsEnv != nil { @@ -248,14 +281,6 @@ func TestEnvStore_ConvertProviderSecrets(t *testing.T) { } } -func createEnvsForProvider(addVault bool, secretFile string) { - os.Setenv("AWS_SECRET_ACCESS_KEY_ID", "file:"+secretFile) - if addVault { - os.Setenv("MYSQL_PASSWORD", "vault:secret/data/test/mysql#MYSQL_PASSWORD") - os.Setenv("AWS_SECRET_ACCESS_KEY", "vault:secret/data/test/aws#AWS_SECRET_ACCESS_KEY") - } -} - func newSecretFile(t *testing.T, content string) string { dir := t.TempDir() + "/test/secrets" err := os.MkdirAll(dir, 0o755) diff --git a/main.go b/main.go index 9874985..d536048 100644 --- a/main.go +++ b/main.go @@ -54,9 +54,7 @@ func main() { // Fetch all provider secrets and assemble env variables using envstore envStore := NewEnvStore(config) - secretReferences := envStore.GetSecretReferences() - - providerSecrets, err := envStore.LoadProviderSecrets(context.Background(), secretReferences) + providerSecrets, err := envStore.LoadProviderSecrets(context.Background(), envStore.GetSecretReferences()) if err != nil { slog.Error(fmt.Errorf("failed to extract secrets: %w", err).Error()) os.Exit(1) diff --git a/pkg/provider/aws/aws.go b/pkg/provider/aws/aws.go index adb0d2a..dcfb390 100644 --- a/pkg/provider/aws/aws.go +++ b/pkg/provider/aws/aws.go @@ -24,21 +24,31 @@ import ( "github.com/aws/aws-sdk-go/service/secretsmanager" "github.com/aws/aws-sdk-go/service/ssm" + "github.com/bank-vaults/secret-init/pkg/common" "github.com/bank-vaults/secret-init/pkg/provider" ) -var ProviderName = "aws" +const ( + ProviderType = "aws" + referenceSelectorSM = "arn:aws:secretsmanager:" + referenceSelectorSSM = "arn:aws:ssm:" +) type Provider struct { sm *secretsmanager.SecretsManager ssm *ssm.SSM } -func NewProvider(config *Config) *Provider { +func NewProvider(_ context.Context, _ *common.Config) (provider.Provider, error) { + config, err := LoadConfig() + if err != nil { + return nil, fmt.Errorf("failed to create vault config: %w", err) + } + return &Provider{ sm: secretsmanager.New(config.session), ssm: ssm.New(config.session), - } + }, nil } func (p *Provider) LoadSecrets(ctx context.Context, paths []string) ([]provider.Secret, error) { @@ -101,6 +111,13 @@ func (p *Provider) LoadSecrets(ctx context.Context, paths []string) ([]provider. return secrets, nil } +// Example AWS prefixes: +// arn:aws:secretsmanager:us-west-2:123456789012:secret:my-secret +// arn:aws:ssm:us-west-2:123456789012:parameter/my-parameter +func Valid(envValue string) bool { + return strings.HasPrefix(envValue, referenceSelectorSM) || strings.HasPrefix(envValue, referenceSelectorSSM) +} + // AWS Secrets Manager can store secrets in two formats: // - SecretString: for text-based secrets, returned as a byte slice. // - SecretBinary: for binary secrets, returned as a byte slice without additional encoding. diff --git a/pkg/provider/aws/aws_test.go b/pkg/provider/aws/aws_test.go deleted file mode 100644 index 08dad0d..0000000 --- a/pkg/provider/aws/aws_test.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright © 2024 Bank-Vaults Maintainers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package aws - -import ( - "testing" - - "github.com/aws/aws-sdk-go/aws/session" - "github.com/stretchr/testify/assert" -) - -func TestNewProvider(t *testing.T) { - tests := []struct { - name string - config *Config - wantType bool - }{ - { - name: "Valid config", - config: &Config{ - session: createSession(), - }, - wantType: true, - }, - } - - for _, tt := range tests { - ttp := tt - t.Run(ttp.name, func(t *testing.T) { - provider := NewProvider(ttp.config) - if ttp.wantType { - assert.Equal(t, ttp.wantType, provider != nil, "Unexpected provider type") - } - }) - } -} - -func createSession() *session.Session { - return session.Must(session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigDisable, - })) -} diff --git a/pkg/provider/azure/azure.go b/pkg/provider/azure/azure.go index b4435a4..3260cd0 100644 --- a/pkg/provider/azure/azure.go +++ b/pkg/provider/azure/azure.go @@ -22,16 +22,25 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azsecrets" + "github.com/bank-vaults/secret-init/pkg/common" "github.com/bank-vaults/secret-init/pkg/provider" ) -var ProviderName = "azure" +const ( + ProviderType = "azure" + referenceSelector = "azure:keyvault:" +) type Provider struct { client *azsecrets.Client } -func NewProvider(config *Config) (*Provider, error) { +func NewProvider(_ context.Context, _ *common.Config) (provider.Provider, error) { + config, err := LoadConfig() + if err != nil { + return nil, fmt.Errorf("failed to create vault config: %w", err) + } + creds, err := azidentity.NewDefaultAzureCredential(nil) if err != nil { return nil, fmt.Errorf("failed to create default azure credentials: %v", err) @@ -78,3 +87,10 @@ func (p *Provider) LoadSecrets(ctx context.Context, paths []string) ([]provider. return secrets, nil } + +// Example Azure Key Vault secret examples: +// azure:keyvault:{SECRET_NAME} +// azure:keyvault:{SECRET_NAME}/{VERSION} +func Valid(envValue string) bool { + return strings.HasPrefix(envValue, referenceSelector) +} diff --git a/pkg/provider/bao/bao.go b/pkg/provider/bao/bao.go index a7a2893..6d98061 100644 --- a/pkg/provider/bao/bao.go +++ b/pkg/provider/bao/bao.go @@ -30,9 +30,9 @@ import ( "github.com/bank-vaults/secret-init/pkg/provider" ) -var ( - ProviderName = "bao" - ProviderEnvRegex = regexp.MustCompile(`(bao:)(.*)#(.*)`) +const ( + ProviderType = "bao" + referenceSelector = `(bao:)(.*)#(.*)` ) type Provider struct { @@ -66,7 +66,12 @@ func (s *sanitized) append(key string, value string) { } } -func NewProvider(config *Config, appConfig *common.Config) (*Provider, error) { +func NewProvider(_ context.Context, appConfig *common.Config) (provider.Provider, error) { + config, err := LoadConfig() + if err != nil { + return nil, fmt.Errorf("failed to create vault config: %w", err) + } + clientOptions := []bao.ClientOption{bao.ClientLogger(clientLogger{slog.Default()})} if config.TokenFile != "" { clientOptions = append(clientOptions, bao.ClientToken(config.Token)) @@ -118,7 +123,7 @@ func NewProvider(config *Config, appConfig *common.Config) (*Provider, error) { // and the value is the secret value // E.g. paths: MYSQL_PASSWORD=secret/data/mysql/password // returns: []provider.Secret{provider.Secret{Path: "MYSQL_PASSWORD", Value: "password"}} -func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Secret, error) { +func (p *Provider) LoadSecrets(ctx context.Context, paths []string) ([]provider.Secret, error) { sanitized := sanitized{login: p.isLogin} baoEnviron := parsePathsToMap(paths) @@ -141,7 +146,7 @@ func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Se if p.revokeToken { // ref: https://www.vaultproject.io/api/auth/token/index.html#revoke-a-token-self- - err := p.client.RawClient().Auth().Token().RevokeSelf(p.client.RawClient().Token()) + err := p.client.RawClient().Auth().Token().RevokeSelfWithContext(ctx, p.client.RawClient().Token()) if err != nil { // Do not exit on error, token revoking can be denied by policy slog.Warn("failed to revoke token") @@ -153,6 +158,12 @@ func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Se return sanitized.secrets, nil } +// If the path contains some string formatted as "bao:{STR}#{STR}" +// it is most probably a vault path +func Valid(envValue string) bool { + return regexp.MustCompile(referenceSelector).MatchString(envValue) +} + func parsePathsToMap(paths []string) map[string]string { baoEnviron := make(map[string]string) diff --git a/pkg/provider/bao/bao_test.go b/pkg/provider/bao/bao_test.go deleted file mode 100644 index 99bd494..0000000 --- a/pkg/provider/bao/bao_test.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright © 2024 Bank-Vaults Maintainers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bao - -import ( - "fmt" - "io" - "log/slog" - "os" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/bank-vaults/secret-init/pkg/common" -) - -var originalLogger *slog.Logger - -func TestMain(m *testing.M) { - setupTestLogger() - code := m.Run() - restoreLogger() - os.Exit(code) -} - -func TestNewProvider(t *testing.T) { - tests := []struct { - name string - config *Config - err error - wantType bool - }{ - { - name: "Valid Provider with Token", - config: &Config{ - IsLogin: true, - TokenFile: "root", - Token: "root", - TransitKeyID: "test-key", - TransitPath: "transit", - TransitBatchSize: 10, - IgnoreMissingSecrets: true, - FromPath: "secret/data/test", - RevokeToken: true, - }, - wantType: true, - }, - { - name: "Valid Provider with bao:login as Token and daemon mode", - config: &Config{ - IsLogin: true, - Token: baoLogin, - TokenFile: "root", - IgnoreMissingSecrets: true, - FromPath: "secret/data/test", - }, - wantType: true, - }, - { - name: "Fail to create bao client due to timeout", - config: &Config{}, - err: fmt.Errorf("failed to create bao client: timeout [10s] during waiting for Vault token"), - }, - } - - for _, tt := range tests { - ttp := tt - - t.Run(ttp.name, func(t *testing.T) { - provider, err := NewProvider(ttp.config, &common.Config{}) - if err != nil { - assert.EqualError(t, ttp.err, err.Error(), "Unexpected error message") - } - if ttp.wantType { - assert.Equal(t, ttp.wantType, provider != nil, "Unexpected provider type") - } - }) - } -} - -func setupTestLogger() { - originalLogger = slog.Default() - - // Discard logs to avoid polluting the test output - testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - slog.SetDefault(testLogger) -} - -func restoreLogger() { - slog.SetDefault(originalLogger) -} diff --git a/pkg/provider/bao/config.go b/pkg/provider/bao/config.go index b46e6eb..427ca66 100644 --- a/pkg/provider/bao/config.go +++ b/pkg/provider/bao/config.go @@ -28,37 +28,37 @@ const ( // which was acquired during the bao client initialization. baoLogin = "bao:login" - TokenEnv = "BAO_TOKEN" - TokenFileEnv = "BAO_TOKEN_FILE" - AddrEnv = "BAO_ADDR" - AgentAddrEnv = "BAO_AGENT_ADDR" - CACertEnv = "BAO_CACERT" - CAPathEnv = "BAO_CAPATH" - ClientCertEnv = "BAO_CLIENT_CERT" - ClientKeyEnv = "BAO_CLIENT_KEY" - ClientTimeoutEnv = "BAO_CLIENT_TIMEOUT" - SRVLookupEnv = "BAO_SRV_LOOKUP" - SkipVerifyEnv = "BAO_SKIP_VERIFY" - NamespaceEnv = "BAO_NAMESPACE" - TLSServerNameEnv = "BAO_TLS_SERVER_NAME" - WrapTTLEnv = "BAO_WRAP_TTL" - MFAEnv = "BAO_MFA" - MaxRetriesEnv = "BAO_MAX_RETRIES" - ClusterAddrEnv = "BAO_CLUSTER_ADDR" - RedirectAddrEnv = "BAO_REDIRECT_ADDR" - CLINoColorEnv = "BAO_CLI_NO_COLOR" - RateLimitEnv = "BAO_RATE_LIMIT" - RoleEnv = "BAO_ROLE" - PathEnv = "BAO_PATH" - AuthMethodEnv = "BAO_AUTH_METHOD" - TransitKeyIDEnv = "BAO_TRANSIT_KEY_ID" - TransitPathEnv = "BAO_TRANSIT_PATH" - TransitBatchSizeEnv = "BAO_TRANSIT_BATCH_SIZE" - IgnoreMissingSecretsEnv = "BAO_IGNORE_MISSING_SECRETS" - PassthroughEnv = "BAO_PASSTHROUGH" - LogLevelEnv = "BAO_LOG_LEVEL" - RevokeTokenEnv = "BAO_REVOKE_TOKEN" - FromPathEnv = "BAO_FROM_PATH" + tokenEnv = "BAO_TOKEN" + tokenFileEnv = "BAO_TOKEN_FILE" + addrEnv = "BAO_ADDR" + agentAddrEnv = "BAO_AGENT_ADDR" + caCertEnv = "BAO_CACERT" + caPathEnv = "BAO_CAPATH" + clientCertEnv = "BAO_CLIENT_CERT" + clientKeyEnv = "BAO_CLIENT_KEY" + clientTimeoutEnv = "BAO_CLIENT_TIMEOUT" + srvLookupEnv = "BAO_SRV_LOOKUP" + skipVerifyEnv = "BAO_SKIP_VERIFY" + namespaceEnv = "BAO_NAMESPACE" + tlsServerNameEnv = "BAO_TLS_SERVER_NAME" + wrapTTLEnv = "BAO_WRAP_TTL" + mfaEnv = "BAO_MFA" + maxRetriesEnv = "BAO_MAX_RETRIES" + clusterAddrEnv = "BAO_CLUSTER_ADDR" + redirectAddrEnv = "BAO_REDIRECT_ADDR" + cliNoColorEnv = "BAO_CLI_NO_COLOR" + rateLimitEnv = "BAO_RATE_LIMIT" + roleEnv = "BAO_ROLE" + pathEnv = "BAO_PATH" + authMethodEnv = "BAO_AUTH_METHOD" + transitKeyIDEnv = "BAO_TRANSIT_KEY_ID" + transitPathEnv = "BAO_TRANSIT_PATH" + transitBatchSizeEnv = "BAO_TRANSIT_BATCH_SIZE" + ignoreMissingSecretsEnv = "BAO_IGNORE_MISSING_SECRETS" + passthroughEnv = "BAO_PASSTHROUGH" + logLevelEnv = "BAO_LOG_LEVEL" + revokeTokenEnv = "BAO_REVOKE_TOKEN" + fromPathEnv = "BAO_FROM_PATH" ) type Config struct { @@ -81,36 +81,36 @@ type envType struct { } var sanitizeEnvmap = map[string]envType{ - TokenEnv: {login: true}, - AddrEnv: {login: true}, - AgentAddrEnv: {login: true}, - CACertEnv: {login: true}, - CAPathEnv: {login: true}, - ClientCertEnv: {login: true}, - ClientKeyEnv: {login: true}, - ClientTimeoutEnv: {login: true}, - SRVLookupEnv: {login: true}, - SkipVerifyEnv: {login: true}, - NamespaceEnv: {login: true}, - TLSServerNameEnv: {login: true}, - WrapTTLEnv: {login: true}, - MFAEnv: {login: true}, - MaxRetriesEnv: {login: true}, - ClusterAddrEnv: {login: false}, - RedirectAddrEnv: {login: false}, - CLINoColorEnv: {login: false}, - RateLimitEnv: {login: false}, - RoleEnv: {login: false}, - PathEnv: {login: false}, - AuthMethodEnv: {login: false}, - TransitKeyIDEnv: {login: false}, - TransitPathEnv: {login: false}, - TransitBatchSizeEnv: {login: false}, - IgnoreMissingSecretsEnv: {login: false}, - PassthroughEnv: {login: false}, - LogLevelEnv: {login: false}, - RevokeTokenEnv: {login: false}, - FromPathEnv: {login: false}, + tokenEnv: {login: true}, + addrEnv: {login: true}, + agentAddrEnv: {login: true}, + caCertEnv: {login: true}, + caPathEnv: {login: true}, + clientCertEnv: {login: true}, + clientKeyEnv: {login: true}, + clientTimeoutEnv: {login: true}, + srvLookupEnv: {login: true}, + skipVerifyEnv: {login: true}, + namespaceEnv: {login: true}, + tlsServerNameEnv: {login: true}, + wrapTTLEnv: {login: true}, + mfaEnv: {login: true}, + maxRetriesEnv: {login: true}, + clusterAddrEnv: {login: false}, + redirectAddrEnv: {login: false}, + cliNoColorEnv: {login: false}, + rateLimitEnv: {login: false}, + roleEnv: {login: false}, + pathEnv: {login: false}, + authMethodEnv: {login: false}, + transitKeyIDEnv: {login: false}, + transitPathEnv: {login: false}, + transitBatchSizeEnv: {login: false}, + ignoreMissingSecretsEnv: {login: false}, + passthroughEnv: {login: false}, + logLevelEnv: {login: false}, + revokeTokenEnv: {login: false}, + fromPathEnv: {login: false}, } func LoadConfig() (*Config, error) { @@ -122,15 +122,15 @@ func LoadConfig() (*Config, error) { // This workaround is necessary because the BAO_ADDR // is not yet used directly by the Bao client. // This is why env_store.go/workaroundForBao() has been implemented. - baoAddr := os.Getenv(AddrEnv) + baoAddr := os.Getenv(addrEnv) os.Setenv("VAULT_ADDR", baoAddr) // The login procedure takes the token from a file (if using Bao Agent) // or requests one for itself (Kubernetes Auth, or GCP, etc...), // so if we got a BAO_TOKEN for the special value with "bao:login" - baoToken := os.Getenv(TokenEnv) + baoToken := os.Getenv(tokenEnv) isLogin := baoToken == baoLogin - tokenFile, ok := os.LookupEnv(TokenFileEnv) + tokenFile, ok := os.LookupEnv(tokenFileEnv) if ok { // load token from bao-agent .bao-token or injected webhook tokenFileContent, err := os.ReadFile(tokenFile) @@ -140,28 +140,28 @@ func LoadConfig() (*Config, error) { baoToken = string(tokenFileContent) } else { if isLogin { - _ = os.Unsetenv(TokenEnv) + _ = os.Unsetenv(tokenEnv) } // will use role/path based authentication - role, hasRole = os.LookupEnv(RoleEnv) + role, hasRole = os.LookupEnv(roleEnv) if !hasRole { - return nil, fmt.Errorf("incomplete authentication configuration: %s missing", RoleEnv) + return nil, fmt.Errorf("incomplete authentication configuration: %s missing", roleEnv) } - authPath, hasPath = os.LookupEnv(PathEnv) + authPath, hasPath = os.LookupEnv(pathEnv) if !hasPath { - return nil, fmt.Errorf("incomplete authentication configuration: %s missing", PathEnv) + return nil, fmt.Errorf("incomplete authentication configuration: %s missing", pathEnv) } - authMethod, hasAuthMethod = os.LookupEnv(AuthMethodEnv) + authMethod, hasAuthMethod = os.LookupEnv(authMethodEnv) if !hasAuthMethod { - return nil, fmt.Errorf("incomplete authentication configuration: %s missing", AuthMethodEnv) + return nil, fmt.Errorf("incomplete authentication configuration: %s missing", authMethodEnv) } } - passthroughEnvVars := strings.Split(os.Getenv(PassthroughEnv), ",") + passthroughEnvVars := strings.Split(os.Getenv(passthroughEnv), ",") if isLogin { - _ = os.Setenv(TokenEnv, baoLogin) - passthroughEnvVars = append(passthroughEnvVars, TokenEnv) + _ = os.Setenv(tokenEnv, baoLogin) + passthroughEnvVars = append(passthroughEnvVars, tokenEnv) } // do not sanitize env vars specified in BAO_PASSTHROUGH @@ -178,11 +178,11 @@ func LoadConfig() (*Config, error) { Role: role, AuthPath: authPath, AuthMethod: authMethod, - TransitKeyID: os.Getenv(TransitKeyIDEnv), - TransitPath: os.Getenv(TransitPathEnv), - TransitBatchSize: cast.ToInt(os.Getenv(TransitBatchSizeEnv)), - IgnoreMissingSecrets: cast.ToBool(os.Getenv(IgnoreMissingSecretsEnv)), // Used both for reading secrets and transit encryption - FromPath: os.Getenv(FromPathEnv), - RevokeToken: cast.ToBool(os.Getenv(RevokeTokenEnv)), + TransitKeyID: os.Getenv(transitKeyIDEnv), + TransitPath: os.Getenv(transitPathEnv), + TransitBatchSize: cast.ToInt(os.Getenv(transitBatchSizeEnv)), + IgnoreMissingSecrets: cast.ToBool(os.Getenv(ignoreMissingSecretsEnv)), // Used both for reading secrets and transit encryption + FromPath: os.Getenv(fromPathEnv), + RevokeToken: cast.ToBool(os.Getenv(revokeTokenEnv)), }, nil } diff --git a/pkg/provider/bao/config_test.go b/pkg/provider/bao/config_test.go index a068a30..a25b6e5 100644 --- a/pkg/provider/bao/config_test.go +++ b/pkg/provider/bao/config_test.go @@ -35,15 +35,15 @@ func TestConfig(t *testing.T) { { name: "Valid login configuration with Token", env: map[string]string{ - TokenEnv: baoLogin, - TokenFileEnv: tokenFile, - PassthroughEnv: AgentAddrEnv + ", " + CLINoColorEnv, - TransitKeyIDEnv: "test-key", - TransitPathEnv: "transit", - TransitBatchSizeEnv: "10", - IgnoreMissingSecretsEnv: "true", - RevokeTokenEnv: "true", - FromPathEnv: "secret/data/test", + tokenEnv: baoLogin, + tokenFileEnv: tokenFile, + passthroughEnv: agentAddrEnv + ", " + cliNoColorEnv, + transitKeyIDEnv: "test-key", + transitPathEnv: "transit", + transitBatchSizeEnv: "10", + ignoreMissingSecretsEnv: "true", + revokeTokenEnv: "true", + fromPathEnv: "secret/data/test", }, wantConfig: &Config{ IsLogin: true, @@ -60,10 +60,10 @@ func TestConfig(t *testing.T) { { name: "Valid login configuration with Role and Path", env: map[string]string{ - TokenEnv: baoLogin, - RoleEnv: "test-app-role", - PathEnv: "auth/approle/test/login", - AuthMethodEnv: "test-approle", + tokenEnv: baoLogin, + roleEnv: "test-app-role", + pathEnv: "auth/approle/test/login", + authMethodEnv: "test-approle", }, wantConfig: &Config{ IsLogin: true, @@ -76,31 +76,31 @@ func TestConfig(t *testing.T) { { name: "Invalid login configuration using tokenfile - missing token file", env: map[string]string{ - TokenFileEnv: tokenFile + "/invalid", + tokenFileEnv: tokenFile + "/invalid", }, err: fmt.Errorf("failed to read token file %s/invalid: open %s/invalid: not a directory", tokenFile, tokenFile), }, { name: "Invalid login configuration using role/path - missing role", env: map[string]string{ - PathEnv: "auth/approle/test/login", - AuthMethodEnv: "k8s", + pathEnv: "auth/approle/test/login", + authMethodEnv: "k8s", }, err: fmt.Errorf("incomplete authentication configuration: BAO_ROLE missing"), }, { name: "Invalid login configuration using role/path - missing path", env: map[string]string{ - RoleEnv: "test-app-role", - AuthMethodEnv: "k8s", + roleEnv: "test-app-role", + authMethodEnv: "k8s", }, err: fmt.Errorf("incomplete authentication configuration: BAO_PATH missing"), }, { name: "Invalid login configuration using role/path - missing auth method", env: map[string]string{ - RoleEnv: "test-app-role", - PathEnv: "auth/approle/test/login", + roleEnv: "test-app-role", + pathEnv: "auth/approle/test/login", }, err: fmt.Errorf("incomplete authentication configuration: BAO_AUTH_METHOD missing"), }, diff --git a/pkg/provider/file/file.go b/pkg/provider/file/file.go index 0b2fdeb..0c57773 100644 --- a/pkg/provider/file/file.go +++ b/pkg/provider/file/file.go @@ -21,16 +21,22 @@ import ( "os" "strings" + "github.com/bank-vaults/secret-init/pkg/common" "github.com/bank-vaults/secret-init/pkg/provider" ) -const ProviderName = "file" +const ( + ProviderType = "file" + referenceSelector = "file:" +) type Provider struct { fs fs.FS } -func NewProvider(config *Config) (provider.Provider, error) { +func NewProvider(_ context.Context, _ *common.Config) (provider.Provider, error) { + config := LoadConfig() + // Check whether the path exists fileInfo, err := os.Stat(config.MountPath) if err != nil { @@ -66,6 +72,10 @@ func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Se return secrets, nil } +func Valid(envValue string) bool { + return strings.HasPrefix(envValue, referenceSelector) +} + func (p *Provider) getSecretFromFile(valuePath string) (string, error) { valuePath = strings.TrimLeft(valuePath, "/") content, err := fs.ReadFile(p.fs, valuePath) diff --git a/pkg/provider/file/file_test.go b/pkg/provider/file/file_test.go index 51bbe44..f3139a6 100644 --- a/pkg/provider/file/file_test.go +++ b/pkg/provider/file/file_test.go @@ -17,8 +17,6 @@ package file import ( "context" "fmt" - "io/fs" - "os" "testing" "testing/fstest" @@ -27,60 +25,6 @@ import ( "github.com/bank-vaults/secret-init/pkg/provider" ) -func TestNewProvider(t *testing.T) { - tempDir := t.TempDir() - secretFile := newSecretFile(t, "3xtr3ms3cr3t") - defer os.Remove(secretFile) - - tests := []struct { - name string - config *Config - err error - wantType bool - wantFs fs.FS - }{ - { - name: "Valid config - directory", - config: &Config{ - MountPath: tempDir, - }, - wantType: true, - wantFs: os.DirFS(tempDir), - }, - { - name: "Invalid config - directory does not exist", - config: &Config{ - MountPath: "test/secrets/invalid", - }, - err: fmt.Errorf("failed to access path: stat test/secrets/invalid: no such file or directory"), - }, - { - name: "Invalid config - file instead of directory", - config: &Config{ - MountPath: secretFile, - }, - err: fmt.Errorf("provided path is not a directory"), - }, - } - - for _, tt := range tests { - ttp := tt - t.Run(ttp.name, func(t *testing.T) { - provider, err := NewProvider(ttp.config) - if err != nil { - assert.EqualError(t, err, ttp.err.Error(), "Unexpected error message") - } - if ttp.wantType { - assert.Equal(t, ttp.wantType, provider != nil, "Unexpected provider type") - - if ttp.wantFs != nil { - assert.Equal(t, ttp.wantFs, provider.(*Provider).fs, "Unexpected file system") - } - } - }) - } -} - func TestLoadSecrets(t *testing.T) { tests := []struct { name string @@ -131,18 +75,3 @@ func TestLoadSecrets(t *testing.T) { }) } } - -func newSecretFile(t *testing.T, content string) string { - dir := t.TempDir() + "/test/secrets" - err := os.MkdirAll(dir, 0o755) - assert.Nil(t, err, "Failed to create directory") - - file, err := os.CreateTemp(dir, "secret.txt") - assert.Nil(t, err, "Failed to create a temporary file") - defer file.Close() - - _, err = file.WriteString(content) - assert.Nil(t, err, "Failed to write to the temporary file") - - return file.Name() -} diff --git a/pkg/provider/gcp/gcp.go b/pkg/provider/gcp/gcp.go index fa913fb..913a818 100644 --- a/pkg/provider/gcp/gcp.go +++ b/pkg/provider/gcp/gcp.go @@ -23,16 +23,21 @@ import ( secretmanager "cloud.google.com/go/secretmanager/apiv1" "cloud.google.com/go/secretmanager/apiv1/secretmanagerpb" + "github.com/bank-vaults/secret-init/pkg/common" "github.com/bank-vaults/secret-init/pkg/provider" ) -var ProviderName = "gcp" +const ( + ProviderType = "gcp" + referenceSelector = "gcp:secretmanager:" + versionRegex = `.*/versions/(latest|\d+)$` +) type Provider struct { client *secretmanager.Client } -func NewProvider(ctx context.Context) (*Provider, error) { +func NewProvider(ctx context.Context, _ *common.Config) (provider.Provider, error) { // This will automatically use the Application Default Credentials (ADC) strategy for authentication. // If the GOOGLE_APPLICATION_CREDENTIALS environment variable is set, // the client will use the service account key JSON file that the variable points to. @@ -85,9 +90,16 @@ func (p *Provider) LoadSecrets(ctx context.Context, paths []string) ([]provider. return secrets, nil } +// Example GCP prefixes: +// gcp:secretmanager:projects/{PROJECT_ID}/secrets/{SECRET_NAME} +// gcp:secretmanager:projects/{PROJECT_ID}/secrets/{SECRET_NAME}/versions/{VERSION|latest} +func Valid(envValue string) bool { + return strings.HasPrefix(envValue, referenceSelector) +} + func handleVersion(secretID string) (string, error) { // If the version is correctly specified, return the secretID as is - match, err := regexp.MatchString(`.*/versions/(latest|\d+)$`, secretID) + match, err := regexp.MatchString(versionRegex, secretID) if err != nil { return "", fmt.Errorf("failed to match secret ID with regex: %v", err) } diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index 65b0cd9..5d7b61e 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -14,10 +14,21 @@ package provider -import "context" +import ( + "context" + + "github.com/bank-vaults/secret-init/pkg/common" +) + +type Factory struct { + ProviderType string + Validator func(envValue string) bool + Create func(ctx context.Context, cfg *common.Config) (Provider, error) +} // Provider is an interface for securely loading secrets based on environment variables. type Provider interface { + // LoadSecrets loads secrets from the provider based on the given paths LoadSecrets(ctx context.Context, paths []string) ([]Secret, error) } diff --git a/pkg/provider/vault/config.go b/pkg/provider/vault/config.go index 800db50..5b5711b 100644 --- a/pkg/provider/vault/config.go +++ b/pkg/provider/vault/config.go @@ -28,37 +28,37 @@ const ( // which was acquired during the vault client initialization. vaultLogin = "vault:login" - TokenEnv = "VAULT_TOKEN" - TokenFileEnv = "VAULT_TOKEN_FILE" - AddrEnv = "VAULT_ADDR" - AgentAddrEnv = "VAULT_AGENT_ADDR" - CACertEnv = "VAULT_CACERT" - CAPathEnv = "VAULT_CAPATH" - ClientCertEnv = "VAULT_CLIENT_CERT" - ClientKeyEnv = "VAULT_CLIENT_KEY" - ClientTimeoutEnv = "VAULT_CLIENT_TIMEOUT" - SRVLookupEnv = "VAULT_SRV_LOOKUP" - SkipVerifyEnv = "VAULT_SKIP_VERIFY" - NamespaceEnv = "VAULT_NAMESPACE" - TLSServerNameEnv = "VAULT_TLS_SERVER_NAME" - WrapTTLEnv = "VAULT_WRAP_TTL" - MFAEnv = "VAULT_MFA" - MaxRetriesEnv = "VAULT_MAX_RETRIES" - ClusterAddrEnv = "VAULT_CLUSTER_ADDR" - RedirectAddrEnv = "VAULT_REDIRECT_ADDR" - CLINoColorEnv = "VAULT_CLI_NO_COLOR" - RateLimitEnv = "VAULT_RATE_LIMIT" - RoleEnv = "VAULT_ROLE" - PathEnv = "VAULT_PATH" - AuthMethodEnv = "VAULT_AUTH_METHOD" - TransitKeyIDEnv = "VAULT_TRANSIT_KEY_ID" - TransitPathEnv = "VAULT_TRANSIT_PATH" - TransitBatchSizeEnv = "VAULT_TRANSIT_BATCH_SIZE" - IgnoreMissingSecretsEnv = "VAULT_IGNORE_MISSING_SECRETS" - PassthroughEnv = "VAULT_PASSTHROUGH" - LogLevelEnv = "VAULT_LOG_LEVEL" - RevokeTokenEnv = "VAULT_REVOKE_TOKEN" - FromPathEnv = "VAULT_FROM_PATH" + tokenEnv = "VAULT_TOKEN" + tokenFileEnv = "VAULT_TOKEN_FILE" + addrEnv = "VAULT_ADDR" + agentAddrEnv = "VAULT_AGENT_ADDR" + caCertEnv = "VAULT_CACERT" + caPathEnv = "VAULT_CAPATH" + clientCertEnv = "VAULT_CLIENT_CERT" + clientKeyEnv = "VAULT_CLIENT_KEY" + clientTimeoutEnv = "VAULT_CLIENT_TIMEOUT" + srvLookupEnv = "VAULT_SRV_LOOKUP" + skipVerifyEnv = "VAULT_SKIP_VERIFY" + namespaceEnv = "VAULT_NAMESPACE" + tlsServerNameEnv = "VAULT_TLS_SERVER_NAME" + wrapTTLEnv = "VAULT_WRAP_TTL" + mfaEnv = "VAULT_MFA" + maxRetriesEnv = "VAULT_MAX_RETRIES" + clusterAddrEnv = "VAULT_CLUSTER_ADDR" + redirectAddrEnv = "VAULT_REDIRECT_ADDR" + cliNoColorEnv = "VAULT_CLI_NO_COLOR" + rateLimitEnv = "VAULT_RATE_LIMIT" + roleEnv = "VAULT_ROLE" + pathEnv = "VAULT_PATH" + authMethodEnv = "VAULT_AUTH_METHOD" + transitKeyIDEnv = "VAULT_TRANSIT_KEY_ID" + transitPathEnv = "VAULT_TRANSIT_PATH" + transitBatchSizeEnv = "VAULT_TRANSIT_BATCH_SIZE" + ignoreMissingSecretsEnv = "VAULT_IGNORE_MISSING_SECRETS" + passthroughEnv = "VAULT_PASSTHROUGH" + logLevelEnv = "VAULT_LOG_LEVEL" + revokeTokenEnv = "VAULT_REVOKE_TOKEN" + fromPathEnv = "VAULT_FROM_PATH" ) type Config struct { @@ -81,36 +81,36 @@ type envType struct { } var sanitizeEnvmap = map[string]envType{ - TokenEnv: {login: true}, - AddrEnv: {login: true}, - AgentAddrEnv: {login: true}, - CACertEnv: {login: true}, - CAPathEnv: {login: true}, - ClientCertEnv: {login: true}, - ClientKeyEnv: {login: true}, - ClientTimeoutEnv: {login: true}, - SRVLookupEnv: {login: true}, - SkipVerifyEnv: {login: true}, - NamespaceEnv: {login: true}, - TLSServerNameEnv: {login: true}, - WrapTTLEnv: {login: true}, - MFAEnv: {login: true}, - MaxRetriesEnv: {login: true}, - ClusterAddrEnv: {login: false}, - RedirectAddrEnv: {login: false}, - CLINoColorEnv: {login: false}, - RateLimitEnv: {login: false}, - RoleEnv: {login: false}, - PathEnv: {login: false}, - AuthMethodEnv: {login: false}, - TransitKeyIDEnv: {login: false}, - TransitPathEnv: {login: false}, - TransitBatchSizeEnv: {login: false}, - IgnoreMissingSecretsEnv: {login: false}, - PassthroughEnv: {login: false}, - LogLevelEnv: {login: false}, - RevokeTokenEnv: {login: false}, - FromPathEnv: {login: false}, + tokenEnv: {login: true}, + addrEnv: {login: true}, + agentAddrEnv: {login: true}, + caCertEnv: {login: true}, + caPathEnv: {login: true}, + clientCertEnv: {login: true}, + clientKeyEnv: {login: true}, + clientTimeoutEnv: {login: true}, + srvLookupEnv: {login: true}, + skipVerifyEnv: {login: true}, + namespaceEnv: {login: true}, + tlsServerNameEnv: {login: true}, + wrapTTLEnv: {login: true}, + mfaEnv: {login: true}, + maxRetriesEnv: {login: true}, + clusterAddrEnv: {login: false}, + redirectAddrEnv: {login: false}, + cliNoColorEnv: {login: false}, + rateLimitEnv: {login: false}, + roleEnv: {login: false}, + pathEnv: {login: false}, + authMethodEnv: {login: false}, + transitKeyIDEnv: {login: false}, + transitPathEnv: {login: false}, + transitBatchSizeEnv: {login: false}, + ignoreMissingSecretsEnv: {login: false}, + passthroughEnv: {login: false}, + logLevelEnv: {login: false}, + revokeTokenEnv: {login: false}, + fromPathEnv: {login: false}, } func LoadConfig() (*Config, error) { @@ -122,9 +122,9 @@ func LoadConfig() (*Config, error) { // The login procedure takes the token from a file (if using Vault Agent) // or requests one for itself (Kubernetes Auth, or GCP, etc...), // so if we got a VAULT_TOKEN for the special value with "vault:login" - vaultToken := os.Getenv(TokenEnv) + vaultToken := os.Getenv(tokenEnv) isLogin := vaultToken == vaultLogin - tokenFile, ok := os.LookupEnv(TokenFileEnv) + tokenFile, ok := os.LookupEnv(tokenFileEnv) if ok { // load token from vault-agent .vault-token or injected webhook tokenFileContent, err := os.ReadFile(tokenFile) @@ -134,28 +134,28 @@ func LoadConfig() (*Config, error) { vaultToken = string(tokenFileContent) } else { if isLogin { - _ = os.Unsetenv(TokenEnv) + _ = os.Unsetenv(tokenEnv) } // will use role/path based authentication - role, hasRole = os.LookupEnv(RoleEnv) + role, hasRole = os.LookupEnv(roleEnv) if !hasRole { - return nil, fmt.Errorf("incomplete authentication configuration: %s missing", RoleEnv) + return nil, fmt.Errorf("incomplete authentication configuration: %s missing", roleEnv) } - authPath, hasPath = os.LookupEnv(PathEnv) + authPath, hasPath = os.LookupEnv(pathEnv) if !hasPath { - return nil, fmt.Errorf("incomplete authentication configuration: %s missing", PathEnv) + return nil, fmt.Errorf("incomplete authentication configuration: %s missing", pathEnv) } - authMethod, hasAuthMethod = os.LookupEnv(AuthMethodEnv) + authMethod, hasAuthMethod = os.LookupEnv(authMethodEnv) if !hasAuthMethod { - return nil, fmt.Errorf("incomplete authentication configuration: %s missing", AuthMethodEnv) + return nil, fmt.Errorf("incomplete authentication configuration: %s missing", authMethodEnv) } } - passthroughEnvVars := strings.Split(os.Getenv(PassthroughEnv), ",") + passthroughEnvVars := strings.Split(os.Getenv(passthroughEnv), ",") if isLogin { - _ = os.Setenv(TokenEnv, vaultLogin) - passthroughEnvVars = append(passthroughEnvVars, TokenEnv) + _ = os.Setenv(tokenEnv, vaultLogin) + passthroughEnvVars = append(passthroughEnvVars, tokenEnv) } // do not sanitize env vars specified in VAULT_PASSTHROUGH @@ -172,11 +172,11 @@ func LoadConfig() (*Config, error) { Role: role, AuthPath: authPath, AuthMethod: authMethod, - TransitKeyID: os.Getenv(TransitKeyIDEnv), - TransitPath: os.Getenv(TransitPathEnv), - TransitBatchSize: cast.ToInt(os.Getenv(TransitBatchSizeEnv)), - IgnoreMissingSecrets: cast.ToBool(os.Getenv(IgnoreMissingSecretsEnv)), // Used both for reading secrets and transit encryption - FromPath: os.Getenv(FromPathEnv), - RevokeToken: cast.ToBool(os.Getenv(RevokeTokenEnv)), + TransitKeyID: os.Getenv(transitKeyIDEnv), + TransitPath: os.Getenv(transitPathEnv), + TransitBatchSize: cast.ToInt(os.Getenv(transitBatchSizeEnv)), + IgnoreMissingSecrets: cast.ToBool(os.Getenv(ignoreMissingSecretsEnv)), // Used both for reading secrets and transit encryption + FromPath: os.Getenv(fromPathEnv), + RevokeToken: cast.ToBool(os.Getenv(revokeTokenEnv)), }, nil } diff --git a/pkg/provider/vault/config_test.go b/pkg/provider/vault/config_test.go index e891177..07f75bf 100644 --- a/pkg/provider/vault/config_test.go +++ b/pkg/provider/vault/config_test.go @@ -35,15 +35,15 @@ func TestConfig(t *testing.T) { { name: "Valid login configuration with Token", env: map[string]string{ - TokenEnv: vaultLogin, - TokenFileEnv: tokenFile, - PassthroughEnv: AgentAddrEnv + ", " + CLINoColorEnv, - TransitKeyIDEnv: "test-key", - TransitPathEnv: "transit", - TransitBatchSizeEnv: "10", - IgnoreMissingSecretsEnv: "true", - RevokeTokenEnv: "true", - FromPathEnv: "secret/data/test", + tokenEnv: vaultLogin, + tokenFileEnv: tokenFile, + passthroughEnv: agentAddrEnv + ", " + cliNoColorEnv, + transitKeyIDEnv: "test-key", + transitPathEnv: "transit", + transitBatchSizeEnv: "10", + ignoreMissingSecretsEnv: "true", + revokeTokenEnv: "true", + fromPathEnv: "secret/data/test", }, wantConfig: &Config{ IsLogin: true, @@ -60,10 +60,10 @@ func TestConfig(t *testing.T) { { name: "Valid login configuration with Role and Path", env: map[string]string{ - TokenEnv: vaultLogin, - RoleEnv: "test-app-role", - PathEnv: "auth/approle/test/login", - AuthMethodEnv: "test-approle", + tokenEnv: vaultLogin, + roleEnv: "test-app-role", + pathEnv: "auth/approle/test/login", + authMethodEnv: "test-approle", }, wantConfig: &Config{ IsLogin: true, @@ -76,31 +76,31 @@ func TestConfig(t *testing.T) { { name: "Invalid login configuration using tokenfile - missing token file", env: map[string]string{ - TokenFileEnv: tokenFile + "/invalid", + tokenFileEnv: tokenFile + "/invalid", }, err: fmt.Errorf("failed to read token file %s/invalid: open %s/invalid: not a directory", tokenFile, tokenFile), }, { name: "Invalid login configuration using role/path - missing role", env: map[string]string{ - PathEnv: "auth/approle/test/login", - AuthMethodEnv: "k8s", + pathEnv: "auth/approle/test/login", + authMethodEnv: "k8s", }, err: fmt.Errorf("incomplete authentication configuration: VAULT_ROLE missing"), }, { name: "Invalid login configuration using role/path - missing path", env: map[string]string{ - RoleEnv: "test-app-role", - AuthMethodEnv: "k8s", + roleEnv: "test-app-role", + authMethodEnv: "k8s", }, err: fmt.Errorf("incomplete authentication configuration: VAULT_PATH missing"), }, { name: "Invalid login configuration using role/path - missing auth method", env: map[string]string{ - RoleEnv: "test-app-role", - PathEnv: "auth/approle/test/login", + roleEnv: "test-app-role", + pathEnv: "auth/approle/test/login", }, err: fmt.Errorf("incomplete authentication configuration: VAULT_AUTH_METHOD missing"), }, diff --git a/pkg/provider/vault/vault.go b/pkg/provider/vault/vault.go index 09ab314..27e6512 100644 --- a/pkg/provider/vault/vault.go +++ b/pkg/provider/vault/vault.go @@ -30,9 +30,9 @@ import ( "github.com/bank-vaults/secret-init/pkg/provider" ) -var ( - ProviderName = "vault" - ProviderEnvRegex = regexp.MustCompile(`(vault:)(.*)#(.*)`) +const ( + ProviderType = "vault" + referenceSelector = `(vault:)(.*)#(.*)` ) type Provider struct { @@ -66,7 +66,12 @@ func (s *sanitized) append(key string, value string) { } } -func NewProvider(config *Config, appConfig *common.Config) (provider.Provider, error) { +func NewProvider(_ context.Context, appConfig *common.Config) (provider.Provider, error) { + config, err := LoadConfig() + if err != nil { + return nil, fmt.Errorf("failed to create vault config: %w", err) + } + clientOptions := []vault.ClientOption{vault.ClientLogger(clientLogger{slog.Default()})} if config.TokenFile != "" { clientOptions = append(clientOptions, vault.ClientToken(config.Token)) @@ -118,7 +123,7 @@ func NewProvider(config *Config, appConfig *common.Config) (provider.Provider, e // and the value is the secret value // E.g. paths: MYSQL_PASSWORD=secret/data/mysql/password // returns: []provider.Secret{provider.Secret{Path: "MYSQL_PASSWORD", Value: "password"}} -func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Secret, error) { +func (p *Provider) LoadSecrets(ctx context.Context, paths []string) ([]provider.Secret, error) { sanitized := sanitized{login: p.isLogin} vaultEnviron := parsePathsToMap(paths) @@ -141,7 +146,7 @@ func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Se if p.revokeToken { // ref: https://www.vaultproject.io/api/auth/token/index.html#revoke-a-token-self- - err := p.client.RawClient().Auth().Token().RevokeSelf(p.client.RawClient().Token()) + err := p.client.RawClient().Auth().Token().RevokeSelfWithContext(ctx, p.client.RawClient().Token()) if err != nil { // Do not exit on error, token revoking can be denied by policy slog.Warn("failed to revoke token") @@ -153,6 +158,12 @@ func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Se return sanitized.secrets, nil } +// If the path contains some string formatted as "vault:{STR}#{STR}" +// it is most probably a vault path +func Valid(envValue string) bool { + return regexp.MustCompile(referenceSelector).MatchString(envValue) +} + func parsePathsToMap(paths []string) map[string]string { vaultEnviron := make(map[string]string) diff --git a/pkg/provider/vault/vault_test.go b/pkg/provider/vault/vault_test.go deleted file mode 100644 index 51afa39..0000000 --- a/pkg/provider/vault/vault_test.go +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright © 2023 Bank-Vaults Maintainers -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package vault - -import ( - "fmt" - "io" - "log/slog" - "os" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/bank-vaults/secret-init/pkg/common" -) - -var originalLogger *slog.Logger - -func TestMain(m *testing.M) { - setupTestLogger() - code := m.Run() - restoreLogger() - os.Exit(code) -} - -func TestNewProvider(t *testing.T) { - tests := []struct { - name string - config *Config - err error - wantType bool - }{ - { - name: "Valid Provider with Token", - config: &Config{ - IsLogin: true, - TokenFile: "root", - Token: "root", - TransitKeyID: "test-key", - TransitPath: "transit", - TransitBatchSize: 10, - IgnoreMissingSecrets: true, - FromPath: "secret/data/test", - RevokeToken: true, - }, - wantType: true, - }, - { - name: "Valid Provider with vault:login as Token and daemon mode", - config: &Config{ - IsLogin: true, - Token: vaultLogin, - TokenFile: "root", - IgnoreMissingSecrets: true, - FromPath: "secret/data/test", - }, - wantType: true, - }, - { - name: "Fail to create vault client due to timeout", - config: &Config{}, - err: fmt.Errorf("failed to create vault client: timeout [10s] during waiting for Vault token"), - }, - } - - for _, tt := range tests { - ttp := tt - - t.Run(ttp.name, func(t *testing.T) { - provider, err := NewProvider(ttp.config, &common.Config{}) - if err != nil { - assert.EqualError(t, ttp.err, err.Error(), "Unexpected error message") - } - if ttp.wantType { - assert.Equal(t, ttp.wantType, provider != nil, "Unexpected provider type") - } - }) - } -} - -func setupTestLogger() { - originalLogger = slog.Default() - - // Discard logs to avoid polluting the test output - testLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) - slog.SetDefault(testLogger) -} - -func restoreLogger() { - slog.SetDefault(originalLogger) -}