From 0b86bc84b8f3df5fc84b27a717397c4db3d51788 Mon Sep 17 00:00:00 2001 From: wxing1292 Date: Wed, 29 Mar 2023 12:33:00 -0700 Subject: [PATCH] Add poison pill support for replication stream (#4116) * Add poison pill support for replication stream & UT --- common/dynamicconfig/constants.go | 2 + service/history/configs/config.go | 2 + .../executable_activity_state_task.go | 42 +++++++++++ .../executable_activity_state_task_test.go | 39 +++++++++- .../replication/executable_history_task.go | 66 +++++++++++++++++ .../executable_history_task_test.go | 73 ++++++++++++++++--- .../replication/executable_noop_task.go | 4 + .../replication/executable_noop_task_test.go | 5 ++ .../history/replication/executable_task.go | 6 +- .../replication/executable_task_tracker.go | 9 ++- .../executable_task_tracker_mock.go | 14 ++++ .../executable_task_tracker_test.go | 53 +++++++++++++- .../replication/executable_unknown_task.go | 13 ++++ .../executable_unknown_task_test.go | 11 ++- .../executable_workflow_state_task.go | 40 ++++++++++ .../executable_workflow_state_task_test.go | 35 ++++++++- service/history/replication/fx.go | 6 +- 17 files changed, 393 insertions(+), 27 deletions(-) diff --git a/common/dynamicconfig/constants.go b/common/dynamicconfig/constants.go index 73f8af1bd2d4..c5dac669f572 100644 --- a/common/dynamicconfig/constants.go +++ b/common/dynamicconfig/constants.go @@ -652,6 +652,8 @@ const ( ReplicationTaskProcessorShardQPS = "history.ReplicationTaskProcessorShardQPS" // ReplicationBypassCorruptedData is the flag to bypass corrupted workflow data in source cluster ReplicationBypassCorruptedData = "history.ReplicationBypassCorruptedData" + // ReplicationProcessorSchedulerQueueSize is the replication task executor queue size + ReplicationProcessorSchedulerQueueSize = "history.ReplicationProcessorSchedulerQueueSize" // ReplicationProcessorSchedulerWorkerCount is the replication task executor worker count ReplicationProcessorSchedulerWorkerCount = "history.ReplicationProcessorSchedulerWorkerCount" diff --git a/service/history/configs/config.go b/service/history/configs/config.go index 6de3aaa3867f..fbd32f3ae078 100644 --- a/service/history/configs/config.go +++ b/service/history/configs/config.go @@ -234,6 +234,7 @@ type Config struct { ReplicationTaskProcessorHostQPS dynamicconfig.FloatPropertyFn ReplicationTaskProcessorShardQPS dynamicconfig.FloatPropertyFn ReplicationBypassCorruptedData dynamicconfig.BoolPropertyFnWithNamespaceIDFilter + ReplicationProcessorSchedulerQueueSize dynamicconfig.IntPropertyFn ReplicationProcessorSchedulerWorkerCount dynamicconfig.IntPropertyFn // The following are used by consistent query @@ -403,6 +404,7 @@ func NewConfig(dc *dynamicconfig.Collection, numberOfShards int32, isAdvancedVis ReplicationTaskProcessorHostQPS: dc.GetFloat64Property(dynamicconfig.ReplicationTaskProcessorHostQPS, 1500), ReplicationTaskProcessorShardQPS: dc.GetFloat64Property(dynamicconfig.ReplicationTaskProcessorShardQPS, 30), ReplicationBypassCorruptedData: dc.GetBoolPropertyFnWithNamespaceIDFilter(dynamicconfig.ReplicationBypassCorruptedData, false), + ReplicationProcessorSchedulerQueueSize: dc.GetIntProperty(dynamicconfig.ReplicationProcessorSchedulerQueueSize, 128), ReplicationProcessorSchedulerWorkerCount: dc.GetIntProperty(dynamicconfig.ReplicationProcessorSchedulerWorkerCount, 512), MaximumBufferedEventsBatch: dc.GetIntProperty(dynamicconfig.MaximumBufferedEventsBatch, 100), diff --git a/service/history/replication/executable_activity_state_task.go b/service/history/replication/executable_activity_state_task.go index ae9d34b06f39..d0388ae984d7 100644 --- a/service/history/replication/executable_activity_state_task.go +++ b/service/history/replication/executable_activity_state_task.go @@ -29,11 +29,15 @@ import ( "go.temporal.io/api/serviceerror" + enumsspb "go.temporal.io/server/api/enums/v1" "go.temporal.io/server/api/historyservice/v1" + persistencespb "go.temporal.io/server/api/persistence/v1" replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/common/definition" + "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/persistence" serviceerrors "go.temporal.io/server/common/serviceerror" ctasks "go.temporal.io/server/common/tasks" ) @@ -142,3 +146,41 @@ func (e *ExecutableActivityStateTask) HandleErr(err error) error { return err } } + +func (e *ExecutableActivityStateTask) MarkPoisonPill() error { + shardContext, err := e.ShardController.GetShardByNamespaceWorkflow( + namespace.ID(e.NamespaceID), + e.WorkflowID, + ) + if err != nil { + return err + } + + // TODO: GetShardID will break GetDLQReplicationMessages we need to handle DLQ for cross shard replication. + req := &persistence.PutReplicationTaskToDLQRequest{ + ShardID: shardContext.GetShardID(), + SourceClusterName: e.sourceClusterName, + TaskInfo: &persistencespb.ReplicationTaskInfo{ + NamespaceId: e.NamespaceID, + WorkflowId: e.WorkflowID, + RunId: e.RunID, + TaskId: e.ExecutableTask.TaskID(), + TaskType: enumsspb.TASK_TYPE_REPLICATION_SYNC_ACTIVITY, + ScheduledEventId: e.req.ScheduledEventId, + Version: e.req.Version, + }, + } + + e.Logger.Error("enqueue activity state replication task to DLQ", + tag.ShardID(shardContext.GetShardID()), + tag.WorkflowNamespaceID(e.NamespaceID), + tag.WorkflowID(e.WorkflowID), + tag.WorkflowRunID(e.RunID), + tag.TaskID(e.ExecutableTask.TaskID()), + ) + + ctx, cancel := newTaskContext(e.NamespaceID) + defer cancel() + + return shardContext.GetExecutionManager().PutReplicationTaskToDLQ(ctx, req) +} diff --git a/service/history/replication/executable_activity_state_task_test.go b/service/history/replication/executable_activity_state_task_test.go index a3ea072b68ce..702f38d15ddb 100644 --- a/service/history/replication/executable_activity_state_task_test.go +++ b/service/history/replication/executable_activity_state_task_test.go @@ -38,8 +38,10 @@ import ( failurepb "go.temporal.io/api/failure/v1" "go.temporal.io/api/serviceerror" + enumsspb "go.temporal.io/server/api/enums/v1" "go.temporal.io/server/api/history/v1" "go.temporal.io/server/api/historyservice/v1" + persistencespb "go.temporal.io/server/api/persistence/v1" replicationspb "go.temporal.io/server/api/replication/v1" workflowspb "go.temporal.io/server/api/workflow/v1" "go.temporal.io/server/client" @@ -47,6 +49,7 @@ import ( "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/primitives/timestamp" serviceerrors "go.temporal.io/server/common/serviceerror" "go.temporal.io/server/common/xdc" @@ -71,7 +74,8 @@ type ( replicationTask *replicationspb.SyncActivityTaskAttributes sourceClusterName string - task *ExecutableActivityStateTask + taskID int64 + task *ExecutableActivityStateTask } ) @@ -116,7 +120,7 @@ func (s *executableActivityStateTaskSuite) SetupTest() { VersionHistory: &history.VersionHistory{}, } s.sourceClusterName = cluster.TestCurrentClusterName - + s.taskID = rand.Int63() s.task = NewExecutableActivityStateTask( ProcessToolBox{ ClusterMetadata: s.clusterMetadata, @@ -127,12 +131,13 @@ func (s *executableActivityStateTaskSuite) SetupTest() { MetricsHandler: s.metricsHandler, Logger: s.logger, }, - rand.Int63(), + s.taskID, time.Unix(0, rand.Int63()), s.replicationTask, s.sourceClusterName, ) s.task.ExecutableTask = s.executableTask + s.executableTask.EXPECT().TaskID().Return(s.taskID).AnyTimes() } func (s *executableActivityStateTaskSuite) TearDownTest() { @@ -264,3 +269,31 @@ func (s *executableActivityStateTaskSuite) TestHandleErr_Other() { err = serviceerror.NewUnavailable("") s.Equal(err, s.task.HandleErr(err)) } + +func (s *executableActivityStateTaskSuite) TestMarkPoisonPill() { + shardID := rand.Int31() + shardContext := shard.NewMockContext(s.controller) + executionManager := persistence.NewMockExecutionManager(s.controller) + s.shardController.EXPECT().GetShardByNamespaceWorkflow( + namespace.ID(s.task.NamespaceID), + s.task.WorkflowID, + ).Return(shardContext, nil).AnyTimes() + shardContext.EXPECT().GetShardID().Return(shardID).AnyTimes() + shardContext.EXPECT().GetExecutionManager().Return(executionManager).AnyTimes() + executionManager.EXPECT().PutReplicationTaskToDLQ(gomock.Any(), &persistence.PutReplicationTaskToDLQRequest{ + ShardID: shardID, + SourceClusterName: s.sourceClusterName, + TaskInfo: &persistencespb.ReplicationTaskInfo{ + NamespaceId: s.task.NamespaceID, + WorkflowId: s.task.WorkflowID, + RunId: s.task.RunID, + TaskId: s.task.ExecutableTask.TaskID(), + TaskType: enumsspb.TASK_TYPE_REPLICATION_SYNC_ACTIVITY, + ScheduledEventId: s.task.req.ScheduledEventId, + Version: s.task.req.Version, + }, + }).Return(nil) + + err := s.task.MarkPoisonPill() + s.NoError(err) +} diff --git a/service/history/replication/executable_history_task.go b/service/history/replication/executable_history_task.go index ff2705402042..b61416406ea9 100644 --- a/service/history/replication/executable_history_task.go +++ b/service/history/replication/executable_history_task.go @@ -30,11 +30,16 @@ import ( commonpb "go.temporal.io/api/common/v1" "go.temporal.io/api/serviceerror" + enumsspb "go.temporal.io/server/api/enums/v1" "go.temporal.io/server/api/historyservice/v1" + persistencespb "go.temporal.io/server/api/persistence/v1" replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/common/definition" + "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" serviceerrors "go.temporal.io/server/common/serviceerror" ctasks "go.temporal.io/server/common/tasks" ) @@ -138,3 +143,64 @@ func (e *ExecutableHistoryTask) HandleErr(err error) error { return err } } + +func (e *ExecutableHistoryTask) MarkPoisonPill() error { + shardContext, err := e.ShardController.GetShardByNamespaceWorkflow( + namespace.ID(e.NamespaceID), + e.WorkflowID, + ) + if err != nil { + return err + } + + events, err := serialization.NewSerializer().DeserializeEvents(e.req.Events) + if err != nil { + e.Logger.Error("unable to enqueue history replication task to DLQ, ser/de error", + tag.ShardID(shardContext.GetShardID()), + tag.WorkflowNamespaceID(e.NamespaceID), + tag.WorkflowID(e.WorkflowID), + tag.WorkflowRunID(e.RunID), + tag.TaskID(e.ExecutableTask.TaskID()), + tag.Error(err), + ) + return nil + } else if len(events) == 0 { + e.Logger.Error("unable to enqueue history replication task to DLQ, no events", + tag.ShardID(shardContext.GetShardID()), + tag.WorkflowNamespaceID(e.NamespaceID), + tag.WorkflowID(e.WorkflowID), + tag.WorkflowRunID(e.RunID), + tag.TaskID(e.ExecutableTask.TaskID()), + ) + return nil + } + + // TODO: GetShardID will break GetDLQReplicationMessages we need to handle DLQ for cross shard replication. + req := &persistence.PutReplicationTaskToDLQRequest{ + ShardID: shardContext.GetShardID(), + SourceClusterName: e.sourceClusterName, + TaskInfo: &persistencespb.ReplicationTaskInfo{ + NamespaceId: e.NamespaceID, + WorkflowId: e.WorkflowID, + RunId: e.RunID, + TaskId: e.ExecutableTask.TaskID(), + TaskType: enumsspb.TASK_TYPE_REPLICATION_HISTORY, + FirstEventId: events[0].GetEventId(), + NextEventId: events[len(events)-1].GetEventId() + 1, + Version: events[0].GetVersion(), + }, + } + + e.Logger.Error("enqueue history replication task to DLQ", + tag.ShardID(shardContext.GetShardID()), + tag.WorkflowNamespaceID(e.NamespaceID), + tag.WorkflowID(e.WorkflowID), + tag.WorkflowRunID(e.RunID), + tag.TaskID(e.ExecutableTask.TaskID()), + ) + + ctx, cancel := newTaskContext(e.NamespaceID) + defer cancel() + + return shardContext.GetExecutionManager().PutReplicationTaskToDLQ(ctx, req) +} diff --git a/service/history/replication/executable_history_task_test.go b/service/history/replication/executable_history_task_test.go index 2884c6f0856f..bade477070bb 100644 --- a/service/history/replication/executable_history_task_test.go +++ b/service/history/replication/executable_history_task_test.go @@ -35,10 +35,14 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" commonpb "go.temporal.io/api/common/v1" + enumspb "go.temporal.io/api/enums/v1" + historypb "go.temporal.io/api/history/v1" "go.temporal.io/api/serviceerror" + enumsspb "go.temporal.io/server/api/enums/v1" "go.temporal.io/server/api/history/v1" "go.temporal.io/server/api/historyservice/v1" + persistencespb "go.temporal.io/server/api/persistence/v1" replicationspb "go.temporal.io/server/api/replication/v1" workflowspb "go.temporal.io/server/api/workflow/v1" "go.temporal.io/server/client" @@ -46,6 +50,8 @@ import ( "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/persistence" + "go.temporal.io/server/common/persistence/serialization" serviceerrors "go.temporal.io/server/common/serviceerror" "go.temporal.io/server/common/xdc" "go.temporal.io/server/service/history/shard" @@ -69,7 +75,8 @@ type ( replicationTask *replicationspb.HistoryTaskAttributes sourceClusterName string - task *ExecutableHistoryTask + taskID int64 + task *ExecutableHistoryTask } ) @@ -96,17 +103,33 @@ func (s *executableHistoryTaskSuite) SetupTest() { s.metricsHandler = metrics.NoopMetricsHandler s.logger = log.NewNoopLogger() s.executableTask = NewMockExecutableTask(s.controller) + + firstEventID := rand.Int63() + nextEventID := firstEventID + 1 + version := rand.Int63() + events, _ := serialization.NewSerializer().SerializeEvents([]*historypb.HistoryEvent{{ + EventId: firstEventID, + Version: version, + }}, enumspb.ENCODING_TYPE_PROTO3) + newEvents, _ := serialization.NewSerializer().SerializeEvents([]*historypb.HistoryEvent{{ + EventId: 1, + Version: version, + }}, enumspb.ENCODING_TYPE_PROTO3) s.replicationTask = &replicationspb.HistoryTaskAttributes{ - NamespaceId: uuid.NewString(), - WorkflowId: uuid.NewString(), - RunId: uuid.NewString(), - BaseExecutionInfo: &workflowspb.BaseExecutionInfo{}, - VersionHistoryItems: []*history.VersionHistoryItem{}, - Events: &commonpb.DataBlob{}, - NewRunEvents: &commonpb.DataBlob{}, + NamespaceId: uuid.NewString(), + WorkflowId: uuid.NewString(), + RunId: uuid.NewString(), + BaseExecutionInfo: &workflowspb.BaseExecutionInfo{}, + VersionHistoryItems: []*history.VersionHistoryItem{{ + EventId: nextEventID - 1, + Version: version, + }}, + Events: events, + NewRunEvents: newEvents, } s.sourceClusterName = cluster.TestCurrentClusterName + s.taskID = rand.Int63() s.task = NewExecutableHistoryTask( ProcessToolBox{ ClusterMetadata: s.clusterMetadata, @@ -117,12 +140,13 @@ func (s *executableHistoryTaskSuite) SetupTest() { MetricsHandler: s.metricsHandler, Logger: s.logger, }, - rand.Int63(), + s.taskID, time.Unix(0, rand.Int63()), s.replicationTask, s.sourceClusterName, ) s.task.ExecutableTask = s.executableTask + s.executableTask.EXPECT().TaskID().Return(s.taskID).AnyTimes() } func (s *executableHistoryTaskSuite) TearDownTest() { @@ -242,3 +266,34 @@ func (s *executableHistoryTaskSuite) TestHandleErr_Other() { err = serviceerror.NewUnavailable("") s.Equal(err, s.task.HandleErr(err)) } + +func (s *executableHistoryTaskSuite) TestMarkPoisonPill() { + events, _ := serialization.NewSerializer().DeserializeEvents(s.task.req.Events) + + shardID := rand.Int31() + shardContext := shard.NewMockContext(s.controller) + executionManager := persistence.NewMockExecutionManager(s.controller) + s.shardController.EXPECT().GetShardByNamespaceWorkflow( + namespace.ID(s.task.NamespaceID), + s.task.WorkflowID, + ).Return(shardContext, nil).AnyTimes() + shardContext.EXPECT().GetShardID().Return(shardID).AnyTimes() + shardContext.EXPECT().GetExecutionManager().Return(executionManager).AnyTimes() + executionManager.EXPECT().PutReplicationTaskToDLQ(gomock.Any(), &persistence.PutReplicationTaskToDLQRequest{ + ShardID: shardID, + SourceClusterName: s.sourceClusterName, + TaskInfo: &persistencespb.ReplicationTaskInfo{ + NamespaceId: s.task.NamespaceID, + WorkflowId: s.task.WorkflowID, + RunId: s.task.RunID, + TaskId: s.task.ExecutableTask.TaskID(), + TaskType: enumsspb.TASK_TYPE_REPLICATION_HISTORY, + FirstEventId: events[0].GetEventId(), + NextEventId: events[len(events)-1].GetEventId() + 1, + Version: events[0].GetVersion(), + }, + }).Return(nil) + + err := s.task.MarkPoisonPill() + s.NoError(err) +} diff --git a/service/history/replication/executable_noop_task.go b/service/history/replication/executable_noop_task.go index 2f23ef684210..f88c2d1974e5 100644 --- a/service/history/replication/executable_noop_task.go +++ b/service/history/replication/executable_noop_task.go @@ -63,3 +63,7 @@ func (e *ExecutableNoopTask) Execute() error { func (e *ExecutableNoopTask) HandleErr(err error) error { return err } + +func (e *ExecutableNoopTask) MarkPoisonPill() error { + return nil +} diff --git a/service/history/replication/executable_noop_task_test.go b/service/history/replication/executable_noop_task_test.go index abceffb248c6..df9546df9477 100644 --- a/service/history/replication/executable_noop_task_test.go +++ b/service/history/replication/executable_noop_task_test.go @@ -116,3 +116,8 @@ func (s *executableNoopTaskSuite) TestHandleErr() { err = serviceerror.NewUnavailable("") s.Equal(err, s.task.HandleErr(err)) } + +func (s *executableNoopTaskSuite) TestMarkPoisonPill() { + err := s.task.MarkPoisonPill() + s.NoError(err) +} diff --git a/service/history/replication/executable_task.go b/service/history/replication/executable_task.go index 57c6c7433c88..3628497d0480 100644 --- a/service/history/replication/executable_task.go +++ b/service/history/replication/executable_task.go @@ -55,6 +55,10 @@ const ( taskStateNacked = int32(ctasks.TaskStateNacked) ) +const ( + applyReplicationTimeout = 20 * time.Second +) + var ( TaskRetryPolicy = backoff.NewExponentialRetryPolicy(1 * time.Second). WithBackoffCoefficient(1.2). @@ -345,5 +349,5 @@ func newTaskContext( headers.SystemPreemptableCallerInfo, ) ctx = headers.SetCallerName(ctx, namespaceName) - return context.WithTimeout(ctx, replicationTimeout) + return context.WithTimeout(ctx, applyReplicationTimeout) } diff --git a/service/history/replication/executable_task_tracker.go b/service/history/replication/executable_task_tracker.go index 30b64f206a43..3c4665929f84 100644 --- a/service/history/replication/executable_task_tracker.go +++ b/service/history/replication/executable_task_tracker.go @@ -41,6 +41,7 @@ type ( ctasks.Task TaskID() int64 TaskCreationTime() time.Time + MarkPoisonPill() error } WatermarkInfo struct { Watermark int64 @@ -131,8 +132,12 @@ Loop: delete(t.taskIDs, task.TaskID()) t.taskQueue.Remove(element) case ctasks.TaskStateNacked: - // TODO put to DLQ, only after <- is successful, then remove from tracker - panic("implement me") + if err := task.MarkPoisonPill(); err != nil { + // unable to save poison pill, retry later + break Loop + } + delete(t.taskIDs, task.TaskID()) + t.taskQueue.Remove(element) case ctasks.TaskStateCancelled: // noop, do not remove from queue, let it block low watermark break Loop diff --git a/service/history/replication/executable_task_tracker_mock.go b/service/history/replication/executable_task_tracker_mock.go index 47fbe30e38fb..5d5333fbf6b6 100644 --- a/service/history/replication/executable_task_tracker_mock.go +++ b/service/history/replication/executable_task_tracker_mock.go @@ -126,6 +126,20 @@ func (mr *MockTrackableExecutableTaskMockRecorder) IsRetryableError(err interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsRetryableError", reflect.TypeOf((*MockTrackableExecutableTask)(nil).IsRetryableError), err) } +// MarkPoisonPill mocks base method. +func (m *MockTrackableExecutableTask) MarkPoisonPill() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkPoisonPill") + ret0, _ := ret[0].(error) + return ret0 +} + +// MarkPoisonPill indicates an expected call of MarkPoisonPill. +func (mr *MockTrackableExecutableTaskMockRecorder) MarkPoisonPill() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPoisonPill", reflect.TypeOf((*MockTrackableExecutableTask)(nil).MarkPoisonPill)) +} + // Nack mocks base method. func (m *MockTrackableExecutableTask) Nack(err error) { m.ctrl.T.Helper() diff --git a/service/history/replication/executable_task_tracker_test.go b/service/history/replication/executable_task_tracker_test.go index 275c77e1b1ef..6dcfc5db47d6 100644 --- a/service/history/replication/executable_task_tracker_test.go +++ b/service/history/replication/executable_task_tracker_test.go @@ -25,6 +25,7 @@ package replication import ( + "errors" "math/rand" "testing" "time" @@ -188,11 +189,35 @@ func (s *executableTaskTrackerSuite) TestLowWatermark_AckedTask() { s.Equal([]int64{}, taskIDs) } -func (s *executableTaskTrackerSuite) TestLowWatermark_CancelledTask() { +func (s *executableTaskTrackerSuite) TestLowWatermark_NackedTask_Success() { task0 := NewMockTrackableExecutableTask(s.controller) task0.EXPECT().TaskID().Return(rand.Int63()).AnyTimes() task0.EXPECT().TaskCreationTime().Return(time.Unix(0, rand.Int63())).AnyTimes() - task0.EXPECT().State().Return(ctasks.TaskStateCancelled).AnyTimes() + task0.EXPECT().State().Return(ctasks.TaskStateNacked).AnyTimes() + task0.EXPECT().MarkPoisonPill().Return(nil) + + highWatermark0 := WatermarkInfo{ + Watermark: task0.TaskID() + 1, + Timestamp: time.Unix(0, rand.Int63()), + } + s.taskTracker.TrackTasks(highWatermark0, task0) + + lowWatermark := s.taskTracker.LowWatermark() + s.Equal(highWatermark0, *lowWatermark) + + taskIDs := []int64{} + for element := s.taskTracker.taskQueue.Front(); element != nil; element = element.Next() { + taskIDs = append(taskIDs, element.Value.(TrackableExecutableTask).TaskID()) + } + s.Equal([]int64{}, taskIDs) +} + +func (s *executableTaskTrackerSuite) TestLowWatermark_NackedTask_Error() { + task0 := NewMockTrackableExecutableTask(s.controller) + task0.EXPECT().TaskID().Return(rand.Int63()).AnyTimes() + task0.EXPECT().TaskCreationTime().Return(time.Unix(0, rand.Int63())).AnyTimes() + task0.EXPECT().State().Return(ctasks.TaskStateNacked).AnyTimes() + task0.EXPECT().MarkPoisonPill().Return(errors.New("random error")) s.taskTracker.TrackTasks(WatermarkInfo{ Watermark: task0.TaskID() + 1, @@ -212,8 +237,28 @@ func (s *executableTaskTrackerSuite) TestLowWatermark_CancelledTask() { s.Equal([]int64{task0.TaskID()}, taskIDs) } -func (s *executableTaskTrackerSuite) TestLowWatermark_NackedTask() { - // TODO add support for poison pill +func (s *executableTaskTrackerSuite) TestLowWatermark_CancelledTask() { + task0 := NewMockTrackableExecutableTask(s.controller) + task0.EXPECT().TaskID().Return(rand.Int63()).AnyTimes() + task0.EXPECT().TaskCreationTime().Return(time.Unix(0, rand.Int63())).AnyTimes() + task0.EXPECT().State().Return(ctasks.TaskStateCancelled).AnyTimes() + + s.taskTracker.TrackTasks(WatermarkInfo{ + Watermark: task0.TaskID() + 1, + Timestamp: time.Unix(0, rand.Int63()), + }, task0) + + lowWatermark := s.taskTracker.LowWatermark() + s.Equal(WatermarkInfo{ + Watermark: task0.TaskID(), + Timestamp: task0.TaskCreationTime(), + }, *lowWatermark) + + taskIDs := []int64{} + for element := s.taskTracker.taskQueue.Front(); element != nil; element = element.Next() { + taskIDs = append(taskIDs, element.Value.(TrackableExecutableTask).TaskID()) + } + s.Equal([]int64{task0.TaskID()}, taskIDs) } func (s *executableTaskTrackerSuite) TestLowWatermark_PendingTask() { diff --git a/service/history/replication/executable_unknown_task.go b/service/history/replication/executable_unknown_task.go index 0bef8992682a..d41e20cfb016 100644 --- a/service/history/replication/executable_unknown_task.go +++ b/service/history/replication/executable_unknown_task.go @@ -30,12 +30,15 @@ import ( "go.temporal.io/api/serviceerror" + "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" ctasks "go.temporal.io/server/common/tasks" ) type ( ExecutableUnknownTask struct { + ProcessToolBox + ExecutableTask task any } @@ -51,6 +54,8 @@ func NewExecutableUnknownTask( task any, ) *ExecutableUnknownTask { return &ExecutableUnknownTask{ + ProcessToolBox: processToolBox, + ExecutableTask: NewExecutableTask( processToolBox, taskID, @@ -75,3 +80,11 @@ func (e *ExecutableUnknownTask) HandleErr(err error) error { func (e *ExecutableUnknownTask) IsRetryableError(err error) bool { return false } + +func (e *ExecutableUnknownTask) MarkPoisonPill() error { + e.Logger.Error("unable to enqueue unknown replication task to DLQ", + tag.Task(e.task), + tag.TaskID(e.ExecutableTask.TaskID()), + ) + return nil +} diff --git a/service/history/replication/executable_unknown_task_test.go b/service/history/replication/executable_unknown_task_test.go index 901bbd902128..c8cd2c1d1903 100644 --- a/service/history/replication/executable_unknown_task_test.go +++ b/service/history/replication/executable_unknown_task_test.go @@ -58,7 +58,8 @@ type ( metricsHandler metrics.Handler logger log.Logger - task *ExecutableUnknownTask + taskID int64 + task *ExecutableUnknownTask } ) @@ -85,6 +86,7 @@ func (s *executableUnknownTaskSuite) SetupTest() { s.metricsHandler = metrics.NoopMetricsHandler s.logger = log.NewNoopLogger() + s.taskID = rand.Int63() s.task = NewExecutableUnknownTask( ProcessToolBox{ ClusterMetadata: s.clusterMetadata, @@ -95,7 +97,7 @@ func (s *executableUnknownTaskSuite) SetupTest() { MetricsHandler: s.metricsHandler, Logger: s.logger, }, - rand.Int63(), + s.taskID, time.Unix(0, rand.Int63()), nil, ) @@ -117,3 +119,8 @@ func (s *executableUnknownTaskSuite) TestHandleErr() { err = serviceerror.NewUnavailable("") s.Equal(err, s.task.HandleErr(err)) } + +func (s *executableUnknownTaskSuite) TestMarkPoisonPill() { + err := s.task.MarkPoisonPill() + s.NoError(err) +} diff --git a/service/history/replication/executable_workflow_state_task.go b/service/history/replication/executable_workflow_state_task.go index 824c7de03841..fddeec1d6744 100644 --- a/service/history/replication/executable_workflow_state_task.go +++ b/service/history/replication/executable_workflow_state_task.go @@ -29,11 +29,15 @@ import ( "go.temporal.io/api/serviceerror" + enumsspb "go.temporal.io/server/api/enums/v1" "go.temporal.io/server/api/historyservice/v1" + persistencespb "go.temporal.io/server/api/persistence/v1" replicationspb "go.temporal.io/server/api/replication/v1" "go.temporal.io/server/common/definition" + "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/persistence" ctasks "go.temporal.io/server/common/tasks" ) @@ -119,3 +123,39 @@ func (e *ExecutableWorkflowStateTask) HandleErr(err error) error { return err } } + +func (e *ExecutableWorkflowStateTask) MarkPoisonPill() error { + shardContext, err := e.ShardController.GetShardByNamespaceWorkflow( + namespace.ID(e.NamespaceID), + e.WorkflowID, + ) + if err != nil { + return err + } + + // TODO: GetShardID will break GetDLQReplicationMessages we need to handle DLQ for cross shard replication. + req := &persistence.PutReplicationTaskToDLQRequest{ + ShardID: shardContext.GetShardID(), + SourceClusterName: e.sourceClusterName, + TaskInfo: &persistencespb.ReplicationTaskInfo{ + NamespaceId: e.NamespaceID, + WorkflowId: e.WorkflowID, + RunId: e.RunID, + TaskId: e.ExecutableTask.TaskID(), + TaskType: enumsspb.TASK_TYPE_REPLICATION_SYNC_WORKFLOW_STATE, + }, + } + + e.Logger.Error("enqueue workflow state replication task to DLQ", + tag.ShardID(shardContext.GetShardID()), + tag.WorkflowNamespaceID(e.NamespaceID), + tag.WorkflowID(e.WorkflowID), + tag.WorkflowRunID(e.RunID), + tag.TaskID(e.ExecutableTask.TaskID()), + ) + + ctx, cancel := newTaskContext(e.NamespaceID) + defer cancel() + + return shardContext.GetExecutionManager().PutReplicationTaskToDLQ(ctx, req) +} diff --git a/service/history/replication/executable_workflow_state_task_test.go b/service/history/replication/executable_workflow_state_task_test.go index 85a2af411c8d..2360cff790a3 100644 --- a/service/history/replication/executable_workflow_state_task_test.go +++ b/service/history/replication/executable_workflow_state_task_test.go @@ -36,6 +36,7 @@ import ( "github.com/stretchr/testify/suite" "go.temporal.io/api/serviceerror" + enumsspb "go.temporal.io/server/api/enums/v1" "go.temporal.io/server/api/historyservice/v1" persistencepb "go.temporal.io/server/api/persistence/v1" replicationspb "go.temporal.io/server/api/replication/v1" @@ -44,6 +45,7 @@ import ( "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" + "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/xdc" "go.temporal.io/server/service/history/shard" ) @@ -66,7 +68,8 @@ type ( replicationTask *replicationspb.SyncWorkflowStateTaskAttributes sourceClusterName string - task *ExecutableWorkflowStateTask + taskID int64 + task *ExecutableWorkflowStateTask } ) @@ -106,6 +109,7 @@ func (s *executableWorkflowStateTaskSuite) SetupTest() { } s.sourceClusterName = cluster.TestCurrentClusterName + s.taskID = rand.Int63() s.task = NewExecutableWorkflowStateTask( ProcessToolBox{ ClusterMetadata: s.clusterMetadata, @@ -116,12 +120,13 @@ func (s *executableWorkflowStateTaskSuite) SetupTest() { MetricsHandler: s.metricsHandler, Logger: s.logger, }, - rand.Int63(), + s.taskID, time.Unix(0, rand.Int63()), s.replicationTask, s.sourceClusterName, ) s.task.ExecutableTask = s.executableTask + s.executableTask.EXPECT().TaskID().Return(s.taskID).AnyTimes() } func (s *executableWorkflowStateTaskSuite) TearDownTest() { @@ -178,3 +183,29 @@ func (s *executableWorkflowStateTaskSuite) TestHandleErr() { err = serviceerror.NewUnavailable("") s.Equal(err, s.task.HandleErr(err)) } + +func (s *executableWorkflowStateTaskSuite) TestMarkPoisonPill() { + shardID := rand.Int31() + shardContext := shard.NewMockContext(s.controller) + executionManager := persistence.NewMockExecutionManager(s.controller) + s.shardController.EXPECT().GetShardByNamespaceWorkflow( + namespace.ID(s.task.NamespaceID), + s.task.WorkflowID, + ).Return(shardContext, nil).AnyTimes() + shardContext.EXPECT().GetShardID().Return(shardID).AnyTimes() + shardContext.EXPECT().GetExecutionManager().Return(executionManager).AnyTimes() + executionManager.EXPECT().PutReplicationTaskToDLQ(gomock.Any(), &persistence.PutReplicationTaskToDLQRequest{ + ShardID: shardID, + SourceClusterName: s.sourceClusterName, + TaskInfo: &persistencepb.ReplicationTaskInfo{ + NamespaceId: s.task.NamespaceID, + WorkflowId: s.task.WorkflowID, + RunId: s.task.RunID, + TaskId: s.task.ExecutableTask.TaskID(), + TaskType: enumsspb.TASK_TYPE_REPLICATION_SYNC_WORKFLOW_STATE, + }, + }).Return(nil) + + err := s.task.MarkPoisonPill() + s.NoError(err) +} diff --git a/service/history/replication/fx.go b/service/history/replication/fx.go index 1b5dd86e3d4c..19702fce1bdc 100644 --- a/service/history/replication/fx.go +++ b/service/history/replication/fx.go @@ -81,10 +81,8 @@ func ReplicationStreamSchedulerProvider( ) ctasks.Scheduler[ctasks.Task] { return ctasks.NewFIFOScheduler[ctasks.Task]( &ctasks.FIFOSchedulerOptions{ - // TODO make it configurable - QueueSize: 1024, - // TODO need to apply events sequentially per workflow - WorkerCount: func() int { return 1 }, // config.ReplicationProcessorSchedulerWorkerCount, + QueueSize: config.ReplicationProcessorSchedulerQueueSize(), + WorkerCount: config.ReplicationProcessorSchedulerWorkerCount, }, logger, )