From 2e4c3af74c0458ae7156c016d791a8da94706073 Mon Sep 17 00:00:00 2001
From: emily <emilyye@google.com>
Date: Thu, 9 Jan 2020 02:49:32 +0000
Subject: [PATCH] Allow for retries of single requests in a batch on failure

Signed-off-by: Modular Magician <magic-modules@google.com>
---
 google/batcher.go                         | 199 ++++++++++++++--------
 google/batcher_test.go                    |  57 +++++--
 google/iam_batching.go                    |   2 +-
 google/resource_google_project_service.go |   2 +-
 google/serviceusage_batching.go           |  49 +++---
 5 files changed, 195 insertions(+), 114 deletions(-)

diff --git a/google/batcher.go b/google/batcher.go
index 9f1a0ecdcb9..26b8e6ed3e4 100644
--- a/google/batcher.go
+++ b/google/batcher.go
@@ -3,21 +3,19 @@ package google
 import (
 	"context"
 	"fmt"
+	"github.com/hashicorp/errwrap"
 	"log"
 	"sync"
 	"time"
-
-	"github.com/hashicorp/errwrap"
 )
 
 const defaultBatchSendIntervalSec = 3
 
-// RequestBatcher is a global batcher object that keeps track of
-// existing batches.
-// In general, a batcher should be created per service that requires batching
-// in order to prevent blocking batching for one service due to another,
-// and to minimize the possibility of overlap in batchKey formats
-// (see SendRequestWithTimeout)
+// RequestBatcher keeps track of batched requests globally.
+// It should be created at a provider level. In general, one
+// should be created per service that requires batching to:
+//   - prevent blocking batching for one service due to another,
+//   - minimize the possibility of overlap in batchKey formats (see SendRequestWithTimeout)
 type RequestBatcher struct {
 	sync.Mutex
 
@@ -27,39 +25,41 @@ type RequestBatcher struct {
 	debugId   string
 }
 
-// BatchRequest represents a single request to a global batcher.
-type BatchRequest struct {
-	// ResourceName represents the underlying resource for which
-	// a request is made. Its format is determined by what SendF expects, but
-	// typically should be the name of the parent GCP resource being changed.
-	ResourceName string
-
-	// Body is this request's data to be passed to SendF, and may be combined
-	// with other bodies using CombineF.
-	Body interface{}
-
-	// CombineF function determines how to combine bodies from two batches.
-	CombineF batcherCombineFunc
-
-	// SendF function determines how to actually send a batched request to a
-	// third party service. The arguments given to this function are
-	// (ResourceName, Body) where Body may have been combined with other request
-	// Bodies.
-	SendF batcherSendFunc
-
-	// ID for debugging request. This should be specific to a single request
-	// (i.e. per Terraform resource)
-	DebugId string
-}
-
 // These types are meant to be the public interface to batchers. They define
-// logic to manage batch data type and behavior, and require service-specific
-// implementations per type of request per service.
-// Function type for combine existing batches and additional batch data
-type batcherCombineFunc func(body interface{}, toAdd interface{}) (interface{}, error)
+// batch data format and logic to send/combine batches, i.e. they require
+// specific implementations per type of request.
+type (
+	// BatchRequest represents a single request to a global batcher.
+	BatchRequest struct {
+		// ResourceName represents the underlying resource for which
+		// a request is made. Its format is determined by what SendF expects, but
+		// typically should be the name of the parent GCP resource being changed.
+		ResourceName string
+
+		// Body is this request's data to be passed to SendF, and may be combined
+		// with other bodies using CombineF.
+		Body interface{}
+
+		// CombineF function determines how to combine bodies from two batches.
+		CombineF BatcherCombineFunc
+
+		// SendF function determines how to actually send a batched request to a
+		// third party service. The arguments given to this function are
+		// (ResourceName, Body) where Body may have been combined with other request
+		// Bodies.
+		SendF BatcherSendFunc
+
+		// ID for debugging request. This should be specific to a single request
+		// (i.e. per Terraform resource)
+		DebugId string
+	}
 
-// Function type for sending a batch request
-type batcherSendFunc func(resourceName string, body interface{}) (interface{}, error)
+	// BatcherCombineFunc is a function type for combine existing batches and additional batch data
+	BatcherCombineFunc func(body interface{}, toAdd interface{}) (interface{}, error)
+
+	// BatcherSendFunc is a function type for sending a batch request
+	BatcherSendFunc func(resourceName string, body interface{}) (interface{}, error)
+)
 
 // batchResponse bundles an API response (data, error) tuple.
 type batchResponse struct {
@@ -67,16 +67,32 @@ type batchResponse struct {
 	err  error
 }
 
-// startedBatch refers to a processed batch whose timer to send the request has
-// already been started. The responses for the request is sent to each listener
-// channel, representing parallel callers that are waiting on requests
-// combined into this batch.
+func (br *batchResponse) IsError() bool {
+	return br.err != nil
+}
+
+// startedBatch refers to a registered batch to group batch requests coming in.
+// The timer manages the time after which a given batch is sent.
 type startedBatch struct {
 	batchKey string
+
+	// Combined Batch Request
 	*BatchRequest
 
-	listeners []chan batchResponse
-	timer     *time.Timer
+	// subscribers is a registry of the requests (batchSubscriber) combined into this batcher.
+
+	subscribers []batchSubscriber
+
+	timer *time.Timer
+}
+
+// batchSubscriber contains information required for a single request for a startedBatch.
+type batchSubscriber struct {
+	// singleRequest is the original request this subscriber represents
+	singleRequest *BatchRequest
+
+	// respCh is the channel created to communicate the result to a waiting goroutine.s
+	respCh chan batchResponse
 }
 
 // batchingConfig contains user configuration for controlling batch requests.
@@ -94,8 +110,12 @@ func NewRequestBatcher(debugId string, ctx context.Context, config *batchingConf
 		batches:        make(map[string]*startedBatch),
 	}
 
+	// Start goroutine to managing stopping the batcher if the provider-level parent context is closed.
 	go func(b *RequestBatcher) {
-		<-ctx.Done()
+		// Block until parent context is closed
+		<-b.parentCtx.Done()
+
+		log.Printf("[DEBUG] parent context canceled, cleaning up batcher batches")
 		b.stop()
 	}(batcher)
 
@@ -108,19 +128,19 @@ func (b *RequestBatcher) stop() {
 
 	log.Printf("[DEBUG] Stopping batcher %q", b.debugId)
 	for batchKey, batch := range b.batches {
-		log.Printf("[DEBUG] Cleaning up batch request %q", batchKey)
+		log.Printf("[DEBUG] Cancelling started batch for batchKey %q", batchKey)
 		batch.timer.Stop()
-		for _, l := range batch.listeners {
-			close(l)
+		for _, l := range batch.subscribers {
+			close(l.respCh)
 		}
 	}
 }
 
-// SendRequestWithTimeout is expected to be called per parallel call.
-// It manages waiting on the result of a batch request.
+// SendRequestWithTimeout is a blocking call for making a single request, run alone or as part of a batch.
+// It manages registering the single request with the batcher and waiting on the result.
 //
-// Batch requests are grouped by the given batchKey. batchKey
-// should be unique to the API request being sent, most likely similar to
+// Params:
+// batchKey: A string to group batchable requests. It should be unique to the API request being sent, similar to
 // the HTTP request URL with GCP resource ID included in the URL (the caller
 // may choose to use a key with method if needed to diff GET/read and
 // POST/create)
@@ -179,40 +199,75 @@ func (b *RequestBatcher) registerBatchRequest(batchKey string, newRequest *Batch
 		return batch.addRequest(newRequest)
 	}
 
+	// Batch doesn't exist for given batch key - create a new batch.
+
 	log.Printf("[DEBUG] Creating new batch %q from request %q", newRequest.DebugId, batchKey)
+
 	// The calling goroutine will need a channel to wait on for a response.
 	respCh := make(chan batchResponse, 1)
+	sub := batchSubscriber{
+		singleRequest: newRequest,
+		respCh:        respCh,
+	}
 
-	// Create a new batch.
+	// Create a new batch with copy of the given batch request.
 	b.batches[batchKey] = &startedBatch{
-		BatchRequest: newRequest,
-		batchKey:     batchKey,
-		listeners:    []chan batchResponse{respCh},
+		BatchRequest: &BatchRequest{
+			ResourceName: newRequest.ResourceName,
+			Body:         newRequest.Body,
+			CombineF:     newRequest.CombineF,
+			SendF:        newRequest.SendF,
+			DebugId:      fmt.Sprintf("Combined batch for started batch %q", batchKey),
+		},
+		batchKey:    batchKey,
+		subscribers: []batchSubscriber{sub},
 	}
 
 	// Start a timer to send the request
 	b.batches[batchKey].timer = time.AfterFunc(b.sendAfter, func() {
 		batch := b.popBatch(batchKey)
-
-		var resp batchResponse
 		if batch == nil {
-			log.Printf("[DEBUG] Batch not found in saved batches, running single request batch %q", batchKey)
-			resp = newRequest.send()
+			log.Printf("[ERROR] batch should have been added to saved batches - just run as single request %q", newRequest.DebugId)
+			respCh <- newRequest.send()
+			close(respCh)
 		} else {
-			log.Printf("[DEBUG] Sending batch %q combining %d requests)", batchKey, len(batch.listeners))
-			resp = batch.send()
-		}
-
-		// Send message to all goroutines waiting on result.
-		for _, ch := range batch.listeners {
-			ch <- resp
-			close(ch)
+			b.sendBatchWithSingleRetry(batchKey, batch)
 		}
 	})
 
 	return respCh, nil
 }
 
+func (b *RequestBatcher) sendBatchWithSingleRetry(batchKey string, batch *startedBatch) {
+	log.Printf("[DEBUG] Sending batch %q combining %d requests)", batchKey, len(batch.subscribers))
+	resp := batch.send()
+
+	// If the batch failed and combines more than one request, retry each single request.
+	if resp.IsError() && len(batch.subscribers) > 1 {
+		log.Printf("[DEBUG] Batch failed with error: %v", resp.err)
+		log.Printf("[DEBUG] Sending each request in batch separately")
+		for _, sub := range batch.subscribers {
+			log.Printf("[DEBUG] Retrying single request %q", sub.singleRequest.DebugId)
+			singleResp := sub.singleRequest.send()
+			log.Printf("[DEBUG] Retried single request %q returned response: %v", sub.singleRequest.DebugId, singleResp)
+
+			if singleResp.IsError() {
+				singleResp.err = errwrap.Wrapf(
+					"batch request and retry as single request failed - final error: {{err}}",
+					singleResp.err)
+			}
+			sub.respCh <- singleResp
+			close(sub.respCh)
+		}
+	} else {
+		// Send result to all subscribers
+		for _, sub := range batch.subscribers {
+			sub.respCh <- resp
+			close(sub.respCh)
+		}
+	}
+}
+
 // popBatch safely gets and removes a batch with given batchkey from the
 // RequestBatcher's started batches.
 func (b *RequestBatcher) popBatch(batchKey string) *startedBatch {
@@ -243,7 +298,11 @@ func (batch *startedBatch) addRequest(newRequest *BatchRequest) (<-chan batchRes
 	log.Printf("[DEBUG] Added batch request %q to batch. New batch body: %v", newRequest.DebugId, batch.Body)
 
 	respCh := make(chan batchResponse, 1)
-	batch.listeners = append(batch.listeners, respCh)
+	sub := batchSubscriber{
+		singleRequest: newRequest,
+		respCh:        respCh,
+	}
+	batch.subscribers = append(batch.subscribers, sub)
 	return respCh, nil
 }
 
diff --git a/google/batcher_test.go b/google/batcher_test.go
index 1078a63b48f..34df7165b8c 100644
--- a/google/batcher_test.go
+++ b/google/batcher_test.go
@@ -4,6 +4,7 @@ import (
 	"context"
 	"errors"
 	"fmt"
+	"log"
 	"strings"
 	"sync"
 	"testing"
@@ -139,41 +140,63 @@ func TestRequestBatcher_errInSend(t *testing.T) {
 			enableBatching: true,
 		})
 
-	testResource := "resource for send error"
-	sendErrTmpl := "this is an expected error in send batch for resource %q"
+	// combineF keeps track of the batched indexes
+	testCombine := func(body interface{}, toAdd interface{}) (interface{}, error) {
+		return append(body.([]int), toAdd.([]int)...), nil
+	}
 
-	// combineF is no-op
-	testCombine := func(_ interface{}, _ interface{}) (interface{}, error) {
+	failIdx := 0
+	testResource := "RESOURCE-SEND-ERROR"
+	expectedErrMsg := fmt.Sprintf("Error - batch %q contains idx %d", testResource, failIdx)
+
+	testSendBatch := func(resourceName string, body interface{}) (interface{}, error) {
+		log.Printf("[DEBUG] sendBatch body: %+v", body)
+		for _, v := range body.([]int) {
+			if v == failIdx {
+				return nil, fmt.Errorf(expectedErrMsg)
+			}
+		}
 		return nil, nil
 	}
 
-	testSendBatch := func(resourceName string, cnt interface{}) (interface{}, error) {
-		return cnt, fmt.Errorf(sendErrTmpl, resourceName)
-	}
+	numRequests := 3
 
 	wg := sync.WaitGroup{}
-	wg.Add(2)
+	wg.Add(numRequests)
 
-	for i := 0; i < 2; i++ {
+	for i := 0; i < numRequests; i++ {
 		go func(idx int) {
 			defer wg.Done()
 
 			req := &BatchRequest{
 				DebugId:      fmt.Sprintf("sendError %d", idx),
 				ResourceName: testResource,
-				Body:         nil,
+				Body:         []int{idx},
 				CombineF:     testCombine,
 				SendF:        testSendBatch,
 			}
 
 			_, err := testBatcher.SendRequestWithTimeout("batchSendError", req, time.Duration(10)*time.Second)
-			if err == nil {
-				t.Errorf("expected error, got none")
-				return
-			}
-			expectedErr := fmt.Sprintf(sendErrTmpl, testResource)
-			if !strings.Contains(err.Error(), fmt.Sprintf(sendErrTmpl, testResource)) {
-				t.Errorf("expected error %q, got error: %v", expectedErr, err)
+			// Requests without index 0 should have succeeded
+			if idx == failIdx {
+				// We expect an error
+				if err == nil {
+					t.Errorf("expected error for request %d, got none", idx)
+				}
+				// Check error message
+				expectedErrPrefix := "batch request and retry as single request failed - final error: "
+				if !strings.Contains(err.Error(), expectedErrPrefix) {
+					t.Errorf("expected error %q to contain %q", err, expectedErrPrefix)
+				}
+				if !strings.Contains(err.Error(), expectedErrMsg) {
+					t.Errorf("expected error %q to contain %q", err, expectedErrMsg)
+				}
+			} else {
+
+				// We shouldn't get error for non-failure index
+				if err != nil {
+					t.Errorf("expected request %d to succeed, got error: %v", i, err)
+				}
 			}
 		}(i)
 	}
diff --git a/google/iam_batching.go b/google/iam_batching.go
index 31cfb3b7fe0..e5ecbfe5525 100644
--- a/google/iam_batching.go
+++ b/google/iam_batching.go
@@ -42,7 +42,7 @@ func combineBatchIamPolicyModifiers(currV interface{}, toAddV interface{}) (inte
 	return append(currModifiers, newModifiers...), nil
 }
 
-func sendBatchModifyIamPolicy(updater ResourceIamUpdater) batcherSendFunc {
+func sendBatchModifyIamPolicy(updater ResourceIamUpdater) BatcherSendFunc {
 	return func(resourceName string, body interface{}) (interface{}, error) {
 		modifiers, ok := body.([]iamPolicyModifyFunc)
 		if !ok {
diff --git a/google/resource_google_project_service.go b/google/resource_google_project_service.go
index fe1a79ccf6b..34c94c9ae60 100644
--- a/google/resource_google_project_service.go
+++ b/google/resource_google_project_service.go
@@ -120,7 +120,7 @@ func resourceGoogleProjectServiceCreate(d *schema.ResourceData, meta interface{}
 	}
 
 	srv := d.Get("service").(string)
-	err = BatchRequestEnableServices(map[string]struct{}{srv: {}}, project, d, config)
+	err = BatchRequestEnableService(srv, project, d, config)
 	if err != nil {
 		return err
 	}
diff --git a/google/serviceusage_batching.go b/google/serviceusage_batching.go
index 8ec706db0a5..9f9e66c77c7 100644
--- a/google/serviceusage_batching.go
+++ b/google/serviceusage_batching.go
@@ -16,35 +16,18 @@ const (
 // BatchRequestEnableServices can be used to batch requests to enable services
 // across resource nodes, i.e. to batch creation of several
 // google_project_service(s) resources.
-func BatchRequestEnableServices(services map[string]struct{}, project string, d *schema.ResourceData, config *Config) error {
-	// renamed service create calls are relatively likely to fail, so break out
-	// of the batched call to avoid failing that as well
-	for k := range services {
-		if v, ok := renamedServicesByOldAndNewServiceNames[k]; ok {
-			log.Printf("[DEBUG] found renamed service %s (with alternate name %s)", k, v)
-			delete(services, k)
-			// also remove the other name, so we don't enable it 2x in a row
-			delete(services, v)
-
-			// use a short timeout- failures are likely
-			log.Printf("[DEBUG] attempting user-specified name %s", k)
-			err := enableServiceUsageProjectServices([]string{k}, project, config, 1*time.Minute)
-			if err != nil {
-				log.Printf("[DEBUG] saw error %s. attempting alternate name %v", err, v)
-				err2 := enableServiceUsageProjectServices([]string{v}, project, config, 1*time.Minute)
-				if err2 != nil {
-					return fmt.Errorf("Saw 2 subsequent errors attempting to enable a renamed service: %s / %s", err, err2)
-				}
-			}
-		}
+func BatchRequestEnableService(service string, project string, d *schema.ResourceData, config *Config) error {
+	// Renamed service create calls are relatively likely to fail, so don't try to batch the call.
+	if altName, ok := renamedServicesByOldAndNewServiceNames[service]; ok {
+		return tryEnableRenamedService(service, altName, project, d, config)
 	}
 
 	req := &BatchRequest{
 		ResourceName: project,
-		Body:         stringSliceFromGolangSet(services),
+		Body:         []string{service},
 		CombineF:     combineServiceUsageServicesBatches,
 		SendF:        sendBatchFuncEnableServices(config, d.Timeout(schema.TimeoutCreate)),
-		DebugId:      fmt.Sprintf("Enable Project Services %s: %+v", project, services),
+		DebugId:      fmt.Sprintf("Enable Project Service %q for project %q", service, project),
 	}
 
 	_, err := config.requestBatcherServiceUsage.SendRequestWithTimeout(
@@ -54,6 +37,22 @@ func BatchRequestEnableServices(services map[string]struct{}, project string, d
 	return err
 }
 
+func tryEnableRenamedService(service, altName string, project string, d *schema.ResourceData, config *Config) error {
+	log.Printf("[DEBUG] found renamed service %s (with alternate name %s)", service, altName)
+	// use a short timeout- failures are likely
+
+	log.Printf("[DEBUG] attempting enabling service with user-specified name %s", service)
+	err := enableServiceUsageProjectServices([]string{altName}, project, config, 1*time.Minute)
+	if err != nil {
+		log.Printf("[DEBUG] saw error %s. attempting alternate name %v", err, altName)
+		err2 := enableServiceUsageProjectServices([]string{altName}, project, config, 1*time.Minute)
+		if err2 != nil {
+			return fmt.Errorf("Saw 2 subsequent errors attempting to enable a renamed service: %s / %s", err, err2)
+		}
+	}
+	return nil
+}
+
 func BatchRequestReadServices(project string, d *schema.ResourceData, config *Config) (interface{}, error) {
 	req := &BatchRequest{
 		ResourceName: project,
@@ -83,7 +82,7 @@ func combineServiceUsageServicesBatches(srvsRaw interface{}, toAddRaw interface{
 	return append(srvs, toAdd...), nil
 }
 
-func sendBatchFuncEnableServices(config *Config, timeout time.Duration) batcherSendFunc {
+func sendBatchFuncEnableServices(config *Config, timeout time.Duration) BatcherSendFunc {
 	return func(project string, toEnableRaw interface{}) (interface{}, error) {
 		toEnable, ok := toEnableRaw.([]string)
 		if !ok {
@@ -93,7 +92,7 @@ func sendBatchFuncEnableServices(config *Config, timeout time.Duration) batcherS
 	}
 }
 
-func sendListServices(config *Config, timeout time.Duration) batcherSendFunc {
+func sendListServices(config *Config, timeout time.Duration) BatcherSendFunc {
 	return func(project string, _ interface{}) (interface{}, error) {
 		return listCurrentlyEnabledServices(project, config, timeout)
 	}