Skip to content

Commit

Permalink
fix: Factored out main logic, added secret struct, fixed affected parts
Browse files Browse the repository at this point in the history
Signed-off-by: Bence Csati <[email protected]>
  • Loading branch information
csatib02 committed Nov 28, 2023
1 parent 29cad09 commit 7256648
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 117 deletions.
83 changes: 83 additions & 0 deletions env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright © 2023 Bank-Vaults Maintainers
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"fmt"
"os"
"strings"

"github.com/bank-vaults/secret-init/provider"
"github.com/bank-vaults/secret-init/provider/file"
)

func GetEnvironMap() map[string]string {
environ := make(map[string]string, len(os.Environ()))
for _, env := range os.Environ() {
split := strings.SplitN(env, "=", 2)
name := split[0]
value := split[1]
environ[name] = value
}

return environ
}

func ExtractPathsFromEnvs(envs map[string]string) []string {
var secretPaths []string

for _, path := range envs {
if p, path := getProviderPath(path); p != nil {
secretPaths = append(secretPaths, path)
}
}

return secretPaths
}

func CreateSecretEnvsFrom(envs map[string]string, secrets []provider.Secret) ([]string, error) {
// Reverse the map so we can match
// the environment variable key to the secret
// by using the secret path
reversedEnvs := make(map[string]string)
for envKey, path := range envs {
if p, path := getProviderPath(path); p != nil {
reversedEnvs[path] = envKey
}
}

var secretsEnv []string
for _, secret := range secrets {
path := secret.Path
value := secret.Value
key, ok := reversedEnvs[path]
if !ok {
return nil, fmt.Errorf("failed to find environment variable key for secret path: %s", path)
}
secretsEnv = append(secretsEnv, fmt.Sprintf("%s=%s", key, value))
}

return secretsEnv, nil
}

// Returns the detected provider name and path with removed prefix
func getProviderPath(path string) (*string, string) {
if strings.HasPrefix(path, "file:") {
var fileProviderName = file.ProviderName
return &fileProviderName, strings.TrimPrefix(path, "file:")
}

return nil, path
}
61 changes: 3 additions & 58 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"os/exec"
"os/signal"
"slices"
"strings"
"syscall"
"time"

Expand All @@ -39,7 +38,7 @@ import (
func NewProvider(providerName string) (provider.Provider, error) {
switch providerName {
case file.ProviderName:
provider, err := file.NewProvider(os.DirFS("/secrets"))
provider, err := file.NewProvider(os.DirFS("/"))
if err != nil {
return nil, err
}
Expand All @@ -51,59 +50,6 @@ func NewProvider(providerName string) (provider.Provider, error) {
}
}

func CreateMapOfEnvs() map[string]string {
environ := make(map[string]string, len(os.Environ()))
for _, env := range os.Environ() {
split := strings.SplitN(env, "=", 2)
name := split[0]
value := split[1]
environ[name] = value
}

return environ
}

func ExtractPathsFromEnvs(envs map[string]string) []string {
var secretPaths []string

for _, path := range envs {
if strings.HasPrefix(path, "file:") {
path = strings.TrimPrefix(path, "file://")
secretPaths = append(secretPaths, path)
}
}

return secretPaths
}

func CreateEnvsFromLoadedSecrets(envs map[string]string, secrets []string) ([]string, error) {
// Reverse the map so we can match
// the environment variable key to the secret
// by using the secret path
reversedEnvs := make(map[string]string)
for envKey, path := range envs {
if strings.HasPrefix(path, "file:") {
path = strings.TrimPrefix(path, "file://")
reversedEnvs[path] = envKey
}
}

var secretsEnv []string
for _, secret := range secrets {
split := strings.SplitN(secret, "|", 2)
secretPath := split[0]

secretValue := split[1]
secretKey, ok := reversedEnvs[secretPath]
if !ok {
return nil, fmt.Errorf("failed to find environment variable key for secret path: %s", secretPath)
}
secretsEnv = append(secretsEnv, fmt.Sprintf("%s=%s", secretKey, secretValue))
}

return secretsEnv, nil
}

func main() {
var logger *slog.Logger
{
Expand Down Expand Up @@ -189,7 +135,7 @@ func main() {
os.Exit(1)
}

environ := CreateMapOfEnvs()
environ := GetEnvironMap()
paths := ExtractPathsFromEnvs(environ)

ctx := context.Background()
Expand All @@ -199,8 +145,7 @@ func main() {

os.Exit(1)
}

secretsEnv, err := CreateEnvsFromLoadedSecrets(environ, secrets)
secretsEnv, err := CreateSecretEnvsFrom(environ, secrets)
if err != nil {
logger.Error(fmt.Errorf("failed to create environment variables from loaded secrets: %w", err).Error())

Expand Down
58 changes: 14 additions & 44 deletions provider/file/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"fmt"
"io/fs"
"strings"

"github.com/bank-vaults/secret-init/provider"
)
Expand All @@ -33,64 +34,33 @@ func NewProvider(fs fs.FS) (provider.Provider, error) {
return nil, fmt.Errorf("file system is nil")
}

isEmpty, err := isFileSystemEmpty(fs)
if err != nil {
return nil, fmt.Errorf("failed to check if file system is empty: %w", err)
}
if isEmpty {
return nil, fmt.Errorf("file system is empty")
}

return &Provider{fs: fs}, nil
}

func (provider *Provider) LoadSecrets(_ context.Context, paths []string) ([]string, error) {
var secrets []string
func (p *Provider) LoadSecrets(_ context.Context, paths []string) ([]provider.Secret, error) {
var secrets []provider.Secret

for i, path := range paths {
secret, err := provider.getSecretFromFile(path)
for _, path := range paths {
secret, err := p.getSecretFromFile(path)
if err != nil {
return nil, fmt.Errorf("failed to get secret from file: %w", err)
}
// Add the secret path with a "|" separator character
// to the secrets slice along with the secret
// so later we can match it to the environment key
secrets = append(secrets, paths[i]+"|"+secret)
}

return secrets, nil
}

func isFileSystemEmpty(fsys fs.FS) (bool, error) {
dir, err := fs.ReadDir(fsys, ".")
fmt.Println(dir, err)
if err != nil {
return false, err
}

for _, entry := range dir {
if entry.IsDir() || entry.Type().IsRegular() {
return false, nil
}
secrets = append(secrets, provider.Secret{
Path: path,
Value: secret,
})
}

return true, nil
return secrets, nil
}

func (provider *Provider) getSecretFromFile(path string) (string, error) {
content, err := provider.readFile(path)
func (p *Provider) getSecretFromFile(filepath string) (string, error) {
filepath = strings.TrimLeft(filepath, "/")
content, err := fs.ReadFile(p.fs, filepath)
if err != nil {
return "", err
return "", fmt.Errorf("failed to read file: %w", err)
}

return string(content), nil
}

func (provider *Provider) readFile(path string) ([]byte, error) {
content, err := fs.ReadFile(provider.fs, path)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}

return content, nil
}
37 changes: 23 additions & 14 deletions provider/file/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"testing/fstest"

"github.com/stretchr/testify/assert"

"github.com/bank-vaults/secret-init/provider"
)

func TestNewProvider(t *testing.T) {
Expand All @@ -35,6 +37,7 @@ func TestNewProvider(t *testing.T) {
fs: fstest.MapFS{
"test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")},
"test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")},
"test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")},
},
wantErr: false,
wantType: true,
Expand All @@ -45,18 +48,11 @@ func TestNewProvider(t *testing.T) {
wantErr: true,
wantType: false,
},
{
name: "Empty file system",
fs: fstest.MapFS{},
wantErr: true,
wantType: false,
},
}

