diff --git a/cmd/client/client.go b/cmd/client/client.go index 5f5ffe8..ad94131 100644 --- a/cmd/client/client.go +++ b/cmd/client/client.go @@ -9,12 +9,8 @@ import ( "github.com/soerenschneider/acmevault/internal/client/hooks" "github.com/soerenschneider/acmevault/internal/config" "github.com/soerenschneider/acmevault/pkg/certstorage/vault" - "strings" ) -// Prefix of the configured AppRole role_ids for this tool -const roleIdPrefix = "acme-client-" - func main() { configPath := cmd.ParseCliFlags() log.Info().Msgf("acmevault-client version %s, commit %s", internal.BuildVersion, internal.CommitHash) @@ -65,8 +61,7 @@ func pickUpCerts(client *client.VaultAcmeClient, conf config.AcmeVaultClientConf return errors.New("empty client passed") } - domain := strings.ReplaceAll(conf.RoleId, roleIdPrefix, "") - return client.RetrieveAndSave(domain) + return client.RetrieveAndSave(conf.Domain) } func writeMetrics(conf config.AcmeVaultClientConfig) { diff --git a/internal/config/client.go b/internal/config/client.go index 86cd76d..3a5956a 100644 --- a/internal/config/client.go +++ b/internal/config/client.go @@ -2,6 +2,7 @@ package config import ( "encoding/json" + "errors" "fmt" "github.com/rs/zerolog/log" "io/ioutil" @@ -10,6 +11,7 @@ import ( type AcmeVaultClientConfig struct { VaultConfig FsWriterConfig + Domain string `json:"domain"` Hook []string `json:"hooks"` MetricsPath string `json:"metricsPath"` } @@ -27,6 +29,10 @@ func AcmeVaultClientConfigFromFile(path string) (AcmeVaultClientConfig, error) { } func (conf AcmeVaultClientConfig) Validate() error { + if len(conf.Domain) == 0 { + return errors.New("Missing field `domain`") + } + err := conf.FsWriterConfig.Validate() if err != nil { return err @@ -39,6 +45,7 @@ func (conf AcmeVaultClientConfig) Print() { log.Info().Msg("--- Client Config Start ---") conf.VaultConfig.Print() conf.FsWriterConfig.Print() + log.Info().Msgf("Domain=%s", conf.Domain) if len(conf.Hook) > 0 { log.Info().Msgf("Hooks=%v", conf.Hook) } diff --git a/internal/config/common.go b/internal/config/common.go index 8876003..90a14cd 100644 --- a/internal/config/common.go +++ b/internal/config/common.go @@ -20,6 +20,7 @@ type VaultConfig struct { TokenIncreaseSeconds int `json:"tokenIncreaseSeconds"` TokenIncreaseInterval int `json:"tokenIncreaseInterval"` PathPrefix string `json:"vaultPathPrefix"` + DomainPathFormat string `json:"domainPathFormat"` } func (conf *VaultConfig) IsTokenIncreaseEnabled() bool { @@ -53,6 +54,9 @@ func (conf *VaultConfig) Print() { if conf.TokenIncreaseInterval > 0 { log.Info().Msgf("TokenIncreaseInterval=%d", conf.TokenIncreaseInterval) } + if len(conf.DomainPathFormat) > 0 { + log.Info().Msgf("DomainPathFormat=%s", conf.DomainPathFormat) + } } func DefaultVaultConfig() VaultConfig { @@ -104,6 +108,12 @@ func (conf *VaultConfig) Validate() error { return errors.New("specified secretIdFile is not writable, quitting") } + if len(conf.DomainPathFormat) > 0 { + if !strings.ContainsRune(conf.DomainPathFormat, '%') { + return fmt.Errorf("the domainPathFormat '%s' does not seem to be a valid format string", conf.DomainPathFormat) + } + } + return nil } diff --git a/pkg/certstorage/vault/vault.go b/pkg/certstorage/vault/vault.go index 6eef8e4..2f9da5d 100644 --- a/pkg/certstorage/vault/vault.go +++ b/pkg/certstorage/vault/vault.go @@ -493,6 +493,13 @@ func (vault *VaultBackend) unwrapAndSaveSecretId(wrappingToken, secretIdFile str return parsed, nil } +func (vault *VaultBackend) formatDomain(domain string) string { + if len(vault.conf.DomainPathFormat) == 0 { + return domain + } + return fmt.Sprintf(vault.conf.DomainPathFormat, domain) +} + func (vault *VaultBackend) getAwsCredentialsPath() string { return fmt.Sprintf("/aws/creds/%s", awsRole) } @@ -502,9 +509,11 @@ func (vault *VaultBackend) getAccountPath(hash string) string { } func (vault *VaultBackend) getCertDataPath(domain string) string { - return fmt.Sprintf("%s/client/%s/certificate", vault.namespacedPrefix, domain) + domainFormatted := vault.formatDomain(domain) + return fmt.Sprintf("%s/client/%s/certificate", vault.namespacedPrefix, domainFormatted) } func (vault *VaultBackend) getSecretDataPath(domain string) string { - return fmt.Sprintf("%s/client/%s/privatekey", vault.namespacedPrefix, domain) + domainFormatted := vault.formatDomain(domain) + return fmt.Sprintf("%s/client/%s/privatekey", vault.namespacedPrefix, domainFormatted) } diff --git a/pkg/certstorage/vault/vault_test.go b/pkg/certstorage/vault/vault_test.go index b137ffd..9e0c8ea 100644 --- a/pkg/certstorage/vault/vault_test.go +++ b/pkg/certstorage/vault/vault_test.go @@ -1,6 +1,8 @@ package vault import ( + "github.com/hashicorp/vault/api" + "github.com/soerenschneider/acmevault/internal/config" "reflect" "testing" ) @@ -45,3 +47,56 @@ func Test_buildSecretPayload(t *testing.T) { }) } } + +func TestVaultBackend_getSecretDataPath(t *testing.T) { + type fields struct { + client *api.Client + conf config.VaultConfig + revokeToken bool + namespacedPrefix string + } + tests := []struct { + name string + fields fields + args string + want string + }{ + { + name: "custom domain path", + fields: fields{ + client: nil, + conf: config.VaultConfig{ + DomainPathFormat: "machine-%s", + }, + revokeToken: false, + namespacedPrefix: "acmevault", + }, + args: "test.domain.tld", + want: "acmevault/client/machine-test.domain.tld/privatekey", + }, + { + name: "no domain format given", + fields: fields{ + client: nil, + conf: config.VaultConfig{}, + revokeToken: false, + namespacedPrefix: "acmevault", + }, + args: "test.domain.tld", + want: "acmevault/client/test.domain.tld/privatekey", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vault := &VaultBackend{ + client: tt.fields.client, + conf: tt.fields.conf, + revokeToken: tt.fields.revokeToken, + namespacedPrefix: tt.fields.namespacedPrefix, + } + if got := vault.getSecretDataPath(tt.args); got != tt.want { + t.Errorf("getSecretDataPath() = %v, want %v", got, tt.want) + } + }) + } +}