diff --git a/service/history/shard/context_impl.go b/service/history/shard/context_impl.go index 4f8ce397358..f2fe97d671a 100644 --- a/service/history/shard/context_impl.go +++ b/service/history/shard/context_impl.go @@ -129,8 +129,9 @@ type ( scheduledTaskMaxReadLevelMap map[string]time.Time // cluster -> scheduledTaskMaxReadLevel // exist only in memory - remoteClusterInfos map[string]*remoteClusterInfo - handoverNamespaces map[string]*namespaceHandOverInfo // keyed on namespace name + remoteClusterInfos map[string]*remoteClusterInfo + handoverNamespaces map[string]*namespaceHandOverInfo // keyed on namespace name + acquireShardRetryPolicy backoff.RetryPolicy } remoteClusterInfo struct { @@ -1621,7 +1622,9 @@ func (s *ContextImpl) transition(request contextRequest) error { // Cancel lifecycle context as soon as we know we're shutting down s.lifecycleCancel() // This will cause the controller to remove this shard from the map and then call s.finishStop() - go s.closeCallback(s) + if s.closeCallback != nil { + go s.closeCallback(s) + } } setStateStopped := func() { @@ -1889,8 +1892,10 @@ func (s *ContextImpl) acquireShard() { // lifecycleCtx. The persistence operations called here use lifecycleCtx as their context, // so if we were blocked in any of them, they should return immediately with a context // canceled error. - policy := backoff.NewExponentialRetryPolicy(1 * time.Second). - WithExpirationInterval(5 * time.Minute) + policy := s.acquireShardRetryPolicy + if policy == nil { + policy = backoff.NewExponentialRetryPolicy(1 * time.Second).WithExpirationInterval(5 * time.Minute) + } // Remember this value across attempts ownershipChanged := false diff --git a/service/history/shard/context_test.go b/service/history/shard/context_test.go index a6face0db90..6b02771c8b3 100644 --- a/service/history/shard/context_test.go +++ b/service/history/shard/context_test.go @@ -27,6 +27,7 @@ package shard import ( "context" "errors" + "fmt" "math/rand" "testing" "time" @@ -36,6 +37,7 @@ import ( "github.com/stretchr/testify/suite" persistencespb "go.temporal.io/server/api/persistence/v1" + "go.temporal.io/server/common/backoff" "go.temporal.io/server/common/clock" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/convert" @@ -53,7 +55,7 @@ type ( *require.Assertions controller *gomock.Controller - mockShard Context + mockShard *ContextTest mockClusterMetadata *cluster.MockMetadata mockShardManager *persistence.MockShardManager mockExecutionManager *persistence.MockExecutionManager @@ -157,7 +159,7 @@ func (s *contextSuite) TestTimerMaxReadLevelInitialization() { ) // clear shardInfo and load from persistence - shardContextImpl := s.mockShard.(*ContextTest) + shardContextImpl := s.mockShard shardContextImpl.shardInfo = nil err := shardContextImpl.loadShardMetadata(convert.BoolPtr(false)) s.NoError(err) @@ -211,7 +213,7 @@ func (s *contextSuite) TestTimerMaxReadLevelUpdate_SingleProcessor() { // update in single processor mode s.mockShard.UpdateScheduledQueueExclusiveHighReadWatermark(cluster.TestCurrentClusterName, true) - scheduledTaskMaxReadLevelMap := s.mockShard.(*ContextTest).scheduledTaskMaxReadLevelMap + scheduledTaskMaxReadLevelMap := s.mockShard.scheduledTaskMaxReadLevelMap s.Len(scheduledTaskMaxReadLevelMap, 2) s.True(scheduledTaskMaxReadLevelMap[cluster.TestCurrentClusterName].After(now)) s.True(scheduledTaskMaxReadLevelMap[cluster.TestAlternativeClusterName].After(now)) @@ -365,3 +367,56 @@ func (s *contextSuite) TestDeleteWorkflowExecution_ErrorAndContinue_Success() { s.NoError(err) s.Equal(tasks.DeleteWorkflowExecutionStageCurrent|tasks.DeleteWorkflowExecutionStageMutableState|tasks.DeleteWorkflowExecutionStageVisibility|tasks.DeleteWorkflowExecutionStageHistory, stage) } + +func (s *contextSuite) TestAcquireShardOwnershipLostErrorIsNotRetried() { + s.mockShard.state = contextStateAcquiring + s.mockShard.acquireShardRetryPolicy = backoff.NewExponentialRetryPolicy(time.Nanosecond). + WithMaximumAttempts(5) + s.mockShardManager.EXPECT().UpdateShard(gomock.Any(), gomock.Any()). + Return(&persistence.ShardOwnershipLostError{}).Times(1) + + s.mockShard.acquireShard() + + s.Assert().Equal(contextStateStopping, s.mockShard.state) +} + +func (s *contextSuite) TestAcquireShardNonOwnershipLostErrorIsRetried() { + s.mockShard.state = contextStateAcquiring + s.mockShard.acquireShardRetryPolicy = backoff.NewExponentialRetryPolicy(time.Nanosecond). + WithMaximumAttempts(5) + // TODO: make this 5 times instead of 6 when retry policy is fixed + s.mockShardManager.EXPECT().UpdateShard(gomock.Any(), gomock.Any()). + Return(fmt.Errorf("temp error")).Times(6) + + s.mockShard.acquireShard() + + s.Assert().Equal(contextStateStopping, s.mockShard.state) +} + +func (s *contextSuite) TestAcquireShardEventuallySucceeds() { + s.mockShard.state = contextStateAcquiring + s.mockShard.acquireShardRetryPolicy = backoff.NewExponentialRetryPolicy(time.Nanosecond). + WithMaximumAttempts(5) + s.mockShardManager.EXPECT().UpdateShard(gomock.Any(), gomock.Any()). + Return(fmt.Errorf("temp error")).Times(3) + s.mockShardManager.EXPECT().UpdateShard(gomock.Any(), gomock.Any()). + Return(nil).Times(1) + s.mockHistoryEngine.EXPECT().NotifyNewTasks(gomock.Any(), gomock.Any()).MinTimes(1) + + s.mockShard.acquireShard() + + s.Assert().Equal(contextStateAcquired, s.mockShard.state) +} + +func (s *contextSuite) TestAcquireShardNoError() { + s.mockShard.state = contextStateAcquiring + s.mockShard.acquireShardRetryPolicy = backoff.NewExponentialRetryPolicy(time.Nanosecond). + WithMaximumAttempts(5) + s.mockShardManager.EXPECT().UpdateShard(gomock.Any(), gomock.Any()). + Return(nil).Times(1) + s.mockHistoryEngine.EXPECT().NotifyNewTasks(gomock.Any(), gomock.Any()).MinTimes(1) + + s.mockShard.acquireShard() + + s.Assert().Equal(contextStateAcquired, s.mockShard.state) +}