for _, tt := range tests {
ttp := tt
t.Run(ttp.name, func(t *testing.T) {

prov, err := NewProvider(ttp.fs)
if (err != nil) != ttp.wantErr {
t.Fatalf("NewProvider() error = %v, wantErr %v", err, ttp.wantErr)
Expand All @@ -77,25 +73,39 @@ func TestLoadSecrets(t *testing.T) {
fs fs.FS
paths []string
wantErr bool
wantData []string
wantData []provider.Secret
}{
{
name: "Load secrets successfully",
fs: fstest.MapFS{
"test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")},
"test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")},
"test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")},
},
paths: []string{
"test/secrets/sqlpass.txt",
"test/secrets/awsaccess.txt",
"test/secrets/awsid.txt",
},
wantErr: false,
wantData: []provider.Secret{
{Path: "test/secrets/sqlpass.txt", Value: "3xtr3ms3cr3t"},
{Path: "test/secrets/awsaccess.txt", Value: "s3cr3t"},
{Path: "test/secrets/awsid.txt", Value: "secretId"},
},
paths: []string{"test/secrets/sqlpass.txt", "test/secrets/awsaccess.txt"},
wantErr: false,
wantData: []string{"test/secrets/sqlpass.txt|3xtr3ms3cr3t", "test/secrets/awsaccess.txt|s3cr3t"},
},
{
name: "Fail to load secrets due to invalid path",
fs: fstest.MapFS{
"test/secrets/sqlpass.txt": &fstest.MapFile{Data: []byte("3xtr3ms3cr3t")},
"test/secrets/awsaccess.txt": &fstest.MapFile{Data: []byte("s3cr3t")},
"test/secrets/awsid.txt": &fstest.MapFile{Data: []byte("secretId")},
},
paths: []string{
"test/secrets/mistake/sqlpass.txt",
"test/secrets/mistake/awsaccess.txt",
"test/secrets/mistake/awsid.txt",
},
paths: []string{"test/secrets/mistake/sqlpass.txt", "test/secrets/mistake/awsaccess.txt"},
wantErr: true,
wantData: nil,
},
Expand All @@ -108,8 +118,7 @@ func TestLoadSecrets(t *testing.T) {
if assert.NoError(t, err, "Unexpected error") {
secrets, err := provider.LoadSecrets(context.Background(), ttp.paths)
assert.Equal(t, ttp.wantErr, err != nil, "Unexpected error status")

assert.Equal(t, ttp.wantData, secrets, "Unexpected secrets loaded")
assert.ElementsMatch(t, ttp.wantData, secrets, "Unexpected secrets loaded")
}
})
}
Expand Down
8 changes: 7 additions & 1 deletion provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@ import "context"

// Provider is an interface for securely loading secrets based on environment variables.
type Provider interface {
LoadSecrets(ctx context.Context, paths []string) ([]string, error)
LoadSecrets(ctx context.Context, paths []string) ([]Secret, error)
}

// Secret holds Provider-specific secret data.
type Secret struct {
Path string
Value string
}

0 comments on commit 7256648

Please sign in to comment.