From 06fec462ccc9bd48db8c69e99970a4b0057659c3 Mon Sep 17 00:00:00 2001 From: soeren <2378192-soerenschneider@users.noreply.gitlab.com> Date: Tue, 12 Oct 2021 21:34:07 +0200 Subject: [PATCH] feat: Better validation --- internal/config/common.go | 13 ++++- internal/config/common_test.go | 101 +++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 internal/config/common_test.go diff --git a/internal/config/common.go b/internal/config/common.go index 4e63988..cc02ac4 100644 --- a/internal/config/common.go +++ b/internal/config/common.go @@ -2,6 +2,7 @@ package config import ( "errors" + "fmt" "github.com/rs/zerolog/log" "net/url" "strings" @@ -23,6 +24,7 @@ func (conf *VaultConfig) IsTokenIncreaseEnabled() bool { func (conf *VaultConfig) Print() { log.Info().Msgf("VaultAddr=%s", conf.VaultAddr) + log.Info().Msgf("PathPrefix=%s", conf.PathPrefix) if len(conf.RoleId) > 0 { log.Info().Msgf("VaultRoleId=%s", conf.RoleId) } @@ -38,7 +40,6 @@ func (conf *VaultConfig) Print() { if conf.TokenIncreaseInterval > 0 { log.Info().Msgf("TokenIncreaseInterval=%d", conf.TokenIncreaseInterval) } - // TODO: Check pathPrefix } func DefaultVaultConfig() VaultConfig { @@ -57,6 +58,16 @@ func (conf *VaultConfig) Validate() error { if len(conf.VaultAddr) == 0 { return errors.New("no Vault address defined") } + addr, err := url.ParseRequestURI(conf.VaultAddr) + if err != nil || addr.Scheme == "" || addr.Host == "" || addr.Port() == "" { + return errors.New("can not parse supplied vault addr as url") + } + + for _, prefix := range []string{"/", "secret/"} { + if strings.HasPrefix(conf.PathPrefix, prefix) { + return fmt.Errorf("vault path prefix must not start with %s", prefix) + } + } validRoleIdCredentials := len(conf.SecretId) > 0 && len(conf.RoleId) > 0 if !validRoleIdCredentials && len(conf.VaultToken) == 0 { diff --git a/internal/config/common_test.go b/internal/config/common_test.go new file mode 100644 index 0000000..12f5cea --- /dev/null +++ b/internal/config/common_test.go @@ -0,0 +1,101 @@ +package config + +import "testing" + +func TestVaultConfig_Validate(t *testing.T) { + type fields struct { + VaultToken string + VaultAddr string + SecretId string + RoleId string + TokenIncreaseSeconds int + TokenIncreaseInterval int + PathPrefix string + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + { + name: "valid config - token", + fields: fields{ + VaultToken: "s.asd83hrfhasfjsda", + VaultAddr: "https://my-vault-instance:443", + TokenIncreaseSeconds: 0, + TokenIncreaseInterval: 0, + PathPrefix: "production", + }, + }, + { + name: "valid config - approle", + fields: fields{ + VaultAddr: "https://my-vault-instance:443", + SecretId: "super-secret", + RoleId: "my-role", + TokenIncreaseSeconds: 0, + TokenIncreaseInterval: 0, + PathPrefix: "dev-v002", + }, + }, + { + name: "invalid config - missing protocol", + fields: fields{ + VaultToken: "s.asd83hrfhasfjsda", + VaultAddr: "my-vault-instance:443", + TokenIncreaseSeconds: 0, + TokenIncreaseInterval: 0, + PathPrefix: "production", + }, + wantErr: true, + }, + { + name: "invalid config - missing port", + fields: fields{ + VaultToken: "s.asd83hrfhasfjsda", + VaultAddr: "http://my-vault-instance", + TokenIncreaseSeconds: 0, + TokenIncreaseInterval: 0, + PathPrefix: "production", + }, + wantErr: true, + }, + { + name: "invalid config - invalid path prefix", + fields: fields{ + VaultToken: "s.asd83hrfhasfjsda", + VaultAddr: "http://my-vault-instance:443", + TokenIncreaseSeconds: 0, + TokenIncreaseInterval: 0, + PathPrefix: "/production", + }, + wantErr: true, + }, + { + name: "invalid config - no auth methods", + fields: fields{ + VaultAddr: "http://my-vault-instance:443", + TokenIncreaseSeconds: 0, + TokenIncreaseInterval: 0, + PathPrefix: "production", + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conf := &VaultConfig{ + VaultToken: tt.fields.VaultToken, + VaultAddr: tt.fields.VaultAddr, + SecretId: tt.fields.SecretId, + RoleId: tt.fields.RoleId, + TokenIncreaseSeconds: tt.fields.TokenIncreaseSeconds, + TokenIncreaseInterval: tt.fields.TokenIncreaseInterval, + PathPrefix: tt.fields.PathPrefix, + } + if err := conf.Validate(); (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}