From 2c2cbf1ae19a209da39a55d7404990821b68c765 Mon Sep 17 00:00:00 2001 From: Codey Whitt Date: Fri, 6 Mar 2020 18:42:46 -0700 Subject: [PATCH] decrypt master key in parallel Currently, Sops tries to decrypt the master key by iterating over each one in series. In practice, this can lead to significant performance issues if the working key is at the end of the list. The example case we encountered was using Sops with Ansible and the upcoming Sops plugin for it. During execution, Ansible calls sops for every task to decrypt the variables. In our case, we used KMS and PGP for our keys, and when running on a machine that relied on PGP, it would take over 30 seconds to decrypt a file as it iterated through providers. This would make PGP based builds take hours where it would take minutes for KMS based builds as it would fail the KMS key look up thousands of times throughout the run. This change makes the decryption efforts happen in parallel, returning the first to succeed, making the speed faster than KMS based builds without affecting the KMS based build times. --- go.mod | 1 + kms/keysource.go | 19 +++++- sops.go | 161 +++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 159 insertions(+), 22 deletions(-) diff --git a/go.mod b/go.mod index ed9685ab5..6e9b73f66 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/blang/semver v3.5.1+incompatible github.com/fatih/color v1.7.0 github.com/golang/protobuf v1.3.2 + github.com/google/go-cmp v0.3.0 github.com/google/shlex v0.0.0-20181106134648-c34317bd91bf github.com/goware/prefixer v0.0.0-20160118172347-395022866408 github.com/hashicorp/vault/api v1.0.4 diff --git a/kms/keysource.go b/kms/keysource.go index 21eba8783..37277274d 100644 --- a/kms/keysource.go +++ b/kms/keysource.go @@ -10,6 +10,7 @@ import ( "os" "regexp" "strings" + "sync" "time" "go.mozilla.org/sops/v3/logging" @@ -30,6 +31,9 @@ func init() { log = logging.NewLogger("AWSKMS") } +// Need a mutex to sync access to the service +var kmsSvcMtx sync.RWMutex + // this needs to be a global var for unit tests to work (mockKMS redefines // it in keysource_test.go) var kmsSvc kmsiface.KMSAPI @@ -92,17 +96,26 @@ func (key *MasterKey) Decrypt() ([]byte, error) { log.WithField("arn", key.Arn).Info("Decryption failed") return nil, fmt.Errorf("Error base64-decoding encrypted data key: %s", err) } + + // Capture service locally to prevent data races + kmsSvcMtx.RLock() + svc := kmsSvc + kmsSvcMtx.RUnlock() + // isMocked is set by unit test to indicate that the KMS service // has already been initialized. it's ugly, but it works. - if kmsSvc == nil || !isMocked { + if svc == nil || !isMocked { sess, err := key.createSession() if err != nil { log.WithField("arn", key.Arn).Info("Decryption failed") return nil, fmt.Errorf("Error creating AWS session: %v", err) } - kmsSvc = kms.New(sess) + svc = kms.New(sess) + kmsSvcMtx.Lock() + kmsSvc = svc + kmsSvcMtx.Unlock() } - decrypted, err := kmsSvc.Decrypt(&kms.DecryptInput{CiphertextBlob: k, EncryptionContext: key.EncryptionContext}) + decrypted, err := svc.Decrypt(&kms.DecryptInput{CiphertextBlob: k, EncryptionContext: key.EncryptionContext}) if err != nil { log.WithField("arn", key.Arn).Info("Decryption failed") return nil, fmt.Errorf("Error decrypting key: %v", err) diff --git a/sops.go b/sops.go index d1d64ec8e..f10cc6dfd 100644 --- a/sops.go +++ b/sops.go @@ -44,6 +44,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" "github.com/sirupsen/logrus" @@ -642,48 +643,170 @@ func (m Metadata) GetDataKeyWithKeyServices(svcs []keyservice.KeyServiceClient) // any of the MasterKeys in the KeyGroup with any of the provided key services, // returning as soon as one key service succeeds. func decryptKeyGroup(group KeyGroup, svcs []keyservice.KeyServiceClient) ([]byte, error) { - var keyErrs []error + var ( + part []byte + wg sync.WaitGroup + wgCount int + keyErrs []error + ) + ctx, cancel := context.WithCancel(context.Background()) + keyChan := make(chan []byte) + errChan := make(chan error) + for _, key := range group { - part, err := decryptKey(key, svcs) - if err != nil { - keyErrs = append(keyErrs, err) - } else { - return part, nil + // Run decryption inside of goroutine + wg.Add(1) + wgCount++ + go func(ctx context.Context, wg *sync.WaitGroup, key keys.MasterKey, keyChan chan []byte, errChan chan error) { + var ( + err error + rsp []byte + ) + + // Decrypt the key + rsp, err = decryptKey(ctx, key, svcs) + + select { + // Key was already decrypted. Time to bail + case <-ctx.Done(): + return + + // Forward the response + default: + if err != nil { + errChan <- err + } else { + keyChan <- rsp + } + } + }(ctx, &wg, key, keyChan, errChan) + } + + // Setup goroutine to watch for the decryption responses + go func(ctx context.Context, wg *sync.WaitGroup, keyChan chan []byte, cancel context.CancelFunc) { + for { + select { + // Bail out if the context is canceled + case <-ctx.Done(): + for i := 0; i < int(wgCount); i++ { + wg.Done() + } + return + + // Receive the key + case key := <-keyChan: + if part == nil { + part = key + } + cancel() + + // Receive the errors + case err := <-errChan: + keyErrs = append(keyErrs, err) + } + wg.Done() + wgCount-- } + }(ctx, &wg, keyChan, cancel) + + // Wait for the services to return a key + wg.Wait() + + // Return the key + if part != nil { + return part, nil } + + // Return the errors return nil, decryptKeyErrors(keyErrs) } // decryptKey tries to decrypt the contents of the provided MasterKey with any // of the key services, returning as soon as one key service succeeds. -func decryptKey(key keys.MasterKey, svcs []keyservice.KeyServiceClient) ([]byte, error) { +func decryptKey(ctx context.Context, key keys.MasterKey, svcs []keyservice.KeyServiceClient) ([]byte, error) { svcKey := keyservice.KeyFromMasterKey(key) var part []byte + var wgCount int decryptErr := decryptKeyError{ keyName: key.ToString(), } + + // Use channels to extract the key in parallel from the services + var wg sync.WaitGroup + localCtx, cancel := context.WithCancel(ctx) + keyChan := make(chan []byte) + errChan := make(chan error) + for _, svc := range svcs { - // All keys in a key group encrypt the same part, so as soon - // as we decrypt it successfully with one key, we need to - // proceed with the next group - var err error - if part == nil { - var rsp *keyservice.DecryptResponse + // Run decryption inside of goroutine + wg.Add(1) + wgCount++ + go func(ctx context.Context, wg *sync.WaitGroup, svc keyservice.KeyServiceClient, keyChan chan []byte, errChan chan error) { + var ( + err error + rsp *keyservice.DecryptResponse + ) + + // Decrypt the key rsp, err = svc.Decrypt( - context.Background(), + ctx, &keyservice.DecryptRequest{ Ciphertext: key.EncryptedDataKey(), Key: &svcKey, - }) - if err == nil { - part = rsp.Plaintext + }, + ) + + select { + // Key was already decrypted. Time to bail + case <-ctx.Done(): + return + // Forward the response + default: + if err != nil { + errChan <- err + } else { + keyChan <- rsp.Plaintext + } } - } - decryptErr.errs = append(decryptErr.errs, err) + }(localCtx, &wg, svc, keyChan, errChan) } + + // Setup goroutine to watch for the key + go func(ctx context.Context, wg *sync.WaitGroup, keyChan chan []byte, cancel context.CancelFunc) { + for { + select { + // Bail out if the context is canceled + case <-ctx.Done(): + for i := 0; i < int(wgCount); i++ { + wg.Done() + } + return + + // Receive the key + case key := <-keyChan: + if part == nil { + part = key + } + cancel() + + // Receive the errors + case err := <-errChan: + decryptErr.errs = append(decryptErr.errs, err) + } + wg.Done() + wgCount-- + } + }(ctx, &wg, keyChan, cancel) + + // Wait for the services to return + wg.Wait() + + // Return the key if part != nil { return part, nil } + + // Return the errors return nil, &decryptErr }