Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

decrypt master key in parallel #638

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
github.com/blang/semver v3.5.1+incompatible
github.com/fatih/color v1.7.0
github.com/golang/protobuf v1.3.2
github.com/google/go-cmp v0.3.0
github.com/google/shlex v0.0.0-20181106134648-c34317bd91bf
github.com/goware/prefixer v0.0.0-20160118172347-395022866408
github.com/hashicorp/vault/api v1.0.4
Expand Down
19 changes: 16 additions & 3 deletions kms/keysource.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os"
"regexp"
"strings"
"sync"
"time"

"go.mozilla.org/sops/v3/logging"
Expand All @@ -30,6 +31,9 @@ func init() {
log = logging.NewLogger("AWSKMS")
}

// Need a mutex to sync access to the service
var kmsSvcMtx sync.RWMutex

// this needs to be a global var for unit tests to work (mockKMS redefines
// it in keysource_test.go)
var kmsSvc kmsiface.KMSAPI
Expand Down Expand Up @@ -92,17 +96,26 @@ func (key *MasterKey) Decrypt() ([]byte, error) {
log.WithField("arn", key.Arn).Info("Decryption failed")
return nil, fmt.Errorf("Error base64-decoding encrypted data key: %s", err)
}

// Capture service locally to prevent data races
kmsSvcMtx.RLock()
svc := kmsSvc
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes were necessary to deal with the concurrent access of the kmsSvc by the decrypt calls.

kmsSvcMtx.RUnlock()

// isMocked is set by unit test to indicate that the KMS service
// has already been initialized. it's ugly, but it works.
if kmsSvc == nil || !isMocked {
if svc == nil || !isMocked {
sess, err := key.createSession()
if err != nil {
log.WithField("arn", key.Arn).Info("Decryption failed")
return nil, fmt.Errorf("Error creating AWS session: %v", err)
}
kmsSvc = kms.New(sess)
svc = kms.New(sess)
kmsSvcMtx.Lock()
kmsSvc = svc
kmsSvcMtx.Unlock()
}
decrypted, err := kmsSvc.Decrypt(&kms.DecryptInput{CiphertextBlob: k, EncryptionContext: key.EncryptionContext})
decrypted, err := svc.Decrypt(&kms.DecryptInput{CiphertextBlob: k, EncryptionContext: key.EncryptionContext})
if err != nil {
log.WithField("arn", key.Arn).Info("Decryption failed")
return nil, fmt.Errorf("Error decrypting key: %v", err)
Expand Down
161 changes: 142 additions & 19 deletions sops.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"regexp"
"strconv"
"strings"
"sync"
"time"

"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -642,48 +643,170 @@ func (m Metadata) GetDataKeyWithKeyServices(svcs []keyservice.KeyServiceClient)
// any of the MasterKeys in the KeyGroup with any of the provided key services,
// returning as soon as one key service succeeds.
func decryptKeyGroup(group KeyGroup, svcs []keyservice.KeyServiceClient) ([]byte, error) {
var keyErrs []error
var (
part []byte
wg sync.WaitGroup
wgCount int
keyErrs []error
)
ctx, cancel := context.WithCancel(context.Background())
keyChan := make(chan []byte)
errChan := make(chan error)

for _, key := range group {
part, err := decryptKey(key, svcs)
if err != nil {
keyErrs = append(keyErrs, err)
} else {
return part, nil
// Run decryption inside of goroutine
wg.Add(1)
wgCount++
go func(ctx context.Context, wg *sync.WaitGroup, key keys.MasterKey, keyChan chan []byte, errChan chan error) {
var (
err error
rsp []byte
)

// Decrypt the key
rsp, err = decryptKey(ctx, key, svcs)

select {
// Key was already decrypted. Time to bail
case <-ctx.Done():
return

// Forward the response
default:
if err != nil {
errChan <- err
} else {
keyChan <- rsp
}
}
}(ctx, &wg, key, keyChan, errChan)
}

// Setup goroutine to watch for the decryption responses
go func(ctx context.Context, wg *sync.WaitGroup, keyChan chan []byte, cancel context.CancelFunc) {
for {
select {
// Bail out if the context is canceled
case <-ctx.Done():
for i := 0; i < int(wgCount); i++ {
wg.Done()
}
return

// Receive the key
case key := <-keyChan:
if part == nil {
part = key
}
cancel()

// Receive the errors
case err := <-errChan:
keyErrs = append(keyErrs, err)
}
wg.Done()
wgCount--
}
}(ctx, &wg, keyChan, cancel)

// Wait for the services to return a key
wg.Wait()

// Return the key
if part != nil {
return part, nil
}

// Return the errors
return nil, decryptKeyErrors(keyErrs)
}

// decryptKey tries to decrypt the contents of the provided MasterKey with any
// of the key services, returning as soon as one key service succeeds.
func decryptKey(key keys.MasterKey, svcs []keyservice.KeyServiceClient) ([]byte, error) {
func decryptKey(ctx context.Context, key keys.MasterKey, svcs []keyservice.KeyServiceClient) ([]byte, error) {
svcKey := keyservice.KeyFromMasterKey(key)
var part []byte
var wgCount int
decryptErr := decryptKeyError{
keyName: key.ToString(),
}

// Use channels to extract the key in parallel from the services
var wg sync.WaitGroup
localCtx, cancel := context.WithCancel(ctx)
keyChan := make(chan []byte)
errChan := make(chan error)

for _, svc := range svcs {
// All keys in a key group encrypt the same part, so as soon
// as we decrypt it successfully with one key, we need to
// proceed with the next group
var err error
if part == nil {
var rsp *keyservice.DecryptResponse
// Run decryption inside of goroutine
wg.Add(1)
wgCount++
go func(ctx context.Context, wg *sync.WaitGroup, svc keyservice.KeyServiceClient, keyChan chan []byte, errChan chan error) {
var (
err error
rsp *keyservice.DecryptResponse
)

// Decrypt the key
rsp, err = svc.Decrypt(
context.Background(),
ctx,
&keyservice.DecryptRequest{
Ciphertext: key.EncryptedDataKey(),
Key: &svcKey,
})
if err == nil {
part = rsp.Plaintext
},
)

select {
// Key was already decrypted. Time to bail
case <-ctx.Done():
return
// Forward the response
default:
if err != nil {
errChan <- err
} else {
keyChan <- rsp.Plaintext
}
}
}
decryptErr.errs = append(decryptErr.errs, err)
}(localCtx, &wg, svc, keyChan, errChan)
}

// Setup goroutine to watch for the key
go func(ctx context.Context, wg *sync.WaitGroup, keyChan chan []byte, cancel context.CancelFunc) {
for {
select {
// Bail out if the context is canceled
case <-ctx.Done():
for i := 0; i < int(wgCount); i++ {
wg.Done()
}
return

// Receive the key
case key := <-keyChan:
if part == nil {
part = key
}
cancel()

// Receive the errors
case err := <-errChan:
decryptErr.errs = append(decryptErr.errs, err)
}
wg.Done()
wgCount--
}
}(ctx, &wg, keyChan, cancel)

// Wait for the services to return
wg.Wait()

// Return the key
if part != nil {
return part, nil
}

// Return the errors
return nil, &decryptErr
}

Expand Down