Skip to content

Commit 30f5259

Browse files
author
lenny
committed
refactor: Rework for addressing PR comments
Signed-off-by: lenny <[email protected]>
1 parent a1bc207 commit 30f5259

File tree

3 files changed

+32
-31
lines changed

3 files changed

+32
-31
lines changed

internal/security/secretstore/init.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func (b *Bootstrap) BootstrapHandler(ctx context.Context, _ *sync.WaitGroup, _ s
135135
shouldContinue = false
136136

137137
case http.StatusTooManyRequests:
138-
lc.Errorf("vault is unsealed and in standby mode (Status Code: %d)", sCode)
138+
// we're done here. Will go into ready mode or reseal
139139
shouldContinue = false
140140

141141
case http.StatusNotImplemented:
@@ -152,19 +152,18 @@ func (b *Bootstrap) BootstrapHandler(ctx context.Context, _ *sync.WaitGroup, _ s
152152
lc.Info("Root token stripped from init response for security reasons")
153153
}
154154

155-
err = client.Unseal(initResponse.Keys, initResponse.KeysBase64)
156-
if err == nil {
155+
err = client.Unseal(initResponse.KeysBase64)
156+
if err != nil {
157157
lc.Errorf("Unable to unseal Vault: %s", err.Error())
158158
return false
159159
}
160160

161161
// We need the unencrypted initResponse in order to generate a temporary root token later
162162
// Make a copy and save the copy, possibly encrypted
163-
var encryptedInitResponse types.InitResponse
163+
encryptedInitResponse := initResponse
164164
// Optionally encrypt the vault init response based on whether encryption was enabled
165165
if vmkEncryption.IsEncrypting() {
166-
encryptedInitResponse, err = vmkEncryption.EncryptInitResponse(initResponse)
167-
if err != nil {
166+
if err := vmkEncryption.EncryptInitResponse(&encryptedInitResponse); err != nil {
168167
lc.Errorf("failed to encrypt init response from secret store: %s", err.Error())
169168
return false
170169
}
@@ -182,13 +181,14 @@ func (b *Bootstrap) BootstrapHandler(ctx context.Context, _ *sync.WaitGroup, _ s
182181
}
183182
// Optionally decrypt the vault init response based on whether encryption was enabled
184183
if vmkEncryption.IsEncrypting() {
185-
initResponse, err = vmkEncryption.DecryptInitResponse(initResponse)
184+
err = vmkEncryption.DecryptInitResponse(&initResponse)
186185
if err != nil {
187186
lc.Errorf("failed to decrypt key shares for secret store unsealing: %s", err.Error())
188187
return false
189188
}
190189
}
191-
err := client.Unseal(initResponse.Keys, initResponse.KeysBase64)
190+
191+
err := client.Unseal(initResponse.KeysBase64)
192192
if err == nil {
193193
shouldContinue = false
194194
}

internal/security/secretstore/vmkencryption.go

+20-19
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,10 @@ func (v *VMKEncryption) IsEncrypting() bool {
104104
// in the end, Keys and KeysBase64 are removed and replaced with
105105
// EncryptedKeys and Nonces in the resulting JSON
106106
// Root token is left untouched
107-
func (v *VMKEncryption) EncryptInitResponse(initResp types.InitResponse) (types.InitResponse, error) {
108-
var encryptedResponse types.InitResponse
107+
func (v *VMKEncryption) EncryptInitResponse(initResp *types.InitResponse) error {
109108
// Check prerequisite (key has been loaded)
110109
if !v.encrypting {
111-
return encryptedResponse, fmt.Errorf("Cannot encrypt init response as key has not been loaded")
110+
return fmt.Errorf("cannot encrypt init response as key has not been loaded")
112111
}
113112

114113
newKeys := make([]string, len(initResp.Keys))
@@ -118,12 +117,12 @@ func (v *VMKEncryption) EncryptInitResponse(initResp types.InitResponse) (types.
118117

119118
plainText, err := hex.DecodeString(hexPlaintext)
120119
if err != nil {
121-
return encryptedResponse, fmt.Errorf("failed to decode hex bytes of keyshare (details omitted): %w", err)
120+
return fmt.Errorf("failed to decode hex bytes of keyshare (details omitted): %w", err)
122121
}
123122

124123
keyShare, nonce, err := v.gcmEncryptKeyShare(plainText, i) // Wrap using a unique AES key
125124
if err != nil {
126-
return encryptedResponse, fmt.Errorf("failed to wrap key %d: %w", i, err)
125+
return fmt.Errorf("failed to wrap key %d: %w", i, err)
127126
}
128127

129128
newKeys[i] = hex.EncodeToString(keyShare)
@@ -133,51 +132,53 @@ func (v *VMKEncryption) EncryptInitResponse(initResp types.InitResponse) (types.
133132
wipeKey(nonce) // Clear out nonce
134133
}
135134

136-
encryptedResponse.EncryptedKeys = newKeys
137-
encryptedResponse.Nonces = newNonces
138-
return encryptedResponse, nil
135+
initResp.EncryptedKeys = newKeys
136+
initResp.Nonces = newNonces
137+
initResp.Keys = nil // strings are immutable, must wait for GC
138+
initResp.KeysBase64 = nil // strings are immutable, must wait for GC
139+
return nil
139140
}
140141

141142
// DecryptInitResponse processes the InitResponse and decrypts the key shares
142143
// in the end, EncryptedKeys and Nonces are removed and replaced with
143144
// Keys and KeysBase64 in the resulting JSON like the init response was originally
144145
// Root token is left untouched
145-
func (v *VMKEncryption) DecryptInitResponse(initResp types.InitResponse) (types.InitResponse, error) {
146-
var decryptedResponse types.InitResponse
147-
146+
func (v *VMKEncryption) DecryptInitResponse(initResp *types.InitResponse) error {
148147
// Check prerequisite (key has been loaded)
149148
if !v.encrypting {
150-
return decryptedResponse, fmt.Errorf("Cannot decrypt init response as key has not been loaded")
149+
return fmt.Errorf("cannot decrypt init response as key has not been loaded")
151150
}
152151

153152
newKeys := make([]string, len(initResp.EncryptedKeys))
154153
newKeysBase64 := make([]string, len(initResp.EncryptedKeys))
155154

156155
for i, hexCiphertext := range initResp.EncryptedKeys {
157-
158156
hexNonce := initResp.Nonces[i]
159157
nonce, err := hex.DecodeString(hexNonce)
160158
if err != nil {
161-
return decryptedResponse, fmt.Errorf("failed to decode hex bytes of nonce: %w", err)
159+
return fmt.Errorf("failed to decode hex bytes of nonce: %w", err)
162160
}
163161

164162
cipherText, err := hex.DecodeString(hexCiphertext)
165163
if err != nil {
166-
return decryptedResponse, fmt.Errorf("failed to decode hex bytes of ciphertext: %w", err)
164+
return fmt.Errorf("failed to decode hex bytes of ciphertext: %w", err)
167165
}
168166

169167
keyShare, err := v.gcmDecryptKeyShare(cipherText, nonce, i) // Unwrap using a unique AES key
170168
if err != nil {
171-
return decryptedResponse, fmt.Errorf("failed to unwrap key %d: %w", i, err)
169+
return fmt.Errorf("failed to unwrap key %d: %w", i, err)
172170
}
173171

174172
newKeys[i] = hex.EncodeToString(keyShare)
175173
newKeysBase64[i] = base64.StdEncoding.EncodeToString(keyShare)
176174
}
177175

178-
decryptedResponse.Keys = newKeys
179-
decryptedResponse.KeysBase64 = newKeysBase64
180-
return decryptedResponse, nil
176+
initResp.Keys = newKeys
177+
initResp.KeysBase64 = newKeysBase64
178+
initResp.EncryptedKeys = nil
179+
initResp.Nonces = nil
180+
181+
return nil
181182
}
182183

183184
//

internal/security/secretstore/vmkencryption_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ func TestVMKEncryption(t *testing.T) {
5959
err := vmkEncryption.LoadIKM("/bin/myikm")
6060
require.NoError(t, err)
6161

62-
initResp, err = vmkEncryption.EncryptInitResponse(initResp)
62+
err = vmkEncryption.EncryptInitResponse(&initResp)
6363
require.NoError(t, err)
6464

65-
initResp, err = vmkEncryption.DecryptInitResponse(initResp)
65+
err = vmkEncryption.DecryptInitResponse(&initResp)
6666
require.NoError(t, err)
6767
require.Equal(t, initialInitResp, initResp)
6868

@@ -91,10 +91,10 @@ func TestVMKEncryptionFailPath(t *testing.T) {
9191
err := vmkEncryption.LoadIKM("/bin/myikm")
9292
require.Error(t, err)
9393

94-
initResp, err = vmkEncryption.EncryptInitResponse(initResp)
94+
err = vmkEncryption.EncryptInitResponse(&initResp)
9595
require.Error(t, err)
9696

97-
initResp, err = vmkEncryption.DecryptInitResponse(initResp)
97+
err = vmkEncryption.DecryptInitResponse(&initResp)
9898
require.Error(t, err)
9999

100100
vmkEncryption.WipeIKM()

0 commit comments

Comments
 (0)