diff --git a/go.mod b/go.mod index ed9685ab5..6e9b73f66 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/blang/semver v3.5.1+incompatible github.com/fatih/color v1.7.0 github.com/golang/protobuf v1.3.2 + github.com/google/go-cmp v0.3.0 github.com/google/shlex v0.0.0-20181106134648-c34317bd91bf github.com/goware/prefixer v0.0.0-20160118172347-395022866408 github.com/hashicorp/vault/api v1.0.4 diff --git a/kms/keysource.go b/kms/keysource.go index 21eba8783..37277274d 100644 --- a/kms/keysource.go +++ b/kms/keysource.go @@ -10,6 +10,7 @@ import ( "os" "regexp" "strings" + "sync" "time" "go.mozilla.org/sops/v3/logging" @@ -30,6 +31,9 @@ func init() { log = logging.NewLogger("AWSKMS") } +// Need a mutex to sync access to the service +var kmsSvcMtx sync.RWMutex + // this needs to be a global var for unit tests to work (mockKMS redefines // it in keysource_test.go) var kmsSvc kmsiface.KMSAPI @@ -92,17 +96,26 @@ func (key *MasterKey) Decrypt() ([]byte, error) { log.WithField("arn", key.Arn).Info("Decryption failed") return nil, fmt.Errorf("Error base64-decoding encrypted data key: %s", err) } + + // Capture service locally to prevent data races + kmsSvcMtx.RLock() + svc := kmsSvc + kmsSvcMtx.RUnlock() + // isMocked is set by unit test to indicate that the KMS service // has already been initialized. it's ugly, but it works. - if kmsSvc == nil || !isMocked { + if svc == nil || !isMocked { sess, err := key.createSession() if err != nil { log.WithField("arn", key.Arn).Info("Decryption failed") return nil, fmt.Errorf("Error creating AWS session: %v", err) } - kmsSvc = kms.New(sess) + svc = kms.New(sess) + kmsSvcMtx.Lock() + kmsSvc = svc + kmsSvcMtx.Unlock() } - decrypted, err := kmsSvc.Decrypt(&kms.DecryptInput{CiphertextBlob: k, EncryptionContext: key.EncryptionContext}) + decrypted, err := svc.Decrypt(&kms.DecryptInput{CiphertextBlob: k, EncryptionContext: key.EncryptionContext}) if err != nil { log.WithField("arn", key.Arn).Info("Decryption failed") return nil, fmt.Errorf("Error decrypting key: %v", err) diff --git a/sops.go b/sops.go index d1d64ec8e..f10cc6dfd 100644 --- a/sops.go +++ b/sops.go @@ -44,6 +44,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" "github.com/sirupsen/logrus" @@ -642,48 +643,170 @@ func (m Metadata) GetDataKeyWithKeyServices(svcs []keyservice.KeyServiceClient) // any of the MasterKeys in the KeyGroup with any of the provided key services, // returning as soon as one key service succeeds. func decryptKeyGroup(group KeyGroup, svcs []keyservice.KeyServiceClient) ([]byte, error) { - var keyErrs []error + var ( + part []byte + wg sync.WaitGroup + wgCount int + keyErrs []error + ) + ctx, cancel := context.WithCancel(context.Background()) + keyChan := make(chan []byte) + errChan := make(chan error) + for _, key := range group { - part, err := decryptKey(key, svcs) - if err != nil { - keyErrs = append(keyErrs, err) - } else { - return part, nil + // Run decryption inside of goroutine + wg.Add(1) + wgCount++ + go func(ctx context.Context, wg *sync.WaitGroup, key keys.MasterKey, keyChan chan []byte, errChan chan error) { + var ( + err error + rsp []byte + ) + + // Decrypt the key + rsp, err = decryptKey(ctx, key, svcs) + + select { + // Key was already decrypted. Time to bail + case <-ctx.Done(): + return + + // Forward the response + default: + if err != nil { + errChan <- err + } else { + keyChan <- rsp + } + } + }(ctx, &wg, key, keyChan, errChan) + } + + // Setup goroutine to watch for the decryption responses + go func(ctx context.Context, wg *sync.WaitGroup, keyChan chan []byte, cancel context.CancelFunc) { + for { + select { + // Bail out if the context is canceled + case <-ctx.Done(): + for i := 0; i < int(wgCount); i++ { + wg.Done() + } + return + + // Receive the key + case key := <-keyChan: + if part == nil { + part = key + } + cancel() + + // Receive the errors + case err := <-errChan: + keyErrs = append(keyErrs, err) + } + wg.Done() + wgCount-- } + }(ctx, &wg, keyChan, cancel) + + // Wait for the services to return a key + wg.Wait() + + // Return the key + if part != nil { + return part, nil } + + // Return the errors return nil, decryptKeyErrors(keyErrs) } // decryptKey tries to decrypt the contents of the provided MasterKey with any // of the key services, returning as soon as one key service succeeds. -func decryptKey(key keys.MasterKey, svcs []keyservice.KeyServiceClient) ([]byte, error) { +func decryptKey(ctx context.Context, key keys.MasterKey, svcs []keyservice.KeyServiceClient) ([]byte, error) { svcKey := keyservice.KeyFromMasterKey(key) var part []byte + var wgCount int decryptErr := decryptKeyError{ keyName: key.ToString(), } + + // Use channels to extract the key in parallel from the services + var wg sync.WaitGroup + localCtx, cancel := context.WithCancel(ctx) + keyChan := make(chan []byte) + errChan := make(chan error) + for _, svc := range svcs { - // All keys in a key group encrypt the same part, so as soon - // as we decrypt it successfully with one key, we need to - // proceed with the next group - var err error - if part == nil { - var rsp *keyservice.DecryptResponse + // Run decryption inside of goroutine + wg.Add(1) + wgCount++ + go func(ctx context.Context, wg *sync.WaitGroup, svc keyservice.KeyServiceClient, keyChan chan []byte, errChan chan error) { + var ( + err error + rsp *keyservice.DecryptResponse + ) + + // Decrypt the key rsp, err = svc.Decrypt( - context.Background(), + ctx, &keyservice.DecryptRequest{ Ciphertext: key.EncryptedDataKey(), Key: &svcKey, - }) - if err == nil { - part = rsp.Plaintext + }, + ) + + select { + // Key was already decrypted. Time to bail + case <-ctx.Done(): + return + // Forward the response + default: + if err != nil { + errChan <- err + } else { + keyChan <- rsp.Plaintext + } } - } - decryptErr.errs = append(decryptErr.errs, err) + }(localCtx, &wg, svc, keyChan, errChan) } + + // Setup goroutine to watch for the key + go func(ctx context.Context, wg *sync.WaitGroup, keyChan chan []byte, cancel context.CancelFunc) { + for { + select { + // Bail out if the context is canceled + case <-ctx.Done(): + for i := 0; i < int(wgCount); i++ { + wg.Done() + } + return + + // Receive the key + case key := <-keyChan: + if part == nil { + part = key + } + cancel() + + // Receive the errors + case err := <-errChan: + decryptErr.errs = append(decryptErr.errs, err) + } + wg.Done() + wgCount-- + } + }(ctx, &wg, keyChan, cancel) + + // Wait for the services to return + wg.Wait() + + // Return the key if part != nil { return part, nil } + + // Return the errors return nil, &decryptErr }