diff --git a/service/history/replication/dlq_handler_test.go b/service/history/replication/dlq_handler_test.go index 3244d2879c1..95e867ab003 100644 --- a/service/history/replication/dlq_handler_test.go +++ b/service/history/replication/dlq_handler_test.go @@ -22,16 +22,17 @@ package replication import ( "context" - "testing" - + "errors" "github.com/golang/mock/gomock" "github.com/pborman/uuid" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "testing" "github.com/uber/cadence/client" "github.com/uber/cadence/client/admin" + "github.com/uber/cadence/common" "github.com/uber/cadence/common/mocks" "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/types" @@ -110,11 +111,146 @@ func (s *dlqHandlerSuite) TearDownTest() { s.mockShard.Finish(s.T()) } +func (s *dlqHandlerSuite) TestNewDLQHandler_panic() { + s.Panics(func() { NewDLQHandler(s.mockShard, nil) }, "Failed to initialize replication DLQ handler due to nil task executors") +} + +func (s *dlqHandlerSuite) TestStart() { + tests := []struct { + name string + status int32 + }{ + { + name: "started", + status: common.DaemonStatusInitialized, + }, + { + name: "not started", + status: common.DaemonStatusStopped, + }, + } + + for _, tc := range tests { + s.T().Run(tc.name, func(t *testing.T) { + s.messageHandler.status = tc.status + + s.messageHandler.Start() + }) + } +} + +func (s *dlqHandlerSuite) TestStop() { + tests := []struct { + name string + status int32 + }{ + { + name: "stopped", + status: common.DaemonStatusStopped, + }, + { + name: "started", + status: common.DaemonStatusStarted, + }, + } + + for _, tc := range tests { + s.T().Run(tc.name, func(t *testing.T) { + s.messageHandler.status = tc.status + + s.messageHandler.Stop() + }) + } +} + +func (s *dlqHandlerSuite) TestGetMessageCount() { + size := int64(1) + tests := []struct { + name string + latestCounts map[string]int64 + forceFetch bool + err error + }{ + { + name: "success", + latestCounts: map[string]int64{s.sourceCluster: size}, + }, + { + name: "success with fetchAndEmitMessageCount call", + forceFetch: true, + }, + { + name: "error", + forceFetch: true, + err: errors.New("fetchAndEmitMessageCount error"), + }, + } + + for _, tc := range tests { + s.T().Run(tc.name, func(t *testing.T) { + s.messageHandler.latestCounts = tc.latestCounts + + if tc.forceFetch || tc.latestCounts == nil { + s.executionManager.On("GetReplicationDLQSize", mock.Anything, mock.Anything).Return(&persistence.GetReplicationDLQSizeResponse{Size: size}, tc.err).Times(1) + } + + counts, err := s.messageHandler.GetMessageCount(context.Background(), tc.forceFetch) + + if tc.err != nil { + s.Error(err) + s.Equal(tc.err, err) + } else if tc.latestCounts != nil { + s.NoError(err) + s.Equal(size, counts[s.sourceCluster]) + } else { + s.NoError(err) + } + }) + } +} + +func (s *dlqHandlerSuite) TestFetchAndEmitMessageCount() { + tests := []struct { + name string + err error + }{ + { + name: "success", + err: nil, + }, + { + name: "error", + err: errors.New("error"), + }, + } + + for _, tc := range tests { + s.T().Run(tc.name, func(t *testing.T) { + size := int64(3) + rets := &persistence.GetReplicationDLQSizeResponse{Size: size} + s.messageHandler.latestCounts = make(map[string]int64) + + s.executionManager.On("GetReplicationDLQSize", context.Background(), mock.Anything).Return(rets, tc.err).Times(1) + + err := s.messageHandler.fetchAndEmitMessageCount(context.Background()) + + if tc.err != nil { + s.Error(err) + s.Equal(tc.err, err) + } else { + s.NoError(err) + s.Equal(len(s.messageHandler.latestCounts), len(s.taskExecutors)) + s.Equal(size, s.messageHandler.latestCounts[s.sourceCluster]) + } + }) + } +} + func (s *dlqHandlerSuite) TestReadMessages_OK() { ctx := context.Background() lastMessageID := int64(1) pageSize := 1 - pageToken := []byte{} + var pageToken []byte resp := &persistence.GetReplicationTasksFromDLQResponse{ Tasks: []*persistence.ReplicationTaskInfo{ @@ -150,26 +286,187 @@ func (s *dlqHandlerSuite) TestReadMessages_OK() { s.Nil(tasks) } -func (s *dlqHandlerSuite) TestPurgeMessages_OK() { - sourceCluster := "test" - lastMessageID := int64(1) +func (s *dlqHandlerSuite) TestReadMessagesWithAckLevel_OK() { + replicationTasksResponse := &persistence.GetReplicationTasksFromDLQResponse{ + Tasks: []*persistence.ReplicationTaskInfo{ + { + DomainID: "domainID", + TaskID: 123, + WorkflowID: "workflowID", + RunID: "runID", + TaskType: 5, + Version: 1, + FirstEventID: 1, + NextEventID: 2, + ScheduledID: 3, + }, + }, + NextPageToken: []byte("token"), + } - s.executionManager.On("RangeDeleteReplicationTaskFromDLQ", mock.Anything, - &persistence.RangeDeleteReplicationTaskFromDLQRequest{ - SourceClusterName: sourceCluster, - ExclusiveBeginTaskID: -1, - InclusiveEndTaskID: lastMessageID, - }).Return(&persistence.RangeDeleteReplicationTaskFromDLQResponse{TasksCompleted: persistence.UnknownNumRowsAffected}, nil).Times(1) + DLQReplicationMessagesResponse := &types.GetDLQReplicationMessagesResponse{ + ReplicationTasks: []*types.ReplicationTask{ + { + SourceTaskID: 123, + }, + }, + } + + ctx := context.Background() + lastMessageID := int64(123) + pageSize := 12 + pageToken := []byte("token") + + req := &persistence.GetReplicationTasksFromDLQRequest{ + SourceClusterName: s.sourceCluster, + GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ + ReadLevel: defaultBeginningMessageID, + MaxReadLevel: lastMessageID, + BatchSize: pageSize, + NextPageToken: pageToken, + }, + } + + s.executionManager.On("GetReplicationTasksFromDLQ", ctx, req).Return(replicationTasksResponse, nil).Times(1) + + s.adminClient.EXPECT(). + GetDLQReplicationMessages(ctx, gomock.Any()). + Return(DLQReplicationMessagesResponse, nil).Times(1) + + replicationTasks, taskInfo, nextPageToken, err := s.messageHandler.readMessagesWithAckLevel(ctx, s.sourceCluster, lastMessageID, pageSize, pageToken) - err := s.messageHandler.PurgeMessages(context.Background(), sourceCluster, lastMessageID) s.NoError(err) + s.Equal(replicationTasks, DLQReplicationMessagesResponse.ReplicationTasks) + s.Len(taskInfo, len(replicationTasksResponse.Tasks)) + // testing content of taskInfo because it's assembled in the method using tasks from replicationTasksFromDLQ + for i, task := range taskInfo { + s.Equal(task.GetDomainID(), replicationTasksResponse.Tasks[i].GetDomainID()) + s.Equal(task.GetWorkflowID(), replicationTasksResponse.Tasks[i].GetWorkflowID()) + s.Equal(task.GetRunID(), replicationTasksResponse.Tasks[i].GetRunID()) + s.Equal(task.GetTaskID(), replicationTasksResponse.Tasks[i].GetTaskID()) + s.Equal(task.GetTaskType(), int16(replicationTasksResponse.Tasks[i].GetTaskType())) + s.Equal(task.GetVersion(), replicationTasksResponse.Tasks[i].GetVersion()) + s.Equal(task.FirstEventID, replicationTasksResponse.Tasks[i].FirstEventID) + s.Equal(task.NextEventID, replicationTasksResponse.Tasks[i].NextEventID) + s.Equal(task.ScheduledID, replicationTasksResponse.Tasks[i].ScheduledID) + } + s.Equal(nextPageToken, replicationTasksResponse.NextPageToken) +} + +func (s *dlqHandlerSuite) TestReadMessagesWithAckLevel_GetReplicationTasksFromDLQFailed() { + errorMessage := "GetReplicationTasksFromDLQFailed" + ctx := context.Background() + lastMessageID := int64(123) + pageSize := 12 + pageToken := []byte("token") + + req := &persistence.GetReplicationTasksFromDLQRequest{ + SourceClusterName: s.sourceCluster, + GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ + ReadLevel: defaultBeginningMessageID, + MaxReadLevel: lastMessageID, + BatchSize: pageSize, + NextPageToken: pageToken, + }, + } + + s.executionManager.On("GetReplicationTasksFromDLQ", ctx, req).Return(nil, errors.New(errorMessage)).Times(1) + + _, _, _, err := s.messageHandler.readMessagesWithAckLevel(ctx, s.sourceCluster, lastMessageID, pageSize, pageToken) + + s.Error(err) + s.Equal(err, errors.New(errorMessage)) +} + +func (s *dlqHandlerSuite) TestReadMessagesWithAckLevel_InvalidCluster() { + s.executionManager.On("GetReplicationTasksFromDLQ", mock.Anything, mock.Anything).Return(nil, nil).Times(1) + + s.mockShard.Resource.ClientBean = client.NewMockBean(s.controller) + s.mockShard.Resource.ClientBean.EXPECT().GetRemoteAdminClient("invalidCluster").Return(nil).Times(1) + + _, _, _, err := s.messageHandler.readMessagesWithAckLevel(context.Background(), "invalidCluster", 123, 12, []byte("token")) + + s.Error(err) + s.Equal(errInvalidCluster, err) +} + +func (s *dlqHandlerSuite) TestReadMessagesWithAckLevel_GetDLQReplicationMessagesFailed() { + errorMessage := "GetDLQReplicationMessagesFailed" + ctx := context.Background() + lastMessageID := int64(123) + pageSize := 12 + pageToken := []byte("token") + + req := &persistence.GetReplicationTasksFromDLQRequest{ + SourceClusterName: s.sourceCluster, + GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ + ReadLevel: defaultBeginningMessageID, + MaxReadLevel: lastMessageID, + BatchSize: pageSize, + NextPageToken: pageToken, + }, + } + + replicationTasksResponse := &persistence.GetReplicationTasksFromDLQResponse{ + Tasks: []*persistence.ReplicationTaskInfo{ + { + DomainID: "domainID", + }, + }, + } + + s.executionManager.On("GetReplicationTasksFromDLQ", ctx, req).Return(replicationTasksResponse, nil).Times(1) + + s.adminClient.EXPECT(). + GetDLQReplicationMessages(ctx, gomock.Any()). + Return(nil, errors.New(errorMessage)).Times(1) + + _, _, _, err := s.messageHandler.readMessagesWithAckLevel(ctx, s.sourceCluster, lastMessageID, pageSize, pageToken) + + s.Error(err) + s.Equal(err, errors.New(errorMessage)) +} + +func (s *dlqHandlerSuite) TestPurgeMessages() { + tests := []struct { + name string + err error + }{ + { + name: "success", + }, + { + name: "error", + err: errors.New("error"), + }, + } + + for _, tc := range tests { + s.T().Run(tc.name, func(t *testing.T) { + lastMessageID := int64(1) + s.executionManager.On("RangeDeleteReplicationTaskFromDLQ", mock.Anything, + &persistence.RangeDeleteReplicationTaskFromDLQRequest{ + SourceClusterName: s.sourceCluster, + ExclusiveBeginTaskID: -1, + InclusiveEndTaskID: lastMessageID, + }).Return(&persistence.RangeDeleteReplicationTaskFromDLQResponse{TasksCompleted: persistence.UnknownNumRowsAffected}, tc.err).Times(1) + + err := s.messageHandler.PurgeMessages(context.Background(), s.sourceCluster, lastMessageID) + + if tc.err != nil { + s.Error(err) + } else { + s.NoError(err) + } + }) + } } func (s *dlqHandlerSuite) TestMergeMessages_OK() { ctx := context.Background() lastMessageID := int64(2) pageSize := 1 - pageToken := []byte{} + var pageToken []byte resp := &persistence.GetReplicationTasksFromDLQResponse{ Tasks: []*persistence.ReplicationTaskInfo{ @@ -222,6 +519,57 @@ func (s *dlqHandlerSuite) TestMergeMessages_OK() { s.Equal(1, len(s.taskExecutor.executedTasks)) } +func (s *dlqHandlerSuite) TestMergeMessages_InvalidCluster() { + _, err := s.messageHandler.MergeMessages(context.Background(), "invalid", 1, 1, nil) + s.Error(err) + s.Equal(errInvalidCluster, err) +} + +func (s *dlqHandlerSuite) TestMergeMessages_GetReplicationTasksFromDLQFailed() { + errorMessage := "GetReplicationTasksFromDLQFailed" + s.executionManager.On("GetReplicationTasksFromDLQ", mock.Anything, mock.Anything).Return(nil, errors.New(errorMessage)).Times(1) + _, err := s.messageHandler.MergeMessages(context.Background(), s.sourceCluster, 1, 1, nil) + s.Error(err) + s.Equal(err, errors.New(errorMessage)) +} + +func (s *dlqHandlerSuite) TestMergeMessages_RangeDeleteReplicationTaskFromDLQFailed() { + errorMessage := "RangeDeleteReplicationTaskFromDLQFailed" + s.executionManager.On("GetReplicationTasksFromDLQ", mock.Anything, mock.Anything).Return(&persistence.GetReplicationTasksFromDLQResponse{}, nil).Times(1) + s.executionManager.On("RangeDeleteReplicationTaskFromDLQ", mock.Anything, mock.Anything).Return(nil, errors.New(errorMessage)).Times(1) + _, err := s.messageHandler.MergeMessages(context.Background(), s.sourceCluster, 1, 1, nil) + s.Error(err) + s.Equal(err, errors.New(errorMessage)) +} + +func (s *dlqHandlerSuite) TestMergeMessages_executeFailed() { + errorMessage := "executeFailed" + s.taskExecutors[s.sourceCluster] = &fakeTaskExecutor{err: errors.New(errorMessage)} + + ctx := context.Background() + lastMessageID := int64(2) + pageSize := 1 + var pageToken []byte + + s.executionManager.On("GetReplicationTasksFromDLQ", mock.Anything, &persistence.GetReplicationTasksFromDLQRequest{ + SourceClusterName: s.sourceCluster, + GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ + ReadLevel: -1, + MaxReadLevel: lastMessageID, + BatchSize: pageSize, + NextPageToken: pageToken, + }, + }).Return(&persistence.GetReplicationTasksFromDLQResponse{Tasks: []*persistence.ReplicationTaskInfo{{TaskID: 1}}}, nil).Times(1) + + s.adminClient.EXPECT().GetDLQReplicationMessages(ctx, gomock.Any()). + Return(&types.GetDLQReplicationMessagesResponse{ReplicationTasks: []*types.ReplicationTask{{SourceTaskID: 1}}}, nil) + + _, err := s.messageHandler.MergeMessages(ctx, s.sourceCluster, lastMessageID, pageSize, pageToken) + + s.Error(err) + s.Equal(err, errors.New(errorMessage)) +} + type fakeTaskExecutor struct { scope int err error @@ -229,7 +577,7 @@ type fakeTaskExecutor struct { executedTasks []*types.ReplicationTask } -func (e *fakeTaskExecutor) execute(replicationTask *types.ReplicationTask, forceApply bool) (int, error) { +func (e *fakeTaskExecutor) execute(replicationTask *types.ReplicationTask, _ bool) (int, error) { e.executedTasks = append(e.executedTasks, replicationTask) return e.scope, e.err }