From 95e673e5064a4cd0fa7f88a771c6fa54627ce0bf Mon Sep 17 00:00:00 2001
From: James Ryans <46216691+james-ryans@users.noreply.github.com>
Date: Tue, 16 Jan 2024 22:58:38 +0700
Subject: [PATCH] [testbed] Detect and fix data race at testbed integration
 test (#30549)

**Description:**
Fixing a bug - The Testbed module has many data race warnings when
tested with `go test -race`, especially for traces. I've listed the data
race issues and fixes below:
1. Add mutex to MockBackend `startedAt` variable
Because `tc := testbed.NewTestCase()` starts a `logStats()` goroutine
which calls `MockBackend.GetStats()` that uses the `startedAt` variable,
a data race occurs when we later execute `tc.StartBackend()`. This is
because `tc.StartBackend()` writes to the `startedAt` variable.
2. Add mutex to LoadGenerator `startedAt` variable
The reason is similar to MockBackend.
3. Move MockBackend `numSpansReceived` addition after `ConsumeTraces()`
Because `tc.numSpansReceived.Add(uint64(td.SpanCount()))` will make
`tc.LoadGenerator.DataItemsSent()` equals to
`tc.MockBackend.DataItemsReceived()` which the test case will assume
MockBackend already received all the spans and lead to
`tc.ValidateData()` while MockBackend is actually consuming the traces.
4. Get `td.SpanCount()` before `batchprocessor` consumes the spans at
`opencensusreceiver` and `zipkinreceiver`
With the `batchprocessor` pipeline, there's a step that it moves the
source `td.ResourceSpans()` to the batched one and that is why the
`td.SpanCount()` line while `batchprocessor` consuming the traces will
cause data race.

**Testing:**
Add `-race` args to `testbed/runtests.sh` at line 22 where `go test`
takes place.
Run `TESTS_DIR=correctnesstests/traces make e2e-test`

---------

Signed-off-by: James Ryans <james.ryans2012@gmail.com>
---
 .../internal/octrace/opencensus.go                 |  3 ++-
 receiver/zipkinreceiver/trace_receiver.go          |  3 ++-
 testbed/testbed/load_generator.go                  |  5 +++++
 testbed/testbed/mock_backend.go                    | 14 +++++++++-----
 4 files changed, 18 insertions(+), 7 deletions(-)

diff --git a/receiver/opencensusreceiver/internal/octrace/opencensus.go b/receiver/opencensusreceiver/internal/octrace/opencensus.go
index d15cd3f26ab3..f85e5e5a684c 100644
--- a/receiver/opencensusreceiver/internal/octrace/opencensus.go
+++ b/receiver/opencensusreceiver/internal/octrace/opencensus.go
@@ -131,8 +131,9 @@ func (ocr *Receiver) processReceivedMsg(
 func (ocr *Receiver) sendToNextConsumer(longLivedRPCCtx context.Context, td ptrace.Traces) error {
 	ctx := ocr.obsrecv.StartTracesOp(longLivedRPCCtx)
 
+	numReceivedSpans := td.SpanCount()
 	err := ocr.nextConsumer.ConsumeTraces(ctx, td)
-	ocr.obsrecv.EndTracesOp(ctx, receiverDataFormat, td.SpanCount(), err)
+	ocr.obsrecv.EndTracesOp(ctx, receiverDataFormat, numReceivedSpans, err)
 
 	return err
 }
diff --git a/receiver/zipkinreceiver/trace_receiver.go b/receiver/zipkinreceiver/trace_receiver.go
index 5b3d6108f514..14e4a5c512df 100644
--- a/receiver/zipkinreceiver/trace_receiver.go
+++ b/receiver/zipkinreceiver/trace_receiver.go
@@ -234,13 +234,14 @@ func (zr *zipkinReceiver) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
+	numReceivedSpans := td.SpanCount()
 	consumerErr := zr.nextConsumer.ConsumeTraces(ctx, td)
 
 	receiverTagValue := zipkinV2TagValue
 	if asZipkinv1 {
 		receiverTagValue = zipkinV1TagValue
 	}
-	obsrecv.EndTracesOp(ctx, receiverTagValue, td.SpanCount(), consumerErr)
+	obsrecv.EndTracesOp(ctx, receiverTagValue, numReceivedSpans, consumerErr)
 	if consumerErr == nil {
 		// Send back the response "Accepted" as
 		// required at https://zipkin.io/zipkin-api/#/default/post_spans
diff --git a/testbed/testbed/load_generator.go b/testbed/testbed/load_generator.go
index 8fa773228c4b..53af241624d9 100644
--- a/testbed/testbed/load_generator.go
+++ b/testbed/testbed/load_generator.go
@@ -58,6 +58,7 @@ type ProviderSender struct {
 	// Number of data items (spans or metric data points) sent.
 	dataItemsSent atomic.Uint64
 	startedAt     time.Time
+	startMutex    sync.Mutex
 
 	// Number of permanent errors received
 	permanentErrors    atomic.Uint64
@@ -116,6 +117,8 @@ func (ps *ProviderSender) Start(options LoadOptions) {
 
 	// Begin generation
 	go ps.generate()
+	ps.startMutex.Lock()
+	defer ps.startMutex.Unlock()
 	ps.startedAt = time.Now()
 }
 
@@ -148,6 +151,8 @@ func (ps *ProviderSender) IsReady() bool {
 
 // GetStats returns the stats as a printable string.
 func (ps *ProviderSender) GetStats() string {
+	ps.startMutex.Lock()
+	defer ps.startMutex.Unlock()
 	sent := ps.DataItemsSent()
 	return printer.Sprintf("Sent:%10d %s (%d/sec)", sent, ps.sendType, int(float64(sent)/time.Since(ps.startedAt).Seconds()))
 }
diff --git a/testbed/testbed/mock_backend.go b/testbed/testbed/mock_backend.go
index 141d2033e193..982e1c63cdca 100644
--- a/testbed/testbed/mock_backend.go
+++ b/testbed/testbed/mock_backend.go
@@ -41,9 +41,10 @@ type MockBackend struct {
 	logFile     *os.File
 
 	// Start/stop flags
-	isStarted bool
-	stopOnce  sync.Once
-	startedAt time.Time
+	isStarted  bool
+	stopOnce   sync.Once
+	startedAt  time.Time
+	startMutex sync.Mutex
 
 	// Recording fields.
 	isRecording     bool
@@ -100,6 +101,8 @@ func (mb *MockBackend) Start() error {
 	}
 
 	mb.isStarted = true
+	mb.startMutex.Lock()
+	defer mb.startMutex.Unlock()
 	mb.startedAt = time.Now()
 	return nil
 }
@@ -130,6 +133,8 @@ func (mb *MockBackend) EnableRecording() {
 }
 
 func (mb *MockBackend) GetStats() string {
+	mb.startMutex.Lock()
+	defer mb.startMutex.Unlock()
 	received := mb.DataItemsReceived()
 	return printer.Sprintf("Received:%10d items (%d/sec)", received, int(float64(received)/time.Since(mb.startedAt).Seconds()))
 }
@@ -190,8 +195,6 @@ func (tc *MockTraceConsumer) ConsumeTraces(_ context.Context, td ptrace.Traces)
 		return err
 	}
 
-	tc.numSpansReceived.Add(uint64(td.SpanCount()))
-
 	rs := td.ResourceSpans()
 	for i := 0; i < rs.Len(); i++ {
 		ils := rs.At(i).ScopeSpans()
@@ -221,6 +224,7 @@ func (tc *MockTraceConsumer) ConsumeTraces(_ context.Context, td ptrace.Traces)
 	}
 
 	tc.backend.ConsumeTrace(td)
+	tc.numSpansReceived.Add(uint64(td.SpanCount()))
 
 	return nil
 }