diff --git a/client/history/client_test.go b/client/history/client_test.go new file mode 100644 index 00000000000..31cb06151f0 --- /dev/null +++ b/client/history/client_test.go @@ -0,0 +1,866 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package history + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "go.uber.org/yarpc" + + "github.com/uber/cadence/common" + "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/types" +) + +func TestNewClient(t *testing.T) { + ctrl := gomock.NewController(t) + numberOfShards := 10 + rpcMaxSizeInBytes := 1024 + client := NewMockClient(ctrl) + peerResolver := NewMockPeerResolver(ctrl) + logger := log.NewNoop() + + c := NewClient(numberOfShards, rpcMaxSizeInBytes, client, peerResolver, logger) + assert.NotNil(t, c) +} + +func TestClient_withResponse(t *testing.T) { + tests := []struct { + name string + op func(Client) (any, error) + mock func(*MockPeerResolver, *MockClient) + want any + wantError bool + }{ + { + name: "StartWorkflowExecution", + op: func(c Client) (any, error) { + return c.StartWorkflowExecution(context.Background(), &types.HistoryStartWorkflowExecutionRequest{ + StartRequest: &types.StartWorkflowExecutionRequest{ + WorkflowID: "test-workflow", + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().StartWorkflowExecution(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.StartWorkflowExecutionResponse{}, nil).Times(1) + }, + want: &types.StartWorkflowExecutionResponse{}, + }, + { + name: "GetMutableState", + op: func(c Client) (any, error) { + return c.GetMutableState(context.Background(), &types.GetMutableStateRequest{ + Execution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().GetMutableState(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.GetMutableStateResponse{}, nil).Times(1) + }, + want: &types.GetMutableStateResponse{}, + }, + { + name: "PollMutableState", + op: func(c Client) (any, error) { + return c.PollMutableState(context.Background(), &types.PollMutableStateRequest{ + Execution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().PollMutableState(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.PollMutableStateResponse{}, nil).Times(1) + }, + want: &types.PollMutableStateResponse{}, + }, + { + name: "ResetWorkflowExecution", + op: func(c Client) (any, error) { + return c.ResetWorkflowExecution(context.Background(), &types.HistoryResetWorkflowExecutionRequest{ + ResetRequest: &types.ResetWorkflowExecutionRequest{ + WorkflowExecution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().ResetWorkflowExecution(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.ResetWorkflowExecutionResponse{}, nil).Times(1) + }, + want: &types.ResetWorkflowExecutionResponse{}, + }, + { + name: "DescribeWorkflowExecution", + op: func(c Client) (any, error) { + return c.DescribeWorkflowExecution(context.Background(), &types.HistoryDescribeWorkflowExecutionRequest{ + Request: &types.DescribeWorkflowExecutionRequest{ + Execution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().DescribeWorkflowExecution(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.DescribeWorkflowExecutionResponse{}, nil).Times(1) + }, + want: &types.DescribeWorkflowExecutionResponse{}, + }, + { + name: "RecordActivityTaskHeartbeat", + op: func(c Client) (any, error) { + return c.RecordActivityTaskHeartbeat(context.Background(), &types.HistoryRecordActivityTaskHeartbeatRequest{ + HeartbeatRequest: &types.RecordActivityTaskHeartbeatRequest{ + TaskToken: []byte(`{"workflowId": "test-workflow"}`), + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RecordActivityTaskHeartbeat(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.RecordActivityTaskHeartbeatResponse{}, nil).Times(1) + }, + want: &types.RecordActivityTaskHeartbeatResponse{}, + }, + { + name: "RecordActivityTaskStarted", + op: func(c Client) (any, error) { + return c.RecordActivityTaskStarted(context.Background(), &types.RecordActivityTaskStartedRequest{ + WorkflowExecution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RecordActivityTaskStarted(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.RecordActivityTaskStartedResponse{}, nil).Times(1) + }, + want: &types.RecordActivityTaskStartedResponse{}, + }, + { + name: "RecordDecisionTaskStarted", + op: func(c Client) (any, error) { + return c.RecordDecisionTaskStarted(context.Background(), &types.RecordDecisionTaskStartedRequest{ + WorkflowExecution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RecordDecisionTaskStarted(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.RecordDecisionTaskStartedResponse{}, nil).Times(1) + }, + want: &types.RecordDecisionTaskStartedResponse{}, + }, + { + name: "GetReplicationMessages", + op: func(c Client) (any, error) { + return c.GetReplicationMessages(context.Background(), &types.GetReplicationMessagesRequest{ + Tokens: []*types.ReplicationToken{ + { + ShardID: 100, + }, + { + ShardID: 101, + }, + { + ShardID: 102, + }, + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromShardID(100).Return("test-peer-0", nil).Times(1) + p.EXPECT().FromShardID(101).Return("test-peer-1", nil).Times(1) + p.EXPECT().FromShardID(102).Return("test-peer-2", nil).Times(1) + c.EXPECT().GetReplicationMessages(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer-0")}). + Return(&types.GetReplicationMessagesResponse{ + MessagesByShard: map[int32]*types.ReplicationMessages{100: {}}, + }, nil).Times(1) + c.EXPECT().GetReplicationMessages(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer-1")}). + Return(&types.GetReplicationMessagesResponse{ + MessagesByShard: map[int32]*types.ReplicationMessages{101: {}}, + }, nil).Times(1) + c.EXPECT().GetReplicationMessages(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer-2")}). + Return(&types.GetReplicationMessagesResponse{ + MessagesByShard: map[int32]*types.ReplicationMessages{102: {}}, + }, nil).Times(1) + }, + want: &types.GetReplicationMessagesResponse{ + MessagesByShard: map[int32]*types.ReplicationMessages{ + 100: {}, + 101: {}, + 102: {}, + }, + }, + }, + { + name: "GetDLQReplicationMessages", + op: func(c Client) (any, error) { + return c.GetDLQReplicationMessages(context.Background(), &types.GetDLQReplicationMessagesRequest{ + TaskInfos: []*types.ReplicationTaskInfo{ + {WorkflowID: "test-workflow"}, + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().GetDLQReplicationMessages(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.GetDLQReplicationMessagesResponse{}, nil).Times(1) + }, + want: &types.GetDLQReplicationMessagesResponse{}, + }, + { + name: "ReadDLQMessages", + op: func(c Client) (any, error) { + return c.ReadDLQMessages(context.Background(), &types.ReadDLQMessagesRequest{ + ShardID: 123, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromShardID(123).Return("test-peer", nil).Times(1) + c.EXPECT().ReadDLQMessages(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.ReadDLQMessagesResponse{}, nil).Times(1) + }, + want: &types.ReadDLQMessagesResponse{}, + }, + { + name: "MergeDLQMessages", + op: func(c Client) (any, error) { + return c.MergeDLQMessages(context.Background(), &types.MergeDLQMessagesRequest{ + ShardID: 123, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromShardID(123).Return("test-peer", nil).Times(1) + c.EXPECT().MergeDLQMessages(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.MergeDLQMessagesResponse{}, nil).Times(1) + }, + want: &types.MergeDLQMessagesResponse{}, + }, + { + name: "GetCrossClusterTasks", + op: func(c Client) (any, error) { + return c.GetCrossClusterTasks(context.Background(), &types.GetCrossClusterTasksRequest{ + ShardIDs: []int32{100, 101, 102}, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromShardID(100).Return("test-peer-0", nil).Times(1) + p.EXPECT().FromShardID(101).Return("test-peer-1", nil).Times(1) + p.EXPECT().FromShardID(102).Return("test-peer-2", nil).Times(1) + c.EXPECT().GetCrossClusterTasks(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer-0")}). + Return(&types.GetCrossClusterTasksResponse{ + TasksByShard: map[int32][]*types.CrossClusterTaskRequest{ + 100: {{}, {}}, + }, + }, nil).Times(1) + c.EXPECT().GetCrossClusterTasks(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer-1")}). + Return(&types.GetCrossClusterTasksResponse{ + TasksByShard: map[int32][]*types.CrossClusterTaskRequest{ + 101: {{}}, + }, + }, nil).Times(1) + c.EXPECT().GetCrossClusterTasks(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer-2")}). + Return(&types.GetCrossClusterTasksResponse{ + TasksByShard: map[int32][]*types.CrossClusterTaskRequest{ + 102: {{}, {}, {}}, + }, + }, nil).Times(1) + }, + want: &types.GetCrossClusterTasksResponse{ + TasksByShard: map[int32][]*types.CrossClusterTaskRequest{ + 100: {{}, {}}, + 101: {{}}, + 102: {{}, {}, {}}, + }, + FailedCauseByShard: map[int32]types.GetTaskFailedCause{}, + }, + }, + { + name: "RespondCrossClusterTasksCompleted", + op: func(c Client) (any, error) { + return c.RespondCrossClusterTasksCompleted(context.Background(), &types.RespondCrossClusterTasksCompletedRequest{ + ShardID: 123, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromShardID(123).Return("test-peer", nil).Times(1) + c.EXPECT().RespondCrossClusterTasksCompleted(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.RespondCrossClusterTasksCompletedResponse{}, nil).Times(1) + }, + want: &types.RespondCrossClusterTasksCompletedResponse{}, + }, + { + name: "GetFailoverInfo", + op: func(c Client) (any, error) { + return c.GetFailoverInfo(context.Background(), &types.GetFailoverInfoRequest{ + DomainID: "test-domain", + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromDomainID("test-domain").Return("test-peer", nil).Times(1) + c.EXPECT().GetFailoverInfo(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.GetFailoverInfoResponse{}, nil).Times(1) + }, + want: &types.GetFailoverInfoResponse{}, + }, + { + name: "DescribeHistoryHost by host address", + op: func(c Client) (any, error) { + return c.DescribeHistoryHost(context.Background(), &types.DescribeHistoryHostRequest{ + HostAddress: common.StringPtr("test-host"), + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromHostAddress("test-host").Return("test-peer", nil).Times(1) + c.EXPECT().DescribeHistoryHost(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.DescribeHistoryHostResponse{}, nil).Times(1) + }, + want: &types.DescribeHistoryHostResponse{}, + }, + { + name: "DescribeHistoryHost by workflow id", + op: func(c Client) (any, error) { + return c.DescribeHistoryHost(context.Background(), &types.DescribeHistoryHostRequest{ + ExecutionForHost: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + HostAddress: common.StringPtr("test-host"), + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().DescribeHistoryHost(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.DescribeHistoryHostResponse{}, nil).Times(1) + }, + want: &types.DescribeHistoryHostResponse{}, + }, + { + name: "DescribeHistoryHost by shard id", + op: func(c Client) (any, error) { + return c.DescribeHistoryHost(context.Background(), &types.DescribeHistoryHostRequest{ + ShardIDForHost: common.Int32Ptr(123), + ExecutionForHost: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + HostAddress: common.StringPtr("test-host"), + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromShardID(123).Return("test-peer", nil).Times(1) + c.EXPECT().DescribeHistoryHost(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.DescribeHistoryHostResponse{}, nil).Times(1) + }, + want: &types.DescribeHistoryHostResponse{}, + }, + { + name: "DescribeMutableState", + op: func(c Client) (any, error) { + return c.DescribeMutableState(context.Background(), &types.DescribeMutableStateRequest{ + Execution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().DescribeMutableState(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.DescribeMutableStateResponse{}, nil).Times(1) + }, + want: &types.DescribeMutableStateResponse{}, + }, + { + name: "DescribeQueue", + op: func(c Client) (any, error) { + return c.DescribeQueue(context.Background(), &types.DescribeQueueRequest{ + ShardID: 123, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromShardID(123).Return("test-peer", nil).Times(1) + c.EXPECT().DescribeQueue(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.DescribeQueueResponse{}, nil).Times(1) + }, + want: &types.DescribeQueueResponse{}, + }, + { + name: "CountDLQMessages", + op: func(c Client) (any, error) { + return c.CountDLQMessages(context.Background(), &types.CountDLQMessagesRequest{}) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().GetAllPeers().Return([]string{"test-peer-0", "test-peer-1"}, nil).Times(1) + c.EXPECT().CountDLQMessages(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer-0")}). + Return(&types.HistoryCountDLQMessagesResponse{ + Entries: map[types.HistoryDLQCountKey]int64{ + {ShardID: 1}: 1, + }, + }, nil).Times(1) + c.EXPECT().CountDLQMessages(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer-1")}). + Return(&types.HistoryCountDLQMessagesResponse{ + Entries: map[types.HistoryDLQCountKey]int64{ + {ShardID: 2}: 2, + }, + }, nil).Times(1) + }, + want: &types.HistoryCountDLQMessagesResponse{ + Entries: map[types.HistoryDLQCountKey]int64{ + {ShardID: 1}: 1, + {ShardID: 2}: 2, + }, + }, + }, + { + name: "QueryWorkflow", + op: func(c Client) (any, error) { + return c.QueryWorkflow(context.Background(), &types.HistoryQueryWorkflowRequest{ + Request: &types.QueryWorkflowRequest{ + Execution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().QueryWorkflow(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.HistoryQueryWorkflowResponse{}, nil).Times(1) + }, + want: &types.HistoryQueryWorkflowResponse{}, + }, + { + name: "ResetStickyTaskList", + op: func(c Client) (any, error) { + return c.ResetStickyTaskList(context.Background(), &types.HistoryResetStickyTaskListRequest{ + Execution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().ResetStickyTaskList(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.HistoryResetStickyTaskListResponse{}, nil).Times(1) + }, + want: &types.HistoryResetStickyTaskListResponse{}, + }, + { + name: "RespondDecisionTaskCompleted", + op: func(c Client) (any, error) { + return c.RespondDecisionTaskCompleted(context.Background(), &types.HistoryRespondDecisionTaskCompletedRequest{ + CompleteRequest: &types.RespondDecisionTaskCompletedRequest{ + TaskToken: []byte(`{"workflowId": "test-workflow"}`), + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RespondDecisionTaskCompleted(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.HistoryRespondDecisionTaskCompletedResponse{}, nil).Times(1) + }, + want: &types.HistoryRespondDecisionTaskCompletedResponse{}, + }, + { + name: "RespondDecisionTaskCompleted", + op: func(c Client) (any, error) { + return c.RespondDecisionTaskCompleted(context.Background(), &types.HistoryRespondDecisionTaskCompletedRequest{ + CompleteRequest: &types.RespondDecisionTaskCompletedRequest{ + TaskToken: []byte(`{"workflowId": "test-workflow"}`), + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RespondDecisionTaskCompleted(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.HistoryRespondDecisionTaskCompletedResponse{}, nil).Times(1) + }, + want: &types.HistoryRespondDecisionTaskCompletedResponse{}, + }, + { + name: "SignalWithStartWorkflowExecution", + op: func(c Client) (any, error) { + return c.SignalWithStartWorkflowExecution(context.Background(), &types.HistorySignalWithStartWorkflowExecutionRequest{ + SignalWithStartRequest: &types.SignalWithStartWorkflowExecutionRequest{ + WorkflowID: "test-workflow", + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().SignalWithStartWorkflowExecution(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(&types.StartWorkflowExecutionResponse{}, nil).Times(1) + }, + want: &types.StartWorkflowExecutionResponse{}, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := NewMockClient(ctrl) + mockPeerResolver := NewMockPeerResolver(ctrl) + c := NewClient(10, 1024, mockClient, mockPeerResolver, log.NewNoop()) + tt.mock(mockPeerResolver, mockClient) + res, err := tt.op(c) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, res) + } + }) + } +} + +func TestClient_withNoResponse(t *testing.T) { + tests := []struct { + name string + op func(Client) error + mock func(*MockPeerResolver, *MockClient) + wantError bool + }{ + { + name: "RefreshWorkflowTasks", + op: func(c Client) error { + return c.RefreshWorkflowTasks(context.Background(), &types.HistoryRefreshWorkflowTasksRequest{ + Request: &types.RefreshWorkflowTasksRequest{ + Execution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RefreshWorkflowTasks(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "PurgeDLQMessages", + op: func(c Client) error { + return c.PurgeDLQMessages(context.Background(), &types.PurgeDLQMessagesRequest{ + ShardID: 123, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromShardID(123).Return("test-peer", nil).Times(1) + c.EXPECT().PurgeDLQMessages(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "ReapplyEvents", + op: func(c Client) error { + return c.ReapplyEvents(context.Background(), &types.HistoryReapplyEventsRequest{ + Request: &types.ReapplyEventsRequest{ + WorkflowExecution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().ReapplyEvents(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "TerminateWorkflowExecution", + op: func(c Client) error { + return c.TerminateWorkflowExecution(context.Background(), &types.HistoryTerminateWorkflowExecutionRequest{ + TerminateRequest: &types.TerminateWorkflowExecutionRequest{ + WorkflowExecution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().TerminateWorkflowExecution(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "NotifyFailoverMarkers", + op: func(c Client) error { + return c.NotifyFailoverMarkers(context.Background(), &types.NotifyFailoverMarkersRequest{ + FailoverMarkerTokens: []*types.FailoverMarkerToken{ + { + FailoverMarker: &types.FailoverMarkerAttributes{ + DomainID: "test-domain-0", + }, + }, + { + FailoverMarker: &types.FailoverMarkerAttributes{ + DomainID: "test-domain-1", + }, + }, + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromDomainID("test-domain-0").Return("test-peer-0", nil).Times(1) + p.EXPECT().FromDomainID("test-domain-1").Return("test-peer-1", nil).Times(1) + c.EXPECT().NotifyFailoverMarkers(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer-0")}). + Return(nil).Times(1) + c.EXPECT().NotifyFailoverMarkers(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer-1")}). + Return(nil).Times(1) + }, + }, + { + name: "CloseShard", + op: func(c Client) error { + return c.CloseShard(context.Background(), &types.CloseShardRequest{ + ShardID: 123, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromShardID(123).Return("test-peer", nil).Times(1) + c.EXPECT().CloseShard(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "RemoveTask", + op: func(c Client) error { + return c.RemoveTask(context.Background(), &types.RemoveTaskRequest{ + ShardID: 123, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromShardID(123).Return("test-peer", nil).Times(1) + c.EXPECT().RemoveTask(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "RecordChildExecutionCompleted", + op: func(c Client) error { + return c.RecordChildExecutionCompleted(context.Background(), &types.RecordChildExecutionCompletedRequest{ + WorkflowExecution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RecordChildExecutionCompleted(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "RefreshWorkflowTasks", + op: func(c Client) error { + return c.RefreshWorkflowTasks(context.Background(), &types.HistoryRefreshWorkflowTasksRequest{ + Request: &types.RefreshWorkflowTasksRequest{ + Execution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RefreshWorkflowTasks(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "RemoveSignalMutableState", + op: func(c Client) error { + return c.RemoveSignalMutableState(context.Background(), &types.RemoveSignalMutableStateRequest{ + WorkflowExecution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RemoveSignalMutableState(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "ReplicateEventsV2", + op: func(c Client) error { + return c.ReplicateEventsV2(context.Background(), &types.ReplicateEventsV2Request{ + WorkflowExecution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().ReplicateEventsV2(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "RequestCancelWorkflowExecution", + op: func(c Client) error { + return c.RequestCancelWorkflowExecution(context.Background(), &types.HistoryRequestCancelWorkflowExecutionRequest{ + CancelRequest: &types.RequestCancelWorkflowExecutionRequest{ + WorkflowExecution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RequestCancelWorkflowExecution(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "ResetQueue", + op: func(c Client) error { + return c.ResetQueue(context.Background(), &types.ResetQueueRequest{ + ShardID: 123, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromShardID(123).Return("test-peer", nil).Times(1) + c.EXPECT().ResetQueue(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "RespondActivityTaskCanceled", + op: func(c Client) error { + return c.RespondActivityTaskCanceled(context.Background(), &types.HistoryRespondActivityTaskCanceledRequest{ + CancelRequest: &types.RespondActivityTaskCanceledRequest{ + TaskToken: []byte(`{"workflowId": "test-workflow"}`), + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RespondActivityTaskCanceled(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "RespondActivityTaskCompleted", + op: func(c Client) error { + return c.RespondActivityTaskCompleted(context.Background(), &types.HistoryRespondActivityTaskCompletedRequest{ + CompleteRequest: &types.RespondActivityTaskCompletedRequest{ + TaskToken: []byte(`{"workflowId": "test-workflow"}`), + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RespondActivityTaskCompleted(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "RespondActivityTaskFailed", + op: func(c Client) error { + return c.RespondActivityTaskFailed(context.Background(), &types.HistoryRespondActivityTaskFailedRequest{ + FailedRequest: &types.RespondActivityTaskFailedRequest{ + TaskToken: []byte(`{"workflowId": "test-workflow"}`), + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RespondActivityTaskFailed(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "RespondDecisionTaskFailed", + op: func(c Client) error { + return c.RespondDecisionTaskFailed(context.Background(), &types.HistoryRespondDecisionTaskFailedRequest{ + FailedRequest: &types.RespondDecisionTaskFailedRequest{ + TaskToken: []byte(`{"workflowId": "test-workflow"}`), + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().RespondDecisionTaskFailed(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "ScheduleDecisionTask", + op: func(c Client) error { + return c.ScheduleDecisionTask(context.Background(), &types.ScheduleDecisionTaskRequest{ + WorkflowExecution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().ScheduleDecisionTask(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "SignalWorkflowExecution", + op: func(c Client) error { + return c.SignalWorkflowExecution(context.Background(), &types.HistorySignalWorkflowExecutionRequest{ + SignalRequest: &types.SignalWorkflowExecutionRequest{ + WorkflowExecution: &types.WorkflowExecution{WorkflowID: "test-workflow"}, + }, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().SignalWorkflowExecution(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "SyncActivity", + op: func(c Client) error { + return c.SyncActivity(context.Background(), &types.SyncActivityRequest{ + WorkflowID: "test-workflow", + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromWorkflowID("test-workflow").Return("test-peer", nil).Times(1) + c.EXPECT().SyncActivity(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + { + name: "SyncShardStatus", + op: func(c Client) error { + return c.SyncShardStatus(context.Background(), &types.SyncShardStatusRequest{ + ShardID: 123, + }) + }, + mock: func(p *MockPeerResolver, c *MockClient) { + p.EXPECT().FromShardID(123).Return("test-peer", nil).Times(1) + c.EXPECT().SyncShardStatus(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("test-peer")}). + Return(nil).Times(1) + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := NewMockClient(ctrl) + mockPeerResolver := NewMockPeerResolver(ctrl) + c := NewClient(10, 1024, mockClient, mockPeerResolver, log.NewNoop()) + tt.mock(mockPeerResolver, mockClient) + err := tt.op(c) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/client/history/peerResolver.go b/client/history/peerResolver.go index 8a542e42cc7..f1bc522cb25 100644 --- a/client/history/peerResolver.go +++ b/client/history/peerResolver.go @@ -29,7 +29,17 @@ import ( // PeerResolver is used to resolve history peers. // Those are deployed instances of Cadence history services that participate in the cluster ring. // The resulting peer is simply an address of form ip:port where RPC calls can be routed to. -type PeerResolver struct { +// +//go:generate mockgen -package $GOPACKAGE -source $GOFILE -destination peerResolver_mock.go -package history github.com/uber/cadence/client/history PeerResolver +type PeerResolver interface { + FromWorkflowID(workflowID string) (string, error) + FromDomainID(domainID string) (string, error) + FromShardID(shardID int) (string, error) + FromHostAddress(hostAddress string) (string, error) + GetAllPeers() ([]string, error) +} + +type peerResolver struct { numberOfShards int resolver membership.Resolver namedPort string // grpc or tchannel, depends on yarpc configuration @@ -37,7 +47,7 @@ type PeerResolver struct { // NewPeerResolver creates a new history peer resolver. func NewPeerResolver(numberOfShards int, resolver membership.Resolver, namedPort string) PeerResolver { - return PeerResolver{ + return peerResolver{ numberOfShards: numberOfShards, resolver: resolver, namedPort: namedPort, @@ -47,7 +57,7 @@ func NewPeerResolver(numberOfShards int, resolver membership.Resolver, namedPort // FromWorkflowID resolves the history peer responsible for a given workflowID. // WorkflowID is converted to logical shardID using a consistent hash function. // FromShardID is used for further resolving. -func (pr PeerResolver) FromWorkflowID(workflowID string) (string, error) { +func (pr peerResolver) FromWorkflowID(workflowID string) (string, error) { shardID := common.WorkflowIDToHistoryShard(workflowID, pr.numberOfShards) return pr.FromShardID(shardID) } @@ -55,7 +65,7 @@ func (pr PeerResolver) FromWorkflowID(workflowID string) (string, error) { // FromDomainID resolves the history peer responsible for a given domainID. // DomainID is converted to logical shardID using a consistent hash function. // FromShardID is used for further resolving. -func (pr PeerResolver) FromDomainID(domainID string) (string, error) { +func (pr peerResolver) FromDomainID(domainID string) (string, error) { shardID := common.DomainIDToHistoryShard(domainID, pr.numberOfShards) return pr.FromShardID(shardID) } @@ -63,7 +73,7 @@ func (pr PeerResolver) FromDomainID(domainID string) (string, error) { // FromShardID resolves the history peer responsible for a given logical shardID. // It uses our membership provider to lookup which instance currently owns the given shard. // FromHostAddress is used for further resolving. -func (pr PeerResolver) FromShardID(shardID int) (string, error) { +func (pr peerResolver) FromShardID(shardID int) (string, error) { shardIDString := string(rune(shardID)) host, err := pr.resolver.Lookup(service.History, shardIDString) if err != nil { @@ -75,7 +85,7 @@ func (pr PeerResolver) FromShardID(shardID int) (string, error) { // FromHostAddress resolves the final history peer responsible for the given host address. // The address is formed by adding port for specified transport -func (pr PeerResolver) FromHostAddress(hostAddress string) (string, error) { +func (pr peerResolver) FromHostAddress(hostAddress string) (string, error) { host, err := pr.resolver.LookupByAddress(service.History, hostAddress) if err != nil { return "", common.ToServiceTransientError(err) @@ -85,7 +95,7 @@ func (pr PeerResolver) FromHostAddress(hostAddress string) (string, error) { } // GetAllPeers returns all history service peers in the cluster ring. -func (pr PeerResolver) GetAllPeers() ([]string, error) { +func (pr peerResolver) GetAllPeers() ([]string, error) { hosts, err := pr.resolver.Members(service.History) if err != nil { return nil, common.ToServiceTransientError(err) diff --git a/client/history/peerResolver_mock.go b/client/history/peerResolver_mock.go new file mode 100644 index 00000000000..69ecbe7895c --- /dev/null +++ b/client/history/peerResolver_mock.go @@ -0,0 +1,131 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: peerResolver.go + +// Package history is a generated GoMock package. +package history + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockPeerResolver is a mock of PeerResolver interface. +type MockPeerResolver struct { + ctrl *gomock.Controller + recorder *MockPeerResolverMockRecorder +} + +// MockPeerResolverMockRecorder is the mock recorder for MockPeerResolver. +type MockPeerResolverMockRecorder struct { + mock *MockPeerResolver +} + +// NewMockPeerResolver creates a new mock instance. +func NewMockPeerResolver(ctrl *gomock.Controller) *MockPeerResolver { + mock := &MockPeerResolver{ctrl: ctrl} + mock.recorder = &MockPeerResolverMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPeerResolver) EXPECT() *MockPeerResolverMockRecorder { + return m.recorder +} + +// FromDomainID mocks base method. +func (m *MockPeerResolver) FromDomainID(domainID string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FromDomainID", domainID) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FromDomainID indicates an expected call of FromDomainID. +func (mr *MockPeerResolverMockRecorder) FromDomainID(domainID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FromDomainID", reflect.TypeOf((*MockPeerResolver)(nil).FromDomainID), domainID) +} + +// FromHostAddress mocks base method. +func (m *MockPeerResolver) FromHostAddress(hostAddress string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FromHostAddress", hostAddress) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FromHostAddress indicates an expected call of FromHostAddress. +func (mr *MockPeerResolverMockRecorder) FromHostAddress(hostAddress interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FromHostAddress", reflect.TypeOf((*MockPeerResolver)(nil).FromHostAddress), hostAddress) +} + +// FromShardID mocks base method. +func (m *MockPeerResolver) FromShardID(shardID int) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FromShardID", shardID) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FromShardID indicates an expected call of FromShardID. +func (mr *MockPeerResolverMockRecorder) FromShardID(shardID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FromShardID", reflect.TypeOf((*MockPeerResolver)(nil).FromShardID), shardID) +} + +// FromWorkflowID mocks base method. +func (m *MockPeerResolver) FromWorkflowID(workflowID string) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FromWorkflowID", workflowID) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FromWorkflowID indicates an expected call of FromWorkflowID. +func (mr *MockPeerResolverMockRecorder) FromWorkflowID(workflowID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FromWorkflowID", reflect.TypeOf((*MockPeerResolver)(nil).FromWorkflowID), workflowID) +} + +// GetAllPeers mocks base method. +func (m *MockPeerResolver) GetAllPeers() ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllPeers") + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllPeers indicates an expected call of GetAllPeers. +func (mr *MockPeerResolverMockRecorder) GetAllPeers() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllPeers", reflect.TypeOf((*MockPeerResolver)(nil).GetAllPeers)) +} diff --git a/codecov.yml b/codecov.yml index cce2a121101..ec0a45dae55 100644 --- a/codecov.yml +++ b/codecov.yml @@ -46,6 +46,7 @@ ignore: - "**/version.go" - "bench/**" - "canary/**" + - "cmd/**" - "common/persistence/persistence-tests/**" - "common/domain/errors.go" - "common/log/**" @@ -57,4 +58,5 @@ ignore: - "idls/**" - "service/history/workflow/errors.go" - "testflags/**" + - "tools/common/schema/test/**" - "tools/linter/**" diff --git a/common/domain/failover_watcher.go b/common/domain/failover_watcher.go index 25a9965e989..e93d5b3e569 100644 --- a/common/domain/failover_watcher.go +++ b/common/domain/failover_watcher.go @@ -119,7 +119,7 @@ func (p *failoverWatcherImpl) Stop() { func (p *failoverWatcherImpl) refreshDomainLoop() { - timer := time.NewTimer(backoff.JitDuration( + timer := p.timeSource.NewTimer(backoff.JitDuration( p.refreshInterval(), p.refreshJitter(), )) @@ -129,7 +129,7 @@ func (p *failoverWatcherImpl) refreshDomainLoop() { select { case <-p.shutdownChan: return - case <-timer.C: + case <-timer.Chan(): domains := p.domainCache.GetAllDomain() for _, domain := range domains { p.handleFailoverTimeout(domain) diff --git a/common/domain/failover_watcher_test.go b/common/domain/failover_watcher_test.go index 6239e8fd5bd..3aadbdb7562 100644 --- a/common/domain/failover_watcher_test.go +++ b/common/domain/failover_watcher_test.go @@ -21,6 +21,7 @@ package domain import ( + "errors" "log" "os" "testing" @@ -76,7 +77,7 @@ func (s *failoverWatcherSuite) SetupTest() { s.controller = gomock.NewController(s.T()) s.mockDomainCache = cache.NewMockDomainCache(s.controller) - s.timeSource = clock.NewRealTimeSource() + s.timeSource = clock.NewMockedTimeSource() s.mockMetadataMgr = &mocks.MetadataManager{} s.mockMetadataMgr.On("GetMetadata", mock.Anything).Return(&persistence.GetMetadataResponse{ @@ -243,3 +244,77 @@ func (s *failoverWatcherSuite) TestHandleFailoverTimeout() { ) s.watcher.handleFailoverTimeout(domainEntry) } + +func (s *failoverWatcherSuite) TestStart() { + s.Assertions.Equal(common.DaemonStatusInitialized, s.watcher.status) + s.watcher.Start() + s.Assertions.Equal(common.DaemonStatusStarted, s.watcher.status) + + // Verify that calling Start again does not change the status + s.watcher.Start() + s.Assertions.Equal(common.DaemonStatusStarted, s.watcher.status) + s.watcher.Stop() +} + +func (s *failoverWatcherSuite) TestIsUpdateDomainRetryable() { + testCases := []struct { + name string + inputErr error + wantRetry bool + }{ + {"nil error", nil, true}, + {"non-nil error", errors.New("some error"), true}, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + retry := isUpdateDomainRetryable(tc.inputErr) + s.Equal(tc.wantRetry, retry) + }) + } +} + +func (s *failoverWatcherSuite) TestRefreshDomainLoop() { + + domainName := "testDomain" + domainID := uuid.New() + failoverEndTime := common.Int64Ptr(time.Now().Add(-time.Hour).UnixNano()) // 1 hour in the past + mockTimeSource, _ := s.timeSource.(clock.MockedTimeSource) + + domainInfo := &persistence.DomainInfo{ID: domainID, Name: domainName} + domainConfig := &persistence.DomainConfig{Retention: 1, EmitMetric: true} + replicationConfig := &persistence.DomainReplicationConfig{ActiveClusterName: "active", Clusters: []*persistence.ClusterReplicationConfig{{ClusterName: "active"}}} + domainEntry := cache.NewDomainCacheEntryForTest(domainInfo, domainConfig, true, replicationConfig, 1, failoverEndTime) + + domainsMap := map[string]*cache.DomainCacheEntry{domainID: domainEntry} + s.mockDomainCache.EXPECT().GetAllDomain().Return(domainsMap).AnyTimes() + + s.mockMetadataMgr.On("GetMetadata", mock.Anything).Return(&persistence.GetMetadataResponse{NotificationVersion: 1}, nil).Maybe() + + s.mockMetadataMgr.On("GetDomain", mock.Anything, mock.AnythingOfType("*persistence.GetDomainRequest")).Return(&persistence.GetDomainResponse{ + Info: domainInfo, + Config: domainConfig, + ReplicationConfig: replicationConfig, + IsGlobalDomain: true, + ConfigVersion: 1, + FailoverVersion: 1, + FailoverNotificationVersion: 1, + FailoverEndTime: failoverEndTime, + NotificationVersion: 1, + }, nil).Once() + + s.mockMetadataMgr.On("UpdateDomain", mock.Anything, mock.Anything).Return(nil).Once() + + s.watcher.Start() + + // Delay to allow loop to start + time.Sleep(1 * time.Second) + mockTimeSource.Advance(12 * time.Second) + // Now stop the watcher, which should trigger the shutdown case in refreshDomainLoop + s.watcher.Stop() + + // Enough time for shutdown process to complete + time.Sleep(1 * time.Second) + + s.mockMetadataMgr.AssertExpectations(s.T()) +} diff --git a/common/domain/handler_test.go b/common/domain/handler_integration_test.go similarity index 100% rename from common/domain/handler_test.go rename to common/domain/handler_integration_test.go diff --git a/common/domain/replicationTaskExecutor_test.go b/common/domain/replicationTaskExecutor_test.go index 2d78a84ad37..22fb6003225 100644 --- a/common/domain/replicationTaskExecutor_test.go +++ b/common/domain/replicationTaskExecutor_test.go @@ -21,7 +21,10 @@ package domain import ( + "context" + "errors" "testing" + "time" "github.com/golang/mock/gomock" "github.com/pborman/uuid" @@ -30,6 +33,7 @@ import ( "github.com/uber/cadence/common/clock" "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/log/testlogger" "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/types" ) @@ -248,3 +252,418 @@ func TestDomainReplicationTaskExecutor_Execute(t *testing.T) { }) } } + +func domainCreationTask() *types.DomainTaskAttributes { + return &types.DomainTaskAttributes{ + DomainOperation: types.DomainOperationCreate.Ptr(), + ID: "testDomainID", + Info: &types.DomainInfo{ + Name: "testDomain", + Status: types.DomainStatusRegistered.Ptr(), + Description: "This is a test domain", + OwnerEmail: "owner@test.com", + Data: map[string]string{"key1": "value1"}, // Arbitrary domain metadata + }, + Config: &types.DomainConfiguration{ + WorkflowExecutionRetentionPeriodInDays: 10, + EmitMetric: true, + HistoryArchivalStatus: types.ArchivalStatusEnabled.Ptr(), + HistoryArchivalURI: "test://history/archival", + VisibilityArchivalStatus: types.ArchivalStatusEnabled.Ptr(), + VisibilityArchivalURI: "test://visibility/archival", + }, + ReplicationConfig: &types.DomainReplicationConfiguration{ + ActiveClusterName: "activeClusterName", + Clusters: []*types.ClusterReplicationConfiguration{ + { + ClusterName: "activeClusterName", + }, + { + ClusterName: "standbyClusterName", + }, + }, + }, + ConfigVersion: 1, + FailoverVersion: 1, + PreviousFailoverVersion: 0, + } +} + +func TestHandleDomainCreationReplicationTask(t *testing.T) { + tests := []struct { + name string + task *types.DomainTaskAttributes + setup func(mockDomainManager *persistence.MockDomainManager) + wantError bool + }{ + { + name: "Successful Domain Creation", + task: domainCreationTask(), + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + CreateDomain(gomock.Any(), gomock.Any()). + Return(&persistence.CreateDomainResponse{ID: "testDomainID"}, nil) + }, + wantError: false, + }, + { + name: "Generic Error During Domain Creation", + task: domainCreationTask(), + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + CreateDomain(gomock.Any(), gomock.Any()). + Return(nil, types.InternalServiceError{Message: "an internal error"}). + Times(1) + + // Since CreateDomain failed, handleDomainCreationReplicationTask check for domain existence by name and ID + mockDomainManager.EXPECT(). + GetDomain(gomock.Any(), gomock.Any()). + Return(nil, &types.EntityNotExistsError{}). // Simulate that no domain exists with the given name/ID + AnyTimes() + }, + wantError: true, + }, + { + name: "Handle Name/UUID Collision - EntityNotExistsError", + task: domainCreationTask(), + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + CreateDomain(gomock.Any(), gomock.Any()). + Return(nil, ErrNameUUIDCollision).Times(1) + + mockDomainManager.EXPECT(). + GetDomain(gomock.Any(), gomock.Any()).Return(nil, &types.EntityNotExistsError{}).AnyTimes() + }, + wantError: true, + }, + { + name: "Immediate Error Return from CreateDomain", + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + CreateDomain(gomock.Any(), gomock.Any()). + Return(nil, types.InternalServiceError{Message: "internal error"}). + Times(1) + mockDomainManager.EXPECT(). + GetDomain(gomock.Any(), gomock.Any()). + Return(nil, ErrInvalidDomainStatus). + AnyTimes() + }, + task: domainCreationTask(), + wantError: true, + }, + { + name: "Domain Creation with Nil Status", + task: &types.DomainTaskAttributes{ + DomainOperation: types.DomainOperationCreate.Ptr(), + ID: "testDomainID", + Info: &types.DomainInfo{ + Name: "testDomain", + // Status is intentionally left as nil to trigger the error + }, + }, + setup: func(mockDomainManager *persistence.MockDomainManager) { + // No need to set up a mock for CreateDomain as the call should not reach this point + }, + wantError: true, + }, + { + name: "Domain Creation with Unrecognized Status", + task: &types.DomainTaskAttributes{ + DomainOperation: types.DomainOperationCreate.Ptr(), + ID: "testDomainID", + Info: &types.DomainInfo{ + Name: "testDomain", + Status: types.DomainStatus(999).Ptr(), // Assuming 999 is an unrecognized status + }, + }, + setup: func(mockDomainManager *persistence.MockDomainManager) { + // As before, no need for mock setup for CreateDomain + }, + wantError: true, + }, + { + name: "Unexpected Error Type from GetDomain Leads to Default Error Handling", + task: domainCreationTask(), + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + CreateDomain(gomock.Any(), gomock.Any()). + Return(nil, ErrInvalidDomainStatus).Times(1) + + mockDomainManager.EXPECT(). + GetDomain(gomock.Any(), gomock.Any()). + Return(nil, errors.New("unexpected error")).Times(1) + }, + wantError: true, + }, + { + name: "Successful GetDomain with Name/UUID Mismatch", + task: domainCreationTask(), + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + CreateDomain(gomock.Any(), gomock.Any()). + Return(nil, ErrNameUUIDCollision).AnyTimes() + + mockDomainManager.EXPECT(). + GetDomain(gomock.Any(), gomock.Any()). + Return(&persistence.GetDomainResponse{ + Info: &persistence.DomainInfo{ID: "testDomainID", Name: "mismatchName"}, + }, nil).AnyTimes() + }, + wantError: true, + }, + { + name: "Handle Domain Creation with Unhandled Error", + task: domainCreationTask(), + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + GetDomain(gomock.Any(), gomock.Any()). + Return(nil, &types.EntityNotExistsError{}). + AnyTimes() + + mockDomainManager.EXPECT(). + CreateDomain(gomock.Any(), gomock.Any()). + Return(nil, errors.New("unhandled error")). + Times(1) + }, + wantError: true, + }, + { + name: "Handle Domain Creation - Unexpected Error from GetDomain", + task: domainCreationTask(), + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + CreateDomain(gomock.Any(), gomock.Any()). + Return(nil, errors.New("test error")).Times(1) + + mockDomainManager.EXPECT(). + GetDomain(gomock.Any(), gomock.Any()). + Return(nil, &types.EntityNotExistsError{}).Times(1) + + mockDomainManager.EXPECT(). + GetDomain(gomock.Any(), gomock.Any()). + Return(nil, errors.New("unexpected error")).Times(1) + }, + wantError: true, + }, + { + name: "Duplicate Domain Creation With Same ID and Name", + task: domainCreationTask(), + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + CreateDomain(gomock.Any(), gomock.Any()). + Return(nil, ErrNameUUIDCollision).Times(1) + + // Setup GetDomain to return matching ID and Name, indicating no actual conflict + // This setup ensures that recordExists becomes true + mockDomainManager.EXPECT(). + GetDomain(gomock.Any(), gomock.Any()). + Return(&persistence.GetDomainResponse{ + Info: &persistence.DomainInfo{ID: "testDomainID", Name: "testDomain"}, + }, nil).Times(2) + }, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + + mockDomainManager := persistence.NewMockDomainManager(ctrl) + mockLogger := testlogger.New(t) + mockTimeSource := clock.NewMockedTimeSourceAt(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)) // Fixed time + + executor := &domainReplicationTaskExecutorImpl{ + domainManager: mockDomainManager, + logger: mockLogger, + timeSource: mockTimeSource, + } + + tt.setup(mockDomainManager) + err := executor.handleDomainCreationReplicationTask(context.Background(), tt.task) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestHandleDomainUpdateReplicationTask(t *testing.T) { + tests := []struct { + name string + task *types.DomainTaskAttributes + wantErr bool + setup func(mockDomainManager *persistence.MockDomainManager) + }{ + { + name: "Convert Status Error", + task: &types.DomainTaskAttributes{ + Info: &types.DomainInfo{ + Status: func() *types.DomainStatus { var ds types.DomainStatus = 999; return &ds }(), // Invalid status to trigger conversion error + }, + }, + wantErr: true, + setup: func(dm *persistence.MockDomainManager) {}, + }, + { + name: "Error Fetching Metadata", + task: domainCreationTask(), + wantErr: true, + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + GetMetadata(gomock.Any()). + Return(nil, errors.New("Error getting metadata while handling replication task")). + AnyTimes() + }, + }, + { + name: "GetDomain Returns General Error", + task: &types.DomainTaskAttributes{ + Info: &types.DomainInfo{ + Name: "someDomain", + }, + }, + wantErr: true, + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + GetDomain(gomock.Any(), gomock.Any()). + Return(nil, errors.New("general error")).AnyTimes() + + mockDomainManager.EXPECT(). + GetMetadata(gomock.Any()). + Return(&persistence.GetMetadataResponse{}, nil).AnyTimes() + + }, + }, + { + name: "GetDomain Returns EntityNotExistsError - Triggers Domain Creation", + task: &types.DomainTaskAttributes{ + Info: &types.DomainInfo{ + Name: "nonexistentDomain", + }, + }, + wantErr: true, + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + GetDomain(gomock.Any(), &persistence.GetDomainRequest{ + Name: "nonexistentDomain", + }). + Return(nil, &types.EntityNotExistsError{}).AnyTimes() + + mockDomainManager.EXPECT(). + GetMetadata(gomock.Any()). + Return(&persistence.GetMetadataResponse{NotificationVersion: 1}, nil). + AnyTimes() + + mockDomainManager.EXPECT(). + CreateDomain(gomock.Any(), &types.DomainTaskAttributes{ + Info: &types.DomainInfo{ + Name: "nonexistentDomain", + }, + }). + Return(&persistence.CreateDomainResponse{ + ID: "nonexistentDomain", + }, nil). + AnyTimes() + }, + }, + { + name: "Record Not Updated then return nil", + task: &types.DomainTaskAttributes{ + Info: &types.DomainInfo{ + Name: "testDomain", + Status: types.DomainStatusRegistered.Ptr(), + }, + Config: &types.DomainConfiguration{ + BadBinaries: &types.BadBinaries{ + Binaries: map[string]*types.BadBinaryInfo{ + "checksum1": { + Reason: "reasontest", + Operator: "operatortest", + CreatedTimeNano: func() *int64 { var ct int64 = 12345; return &ct }(), + }, + }, + }, + }, + }, + wantErr: false, + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + GetMetadata(gomock.Any()). + Return(&persistence.GetMetadataResponse{NotificationVersion: 1}, nil).AnyTimes() + + mockDomainManager.EXPECT(). + GetDomain(gomock.Any(), gomock.Any()). + Return(&persistence.GetDomainResponse{ + Info: &persistence.DomainInfo{ID: "testDomainID", Name: "testDomain"}, + Config: &persistence.DomainConfig{}, + ReplicationConfig: &persistence.DomainReplicationConfig{}, + }, nil). + AnyTimes() + + mockDomainManager.EXPECT(). + UpdateDomain(gomock.Any(), gomock.Any()). + Return(nil). + AnyTimes() + + }, + }, + { + name: "Update Domain with BadBinaries Set", + task: &types.DomainTaskAttributes{ + Info: &types.DomainInfo{ + Name: "existingDomainName", + }, + Config: &types.DomainConfiguration{ + BadBinaries: &types.BadBinaries{ + Binaries: map[string]*types.BadBinaryInfo{}, + }, + }, + }, + wantErr: true, + setup: func(mockDomainManager *persistence.MockDomainManager) { + mockDomainManager.EXPECT(). + GetMetadata(gomock.Any()). + Return(&persistence.GetMetadataResponse{NotificationVersion: 1}, nil). + AnyTimes() + + mockDomainManager.EXPECT(). + GetDomain(gomock.Any(), gomock.Any()). + Return(&persistence.GetDomainResponse{}, nil). + AnyTimes() + + mockDomainManager.EXPECT(). + UpdateDomain(gomock.Any(), gomock.Any()). + Return(nil). + AnyTimes() + }, + }, + } + assert := assert.New(t) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + + mockDomainManager := persistence.NewMockDomainManager(mockCtrl) + mockLogger := testlogger.New(t) + mockTimeSource := clock.NewMockedTimeSourceAt(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)) // Fixed time + + executor := &domainReplicationTaskExecutorImpl{ + domainManager: mockDomainManager, + logger: mockLogger, + timeSource: mockTimeSource, + } + tt.setup(mockDomainManager) + + err := executor.handleDomainUpdateReplicationTask(context.Background(), tt.task) + if tt.wantErr { + assert.Error(err, "Expected an error for test case: %s", tt.name) + } else { + assert.NoError(err, "Expected no error for test case: %s", tt.name) + } + + }) + } +} diff --git a/common/elasticsearch/client/os2/client_test.go b/common/elasticsearch/client/os2/client_test.go index a4e643495ce..e579bf36602 100644 --- a/common/elasticsearch/client/os2/client_test.go +++ b/common/elasticsearch/client/os2/client_test.go @@ -26,6 +26,8 @@ import ( "bytes" "context" "crypto/tls" + "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" @@ -40,6 +42,13 @@ import ( "github.com/uber/cadence/common/log/testlogger" ) +type MockTransport struct{} + +func (m *MockTransport) Perform(req *http.Request) (*http.Response, error) { + // Simulate a network or connection error + return nil, fmt.Errorf("forced connection error") +} + func TestNewClient(t *testing.T) { logger := testlogger.New(t) testServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -288,3 +297,314 @@ func TestCloseBody(t *testing.T) { _, err = osResponse.Body.Read(make([]byte, 1)) assert.Error(t, err, "Expected response body to be closed after calling closeBody") } + +func TestPutMapping(t *testing.T) { + testCases := []struct { + name string + handler http.HandlerFunc + index string + body string + expectedErr bool + }{ + { + name: "Successful PutMapping", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }, + index: "testIndex", + body: `{"properties": {"field": {"type": "text"}}}`, + expectedErr: false, + }, + { + name: "Failed PutMapping", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }, + index: "nonExistentIndex", + body: `{"properties": {"field": {"type": "text"}}}`, + expectedErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + os2Client, testServer := getSecureMockOS2Client(t, tc.handler, true) + defer testServer.Close() + + err := os2Client.PutMapping(context.Background(), tc.index, tc.body) + + if tc.expectedErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPutMappingError(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + + os2Client, testServer := getSecureMockOS2Client(t, http.HandlerFunc(handler), true) + defer testServer.Close() + os2Client.client.Transport = &MockTransport{} + err := os2Client.PutMapping(context.Background(), "testIndex", `{"properties": {"field": {"type": "text"}}}`) + assert.Error(t, err) +} + +func TestIsNotFoundError(t *testing.T) { + testCases := []struct { + name string + handler http.HandlerFunc + expected bool + }{ + { + name: "NotFound error", + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": map[string]interface{}{}, + "status": 404, + }) + }), + expected: true, + }, + { + name: "Other error", + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Bad Request", http.StatusBadRequest) + }), + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + os2Client, testServer := getSecureMockOS2Client(t, tc.handler, true) + defer testServer.Close() + err := os2Client.CreateIndex(context.Background(), "testIndex") + res := os2Client.IsNotFoundError(err) + assert.Equal(t, tc.expected, res) + }) + } +} + +func TestCount(t *testing.T) { + testCases := []struct { + name string + handler http.HandlerFunc + index string + query string + expectedCount int64 + expectError bool + }{ + { + name: "Successful Count", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, `{"count": 42}`) + }, + index: "testIndex", + query: "{}", + expectedCount: 42, + expectError: false, + }, + { + name: "OpenSearch Error", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, `{"error": "Internal Server Error"}`) + }, + index: "testIndex", + query: "{}", + expectError: true, + }, + { + name: "Decoding Error", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, `{"count": "should be an int64"}`) + }, + index: "testIndex", + query: "{}", + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + os2Client, testServer := getSecureMockOS2Client(t, tc.handler, true) + defer testServer.Close() + + count, err := os2Client.Count(context.Background(), tc.index, tc.query) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedCount, count) + } + }) + } +} + +func TestScroll(t *testing.T) { + testCases := []struct { + name string + scrollID string + handler http.HandlerFunc + expectError bool + expectedScrollID string // Add more fields as needed for assertions + }{ + { + name: "Initial Search Request", + scrollID: "", + handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"_scroll_id": "scrollID123", "took": 10, "hits": {"total": {"value": 2}, "hits": [{"_source": {"field1": "value1"}}]}}`) + }, + expectError: false, + expectedScrollID: "scrollID123", + }, + { + name: "Subsequent Scroll Request", + scrollID: "existingScrollID", + handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"_scroll_id": "scrollID456", "took": 5, "hits": {"total": {"value": 1}, "hits": [{"_source": {"field2": "value2"}}]}}`) + }, + expectError: false, + expectedScrollID: "scrollID456", + }, + { + name: "Error Response", + scrollID: "", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, `{"error": "Internal Server Error"}`) + }, + expectError: true, + }, + { + name: "No More Hits", + scrollID: "someScrollID", + handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"_scroll_id": "scrollIDNoHits", "took": 5, "hits": {"hits": []}}`) + }, + expectError: false, + expectedScrollID: "scrollIDNoHits", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + os2Client, testServer := getSecureMockOS2Client(t, tc.handler, true) + defer testServer.Close() + + resp, err := os2Client.Scroll(context.Background(), "testIndex", "{}", tc.scrollID) + + if tc.expectError { + assert.Error(t, err) + } else if tc.name == "No More Hits" { + assert.Equal(t, io.EOF, err, "Expected io.EOF error for no more hits") + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, tc.expectedScrollID, resp.ScrollID) + } + }) + } +} + +func TestClearScroll(t *testing.T) { + testCases := []struct { + name string + scrollID string + handler http.HandlerFunc + expectedError bool + }{ + { + name: "Successful Scroll Clear", + scrollID: "testScrollID", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, `{}`) + }, + expectedError: false, + }, + { + name: "OpenSearch Server Error", + scrollID: "testScrollID", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, `{"error": {"root_cause": [{"type": "internal_server_error","reason": "Internal server error"}],"type": "internal_server_error","reason": "Internal server error"}}`) + }, + expectedError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + os2Client, testServer := getSecureMockOS2Client(t, tc.handler, true) + defer testServer.Close() + + err := os2Client.ClearScroll(context.Background(), tc.scrollID) + + if tc.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestSearch(t *testing.T) { + testCases := []struct { + name string + index string + body string + handler http.HandlerFunc + expectedError bool + expectedHits int + }{ + { + name: "Successful Search", + index: "testIndex", + body: `{"query": {"match_all": {}}}`, + handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"took": 10, "hits": {"total": {"value": 2}, "hits": [{"_source": {"field": "value"}}, {"_source": {"field": "another value"}}]}}`) + }, + expectedError: false, + expectedHits: 2, + }, + { + name: "OpenSearch Error", + index: "testIndex", + body: `{"query": {"match_all": {}}}`, + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintln(w, `{"error": "Bad request"}`) + }, + expectedError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + os2Client, testServer := getSecureMockOS2Client(t, tc.handler, true) + defer testServer.Close() + + resp, err := os2Client.Search(context.Background(), tc.index, tc.body) + + if tc.expectedError { + assert.Error(t, err) + assert.Nil(t, resp) + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Len(t, resp.Hits.Hits, tc.expectedHits) + } + }) + } +} diff --git a/common/persistence/nosql/nosql_execution_store.go b/common/persistence/nosql/nosql_execution_store.go index 4083418102b..96311880768 100644 --- a/common/persistence/nosql/nosql_execution_store.go +++ b/common/persistence/nosql/nosql_execution_store.go @@ -718,7 +718,6 @@ func (d *nosqlExecutionStore) PutReplicationTaskToDLQ( ctx context.Context, request *persistence.InternalPutReplicationTaskToDLQRequest, ) error { - err := d.db.InsertReplicationDLQTask(ctx, d.shardID, request.SourceClusterName, *request.TaskInfo) if err != nil { return convertCommonErrors(d.db, "PutReplicationTaskToDLQ", err) @@ -731,6 +730,9 @@ func (d *nosqlExecutionStore) GetReplicationTasksFromDLQ( ctx context.Context, request *persistence.GetReplicationTasksFromDLQRequest, ) (*persistence.InternalGetReplicationTasksFromDLQResponse, error) { + if request.ReadLevel > request.MaxReadLevel { + return nil, &types.BadRequestError{Message: "ReadLevel cannot be higher than MaxReadLevel"} + } tasks, nextPageToken, err := d.db.SelectReplicationDLQTasksOrderByTaskID(ctx, d.shardID, request.SourceClusterName, request.BatchSize, request.NextPageToken, request.ReadLevel, request.MaxReadLevel) if err != nil { return nil, convertCommonErrors(d.db, "GetReplicationTasksFromDLQ", err) diff --git a/common/persistence/nosql/nosql_execution_store_test.go b/common/persistence/nosql/nosql_execution_store_test.go index 2913b4ce3b1..60e52035909 100644 --- a/common/persistence/nosql/nosql_execution_store_test.go +++ b/common/persistence/nosql/nosql_execution_store_test.go @@ -24,6 +24,7 @@ import ( "context" "errors" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" @@ -688,11 +689,369 @@ func TestNosqlExecutionStore(t *testing.T) { }, expectedError: &types.EntityNotExistsError{Message: "replication task does not exist"}, }, + { + name: "RangeCompleteReplicationTask success", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + mockDB.EXPECT(). + RangeDeleteReplicationTasks(ctx, shardID, int64(10)). + Return(nil) + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + _, err := store.RangeCompleteReplicationTask(ctx, &persistence.RangeCompleteReplicationTaskRequest{ + InclusiveEndTaskID: 10, + }) + return err + }, + expectedError: nil, + }, + { + name: "RangeCompleteReplicationTask failure - database error", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + mockDB.EXPECT(). + RangeDeleteReplicationTasks(ctx, shardID, int64(10)). + Return(errors.New("database error")) + mockDB.EXPECT().IsNotFoundError(gomock.Any()).Return(true).AnyTimes() + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + _, err := store.RangeCompleteReplicationTask(ctx, &persistence.RangeCompleteReplicationTaskRequest{ + InclusiveEndTaskID: 10, + }) + return err + }, + expectedError: &types.InternalServiceError{Message: "database error"}, + }, + { + name: "RangeCompleteReplicationTask with zero InclusiveEndTaskID", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + // Expect the call with InclusiveEndTaskID of 0 + mockDB.EXPECT(). + RangeDeleteReplicationTasks(ctx, shardID, int64(0)). + Return(nil) + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + _, err := store.RangeCompleteReplicationTask(ctx, &persistence.RangeCompleteReplicationTaskRequest{ + InclusiveEndTaskID: 0, + }) + return err + }, + expectedError: nil, + }, + + { + name: "CompleteTimerTask success", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + mockDB.EXPECT(). + DeleteTimerTask(ctx, shardID, int64(1), gomock.Any()). + Return(nil) + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + return store.CompleteTimerTask(ctx, &persistence.CompleteTimerTaskRequest{ + TaskID: 1, + VisibilityTimestamp: time.Now(), + }) + }, + expectedError: nil, + }, + { + name: "CompleteTimerTask failure - database error", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + mockDB.EXPECT(). + DeleteTimerTask(ctx, shardID, int64(1), gomock.Any()). + Return(errors.New("database error")) + mockDB.EXPECT().IsNotFoundError(gomock.Any()).Return(true).AnyTimes() + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + return store.CompleteTimerTask(ctx, &persistence.CompleteTimerTaskRequest{ + TaskID: 1, + VisibilityTimestamp: time.Now(), + }) + }, + expectedError: &types.InternalServiceError{Message: "database error"}, + }, + { + name: "CompleteTimerTask with future VisibilityTimestamp", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + mockDB.EXPECT(). + DeleteTimerTask(ctx, shardID, int64(1), gomock.Any()). // Expect the call with any timestamp + Return(nil) // Assuming successful deletion + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + return store.CompleteTimerTask(ctx, &persistence.CompleteTimerTaskRequest{ + TaskID: 1, + VisibilityTimestamp: time.Now().Add(24 * time.Hour), // Future timestamp + }) + }, + expectedError: nil, // Adjust based on actual behavior + }, + { + name: "RangeCompleteTimerTask success", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + mockDB.EXPECT(). + RangeDeleteTimerTasks(ctx, shardID, gomock.Any(), gomock.Any()). + Return(nil) + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + now := time.Now() + // Assuming you're testing with a time range starting from 'now' and ending 1 hour later. + beginTime := now + endTime := now.Add(time.Hour) + + _, err := store.RangeCompleteTimerTask(ctx, &persistence.RangeCompleteTimerTaskRequest{ + InclusiveBeginTimestamp: beginTime, + ExclusiveEndTimestamp: endTime, + }) + return err + }, + expectedError: nil, + }, + { + name: "RangeCompleteTimerTask failure - database error", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + mockDB.EXPECT(). + RangeDeleteTimerTasks(ctx, shardID, gomock.Any(), gomock.Any()). + Return(errors.New("database error")) + mockDB.EXPECT().IsNotFoundError(gomock.Any()).Return(true).AnyTimes() + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + now := time.Now() + // Assuming you're testing with a time range starting from 'now' and ending 1 hour later. + beginTime := now + endTime := now.Add(time.Hour) + _, err := store.RangeCompleteTimerTask(ctx, &persistence.RangeCompleteTimerTaskRequest{ + InclusiveBeginTimestamp: beginTime, + ExclusiveEndTimestamp: endTime, + }) + return err + }, + expectedError: &types.InternalServiceError{Message: "database error"}, + }, + { + name: "RangeCompleteTimerTask with inverted time range proceeds", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + // Set up an expectation for the call, even with inverted time range + mockDB.EXPECT(). + RangeDeleteTimerTasks(ctx, shardID, gomock.Any(), gomock.Any()). + Return(nil) // Assuming the operation proceeds regardless of time range order + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + _, err := store.RangeCompleteTimerTask(ctx, &persistence.RangeCompleteTimerTaskRequest{ + InclusiveBeginTimestamp: time.Now().Add(time.Hour), // Future time + ExclusiveEndTimestamp: time.Now(), // Present time + }) + return err + }, + expectedError: nil, + }, + { + name: "GetTimerIndexTasks success", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + mockDB.EXPECT(). + SelectTimerTasksOrderByVisibilityTime( + ctx, + shardID, + 10, + gomock.Nil(), + gomock.Any(), + gomock.Any(), + ).Return([]*persistence.TimerTaskInfo{}, nil, nil) + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + _, err := store.GetTimerIndexTasks(ctx, &persistence.GetTimerIndexTasksRequest{ + BatchSize: 10, + MinTimestamp: time.Now().Add(-time.Hour), + MaxTimestamp: time.Now(), + }) + return err + }, + expectedError: nil, + }, + { + name: "GetTimerIndexTasks success - empty result", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + mockDB.EXPECT(). + SelectTimerTasksOrderByVisibilityTime(ctx, shardID, 10, gomock.Nil(), gomock.Any(), gomock.Any()). + Return([]*persistence.TimerTaskInfo{}, []byte{}, nil) // Return an empty list + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + resp, err := store.GetTimerIndexTasks(ctx, &persistence.GetTimerIndexTasksRequest{ + BatchSize: 10, + MinTimestamp: time.Now().Add(-time.Hour), + MaxTimestamp: time.Now(), + }) + if err != nil { + return err + } + if len(resp.Timers) != 0 { + return errors.New("expected empty result set for timers") + } + return nil + }, + expectedError: nil, + }, + { + name: "PutReplicationTaskToDLQ success", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + replicationTaskInfo := newInternalReplicationTaskInfo() + + mockDB.EXPECT(). + InsertReplicationDLQTask(ctx, shardID, "sourceCluster", gomock.Any()). + DoAndReturn(func(_ context.Context, _ int, _ string, taskInfo persistence.InternalReplicationTaskInfo) error { + require.Equal(t, replicationTaskInfo, taskInfo) + return nil + }) + + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + taskInfo := newInternalReplicationTaskInfo() + return store.PutReplicationTaskToDLQ(ctx, &persistence.InternalPutReplicationTaskToDLQRequest{ + SourceClusterName: "sourceCluster", + TaskInfo: &taskInfo, + }) + }, + expectedError: nil, + }, + { + name: "GetTimerIndexTasks failure - database error", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + mockDB.EXPECT().IsNotFoundError(gomock.Any()).Return(true).AnyTimes() + mockDB.EXPECT(). + SelectTimerTasksOrderByVisibilityTime(ctx, shardID, 10, gomock.Nil(), gomock.Any(), gomock.Any()). + Return(nil, nil, errors.New("database error")) + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + _, err := store.GetTimerIndexTasks(ctx, &persistence.GetTimerIndexTasksRequest{ + BatchSize: 10, + MinTimestamp: time.Now().Add(-time.Hour), + MaxTimestamp: time.Now(), + }) + return err + }, + expectedError: &types.InternalServiceError{Message: "database error"}, + }, + { + name: "GetReplicationTasksFromDLQ success", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + + nextPageToken := []byte("next-page-token") + replicationTasks := []*persistence.InternalReplicationTaskInfo{} + mockDB.EXPECT(). + SelectReplicationDLQTasksOrderByTaskID( + ctx, + shardID, + "sourceCluster", + 10, + gomock.Any(), + int64(0), + int64(100), + ).Return(replicationTasks, nextPageToken, nil) + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + initialNextPageToken := []byte{} + _, err := store.GetReplicationTasksFromDLQ(ctx, &persistence.GetReplicationTasksFromDLQRequest{ + SourceClusterName: "sourceCluster", + GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ + BatchSize: 10, + NextPageToken: initialNextPageToken, + ReadLevel: 0, + MaxReadLevel: 100, + }, + }) + + return err + }, + expectedError: nil, + }, + { + name: "GetReplicationTasksFromDLQ failure - invalid read levels", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + return newTestNosqlExecutionStore(nosqlplugin.NewMockDB(ctrl), log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + _, err := store.GetReplicationTasksFromDLQ(ctx, &persistence.GetReplicationTasksFromDLQRequest{ + SourceClusterName: "sourceCluster", + GetReplicationTasksRequest: persistence.GetReplicationTasksRequest{ + ReadLevel: 100, + MaxReadLevel: 50, + BatchSize: 10, + NextPageToken: []byte{}, + }, + }) + return err + }, + expectedError: &types.BadRequestError{Message: "ReadLevel cannot be higher than MaxReadLevel"}, + }, + { + name: "GetReplicationDLQSize success", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + mockDB.EXPECT(). + SelectReplicationDLQTasksCount(ctx, shardID, "sourceCluster"). + Return(int64(42), nil) + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + resp, err := store.GetReplicationDLQSize(ctx, &persistence.GetReplicationDLQSizeRequest{ + SourceClusterName: "sourceCluster", + }) + if err != nil { + return err + } + if resp.Size != 42 { + return errors.New("unexpected DLQ size") + } + return nil + }, + expectedError: nil, + }, + { + name: "GetReplicationDLQSize failure - invalid source cluster name", + setupMock: func(ctrl *gomock.Controller) *nosqlExecutionStore { + mockDB := nosqlplugin.NewMockDB(ctrl) + mockDB.EXPECT(). + SelectReplicationDLQTasksCount(ctx, shardID, ""). + Return(int64(0), nil) + return newTestNosqlExecutionStore(mockDB, log.NewNoop()) + }, + testFunc: func(store *nosqlExecutionStore) error { + _, err := store.GetReplicationDLQSize(ctx, &persistence.GetReplicationDLQSizeRequest{ + SourceClusterName: "", + }) + return err + }, + expectedError: nil, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ctrl := gomock.NewController(t) - defer ctrl.Finish() store := tc.setupMock(ctrl) err := tc.testFunc(store) @@ -707,6 +1066,179 @@ func TestNosqlExecutionStore(t *testing.T) { } } +func TestDeleteReplicationTaskFromDLQ(t *testing.T) { + ctx := context.Background() + shardID := 1 + + tests := []struct { + name string + sourceCluster string + taskID int64 + setupMock func(*nosqlplugin.MockDB) + expectedError error + }{ + { + name: "success", + sourceCluster: "sourceCluster", + taskID: 1, + setupMock: func(mockDB *nosqlplugin.MockDB) { + mockDB.EXPECT(). + DeleteReplicationDLQTask(ctx, shardID, "sourceCluster", int64(1)). + Return(nil) + }, + expectedError: nil, + }, + { + name: "database error", + sourceCluster: "sourceCluster", + taskID: 1, + setupMock: func(mockDB *nosqlplugin.MockDB) { + mockDB.EXPECT().IsNotFoundError(gomock.Any()).Return(true).AnyTimes() + mockDB.EXPECT(). + DeleteReplicationDLQTask(ctx, shardID, "sourceCluster", int64(1)). + Return(errors.New("database error")) + }, + expectedError: &types.InternalServiceError{Message: "database error"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + controller := gomock.NewController(t) + + mockDB := nosqlplugin.NewMockDB(controller) + store := newTestNosqlExecutionStore(mockDB, log.NewNoop()) + + tc.setupMock(mockDB) + + err := store.DeleteReplicationTaskFromDLQ(ctx, &persistence.DeleteReplicationTaskFromDLQRequest{ + SourceClusterName: tc.sourceCluster, + TaskID: tc.taskID, + }) + + if tc.expectedError != nil { + require.ErrorAs(t, err, &tc.expectedError) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestRangeDeleteReplicationTaskFromDLQ(t *testing.T) { + ctx := context.Background() + shardID := 1 + + tests := []struct { + name string + sourceCluster string + exclusiveBeginID int64 + inclusiveEndID int64 + setupMock func(*nosqlplugin.MockDB) + expectedError error + }{ + { + name: "success", + sourceCluster: "sourceCluster", + exclusiveBeginID: 1, + inclusiveEndID: 100, + setupMock: func(mockDB *nosqlplugin.MockDB) { + mockDB.EXPECT(). + RangeDeleteReplicationDLQTasks(ctx, shardID, "sourceCluster", int64(1), int64(100)). + Return(nil) + }, + expectedError: nil, + }, + { + name: "database error", + sourceCluster: "sourceCluster", + exclusiveBeginID: 1, + inclusiveEndID: 100, + setupMock: func(mockDB *nosqlplugin.MockDB) { + mockDB.EXPECT().IsNotFoundError(gomock.Any()).Return(true).AnyTimes() + mockDB.EXPECT(). + RangeDeleteReplicationDLQTasks(ctx, shardID, "sourceCluster", int64(1), int64(100)). + Return(errors.New("database error")) + }, + expectedError: &types.InternalServiceError{Message: "database error"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + controller := gomock.NewController(t) + + mockDB := nosqlplugin.NewMockDB(controller) + store := newTestNosqlExecutionStore(mockDB, log.NewNoop()) + + tc.setupMock(mockDB) + + _, err := store.RangeDeleteReplicationTaskFromDLQ(ctx, &persistence.RangeDeleteReplicationTaskFromDLQRequest{ + SourceClusterName: tc.sourceCluster, + ExclusiveBeginTaskID: tc.exclusiveBeginID, + InclusiveEndTaskID: tc.inclusiveEndID, + }) + + if tc.expectedError != nil { + require.ErrorAs(t, err, &tc.expectedError) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestCreateFailoverMarkerTasks(t *testing.T) { + ctx := context.Background() + shardID := 1 + + tests := []struct { + name string + rangeID int64 + markers []*persistence.FailoverMarkerTask + setupMock func(*nosqlplugin.MockDB) + expectedError error + }{ + { + name: "success", + rangeID: 123, + markers: []*persistence.FailoverMarkerTask{ + { + TaskData: persistence.TaskData{}, + DomainID: "testDomainID", + }, + }, + setupMock: func(mockDB *nosqlplugin.MockDB) { + mockDB.EXPECT(). + InsertReplicationTask(ctx, gomock.Any(), nosqlplugin.ShardCondition{ShardID: shardID, RangeID: 123}). + Return(nil) + }, + expectedError: nil, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + controller := gomock.NewController(t) + + mockDB := nosqlplugin.NewMockDB(controller) + store := newTestNosqlExecutionStore(mockDB, log.NewNoop()) + + tc.setupMock(mockDB) + + err := store.CreateFailoverMarkerTasks(ctx, &persistence.CreateFailoverMarkersRequest{ + RangeID: tc.rangeID, + Markers: tc.markers, + }) + + if tc.expectedError != nil { + require.ErrorAs(t, err, &tc.expectedError) + } else { + require.NoError(t, err) + } + }) + } +} + func newCreateWorkflowExecutionRequest() *persistence.InternalCreateWorkflowExecutionRequest { return &persistence.InternalCreateWorkflowExecutionRequest{ RangeID: 123, @@ -758,3 +1290,21 @@ func newTestNosqlExecutionStore(db nosqlplugin.DB, logger log.Logger) *nosqlExec nosqlStore: nosqlStore{logger: logger, db: db}, } } + +func newInternalReplicationTaskInfo() persistence.InternalReplicationTaskInfo { + var fixedCreationTime = time.Date(2024, time.April, 3, 14, 35, 44, 0, time.UTC) + return persistence.InternalReplicationTaskInfo{ + DomainID: "testDomainID", + WorkflowID: "testWorkflowID", + RunID: "testRunID", + TaskID: 123, + TaskType: persistence.ReplicationTaskTypeHistory, + FirstEventID: 1, + NextEventID: 2, + Version: 1, + ScheduledID: 3, + BranchToken: []byte("branchToken"), + NewRunBranchToken: []byte("newRunBranchToken"), + CreationTime: fixedCreationTime, + } +} diff --git a/common/persistence/nosql/nosql_execution_store_util_test.go b/common/persistence/nosql/nosql_execution_store_util_test.go index ce3c627ce31..59180fd2095 100644 --- a/common/persistence/nosql/nosql_execution_store_util_test.go +++ b/common/persistence/nosql/nosql_execution_store_util_test.go @@ -23,6 +23,7 @@ package nosql import ( + "context" "errors" "testing" "time" @@ -35,6 +36,7 @@ import ( "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/persistence/nosql/nosqlplugin" + "github.com/uber/cadence/common/types" ) func TestNosqlExecutionStoreUtils(t *testing.T) { @@ -203,7 +205,6 @@ func TestNosqlExecutionStoreUtils(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() mockDB := nosqlplugin.NewMockDB(mockCtrl) store := newTestNosqlExecutionStore(mockDB, log.NewNoop()) @@ -253,12 +254,109 @@ func TestPrepareTasksForWorkflowTxn(t *testing.T) { assert.Nil(t, tasks) }, }, + { + name: "PrepareTimerTasksForWorkflowTxn - Zero Tasks", + setupStore: func(store *nosqlExecutionStore) ([]*nosqlplugin.TimerTask, error) { + return store.prepareTimerTasksForWorkflowTxn("domainID", "workflowID", "runID", []persistence.Task{}) + }, + validate: func(t *testing.T, tasks []*nosqlplugin.TimerTask, err error) { + assert.NoError(t, err) + assert.Empty(t, tasks) + }, + }, + { + name: "PrepareTimerTasksForWorkflowTxn - ActivityTimeoutTask", + setupStore: func(store *nosqlExecutionStore) ([]*nosqlplugin.TimerTask, error) { + timerTasks := []persistence.Task{ + &persistence.ActivityTimeoutTask{ + TaskData: persistence.TaskData{ + Version: 1, + TaskID: 2, + VisibilityTimestamp: time.Now(), + }, + EventID: 3, + Attempt: 2, + }, + } + return store.prepareTimerTasksForWorkflowTxn("domainID", "workflowID", "runID", timerTasks) + }, + validate: func(t *testing.T, tasks []*nosqlplugin.TimerTask, err error) { + assert.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, int64(3), tasks[0].EventID) + assert.Equal(t, int64(2), tasks[0].ScheduleAttempt) + }, + }, + { + name: "PrepareTimerTasksForWorkflowTxn - UserTimerTask", + setupStore: func(store *nosqlExecutionStore) ([]*nosqlplugin.TimerTask, error) { + timerTasks := []persistence.Task{ + &persistence.UserTimerTask{ + TaskData: persistence.TaskData{ + Version: 1, + TaskID: 3, + VisibilityTimestamp: time.Now(), + }, + EventID: 4, + }, + } + return store.prepareTimerTasksForWorkflowTxn("domainID", "workflowID", "runID", timerTasks) + }, + validate: func(t *testing.T, tasks []*nosqlplugin.TimerTask, err error) { + assert.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, int64(4), tasks[0].EventID) + }, + }, + { + name: "PrepareTimerTasksForWorkflowTxn - ActivityRetryTimerTask", + setupStore: func(store *nosqlExecutionStore) ([]*nosqlplugin.TimerTask, error) { + timerTasks := []persistence.Task{ + &persistence.ActivityRetryTimerTask{ + TaskData: persistence.TaskData{ + Version: 1, + TaskID: 4, + VisibilityTimestamp: time.Now(), + }, + EventID: 5, + Attempt: 3, + }, + } + return store.prepareTimerTasksForWorkflowTxn("domainID", "workflowID", "runID", timerTasks) + }, + validate: func(t *testing.T, tasks []*nosqlplugin.TimerTask, err error) { + assert.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, int64(5), tasks[0].EventID) + assert.Equal(t, int64(3), tasks[0].ScheduleAttempt) + }, + }, + { + name: "PrepareTimerTasksForWorkflowTxn - WorkflowBackoffTimerTask", + setupStore: func(store *nosqlExecutionStore) ([]*nosqlplugin.TimerTask, error) { + timerTasks := []persistence.Task{ + &persistence.WorkflowBackoffTimerTask{ + TaskData: persistence.TaskData{ + Version: 1, + TaskID: 5, + VisibilityTimestamp: time.Now(), + }, + EventID: 6, + }, + } + return store.prepareTimerTasksForWorkflowTxn("domainID", "workflowID", "runID", timerTasks) + }, + validate: func(t *testing.T, tasks []*nosqlplugin.TimerTask, err error) { + assert.NoError(t, err) + assert.Len(t, tasks, 1) + assert.Equal(t, int64(6), tasks[0].EventID) + }, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() mockDB := nosqlplugin.NewMockDB(mockCtrl) store := newTestNosqlExecutionStore(mockDB, log.NewNoop()) @@ -271,7 +369,6 @@ func TestPrepareTasksForWorkflowTxn(t *testing.T) { func TestPrepareReplicationTasksForWorkflowTxn(t *testing.T) { mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() mockDB := nosqlplugin.NewMockDB(mockCtrl) store := newTestNosqlExecutionStore(mockDB, log.NewNoop()) @@ -314,6 +411,52 @@ func TestPrepareReplicationTasksForWorkflowTxn(t *testing.T) { assert.Nil(t, tasks) }, }, + { + name: "PrepareReplicationTasksForWorkflowTxn - SyncActivityTask", + setupStore: func(store *nosqlExecutionStore) ([]*nosqlplugin.ReplicationTask, error) { + replicationTasks := []persistence.Task{ + &persistence.SyncActivityTask{ + TaskData: persistence.TaskData{ + Version: 2, + VisibilityTimestamp: time.Now(), + TaskID: 2, + }, + ScheduledID: 123, + }, + } + return store.prepareReplicationTasksForWorkflowTxn("domainID", "workflowID", "runID", replicationTasks) + }, + validate: func(t *testing.T, tasks []*nosqlplugin.ReplicationTask, err error) { + assert.NoError(t, err) + assert.Len(t, tasks, 1) + task := tasks[0] + assert.Equal(t, persistence.ReplicationTaskTypeSyncActivity, task.TaskType) + assert.Equal(t, int64(123), task.ScheduledID) + }, + }, + { + name: "PrepareReplicationTasksForWorkflowTxn - FailoverMarkerTask", + setupStore: func(store *nosqlExecutionStore) ([]*nosqlplugin.ReplicationTask, error) { + replicationTasks := []persistence.Task{ + &persistence.FailoverMarkerTask{ + TaskData: persistence.TaskData{ + Version: 3, + VisibilityTimestamp: time.Now(), + TaskID: 3, + }, + DomainID: "domainID", + }, + } + return store.prepareReplicationTasksForWorkflowTxn("domainID", "workflowID", "runID", replicationTasks) + }, + validate: func(t *testing.T, tasks []*nosqlplugin.ReplicationTask, err error) { + assert.NoError(t, err) + assert.Len(t, tasks, 1) + task := tasks[0] + assert.Equal(t, persistence.ReplicationTaskTypeFailoverMarker, task.TaskType) + assert.Equal(t, "domainID", task.DomainID) + }, + }, } for _, tc := range testCases { @@ -326,7 +469,6 @@ func TestPrepareReplicationTasksForWorkflowTxn(t *testing.T) { func TestPrepareCrossClusterTasksForWorkflowTxn(t *testing.T) { mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() mockDB := nosqlplugin.NewMockDB(mockCtrl) store := newTestNosqlExecutionStore(mockDB, log.NewNoop()) @@ -367,6 +509,112 @@ func TestPrepareCrossClusterTasksForWorkflowTxn(t *testing.T) { assert.Nil(t, tasks) }, }, + { + name: "CrossClusterCancelExecutionTask - Success", + setupStore: func(store *nosqlExecutionStore) ([]*nosqlplugin.CrossClusterTask, error) { + crossClusterTasks := []persistence.Task{ + &persistence.CrossClusterCancelExecutionTask{ + CancelExecutionTask: persistence.CancelExecutionTask{ + TaskData: persistence.TaskData{ + TaskID: 1001, + }, + TargetDomainID: "targetDomainID-cancel", + TargetWorkflowID: "targetWorkflowID-cancel", + TargetRunID: "targetRunID-cancel", + TargetChildWorkflowOnly: true, + InitiatedID: 1001, + }, + TargetCluster: "targetCluster-cancel", + }, + } + return store.prepareCrossClusterTasksForWorkflowTxn("domainID", "workflowID", "runID", crossClusterTasks) + }, + validate: func(t *testing.T, tasks []*nosqlplugin.CrossClusterTask, err error) { + assert.NoError(t, err) + assert.Len(t, tasks, 1) + task := tasks[0] + assert.Equal(t, "targetCluster-cancel", task.TargetCluster) + assert.Equal(t, int64(1001), task.TransferTask.ScheduleID) + }, + }, + { + name: "CrossClusterSignalExecutionTask - Success", + setupStore: func(store *nosqlExecutionStore) ([]*nosqlplugin.CrossClusterTask, error) { + crossClusterTasks := []persistence.Task{ + &persistence.CrossClusterSignalExecutionTask{ + SignalExecutionTask: persistence.SignalExecutionTask{ + TaskData: persistence.TaskData{ + TaskID: 1002, + }, + TargetDomainID: "targetDomainID-signal", + TargetWorkflowID: "targetWorkflowID-signal", + TargetRunID: "targetRunID-signal", + TargetChildWorkflowOnly: true, + InitiatedID: 1002, + }, + TargetCluster: "targetCluster-signal", + }, + } + return store.prepareCrossClusterTasksForWorkflowTxn("domainID", "workflowID", "runID", crossClusterTasks) + }, + validate: func(t *testing.T, tasks []*nosqlplugin.CrossClusterTask, err error) { + assert.NoError(t, err) + assert.Len(t, tasks, 1) + task := tasks[0] + assert.Equal(t, "targetCluster-signal", task.TargetCluster) + assert.Equal(t, int64(1002), task.TransferTask.ScheduleID) + }, + }, + { + name: "CrossClusterRecordChildExecutionCompletedTask - Success", + setupStore: func(store *nosqlExecutionStore) ([]*nosqlplugin.CrossClusterTask, error) { + crossClusterTasks := []persistence.Task{ + &persistence.CrossClusterRecordChildExecutionCompletedTask{ + RecordChildExecutionCompletedTask: persistence.RecordChildExecutionCompletedTask{ + TaskData: persistence.TaskData{ + TaskID: 1003, + }, + TargetDomainID: "targetDomainID-record", + TargetWorkflowID: "targetWorkflowID-record", + TargetRunID: "targetRunID-record", + }, + TargetCluster: "targetCluster-record", + }, + } + return store.prepareCrossClusterTasksForWorkflowTxn("domainID", "workflowID", "runID", crossClusterTasks) + }, + validate: func(t *testing.T, tasks []*nosqlplugin.CrossClusterTask, err error) { + assert.NoError(t, err) + assert.Len(t, tasks, 1) + task := tasks[0] + assert.Equal(t, "targetCluster-record", task.TargetCluster) + }, + }, + { + name: "CrossClusterApplyParentClosePolicyTask - Success", + setupStore: func(store *nosqlExecutionStore) ([]*nosqlplugin.CrossClusterTask, error) { + crossClusterTasks := []persistence.Task{ + &persistence.CrossClusterApplyParentClosePolicyTask{ + ApplyParentClosePolicyTask: persistence.ApplyParentClosePolicyTask{ + TaskData: persistence.TaskData{ + TaskID: 1004, + }, + TargetDomainIDs: map[string]struct{}{"targetDomainID-apply-close": {}}, + }, + TargetCluster: "targetCluster-apply-close", + }, + } + return store.prepareCrossClusterTasksForWorkflowTxn("domainID", "workflowID", "runID", crossClusterTasks) + }, + validate: func(t *testing.T, tasks []*nosqlplugin.CrossClusterTask, err error) { + assert.NoError(t, err) + assert.Len(t, tasks, 1) + task := tasks[0] + assert.Equal(t, "targetCluster-apply-close", task.TargetCluster) + _, exists := task.TransferTask.TargetDomainIDs["targetDomainID-apply-close"] + assert.True(t, exists) + }, + }, } for _, tc := range testCases { @@ -379,7 +627,6 @@ func TestPrepareCrossClusterTasksForWorkflowTxn(t *testing.T) { func TestPrepareNoSQLTasksForWorkflowTxn(t *testing.T) { mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() mockDB := nosqlplugin.NewMockDB(mockCtrl) store := newTestNosqlExecutionStore(mockDB, log.NewNoop()) @@ -463,17 +710,434 @@ func TestPrepareTransferTasksForWorkflowTxn(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() mockDB := nosqlplugin.NewMockDB(mockCtrl) store := newTestNosqlExecutionStore(mockDB, log.NewNoop()) if tc.expectFunc != nil { - tc.expectFunc(mockDB) // Set up any expectations on the mockDB + tc.expectFunc(mockDB) } tasks, err := store.prepareTransferTasksForWorkflowTxn("domainID", "workflowID", "runID", tc.tasks) - tc.validate(t, tasks, err) // Validate the output + tc.validate(t, tasks, err) + }) + } +} + +func TestNosqlExecutionStoreUtilsExtended(t *testing.T) { + testCases := []struct { + name string + setupStore func(store *nosqlExecutionStore) (interface{}, error) + validate func(t *testing.T, result interface{}, err error) + }{ + { + name: "PrepareActivityInfosForWorkflowTxn - Success", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + activityInfos := []*persistence.InternalActivityInfo{ + { + ScheduleID: 1, + ScheduledEvent: persistence.NewDataBlob([]byte("scheduled event data"), common.EncodingTypeThriftRW), + StartedEvent: persistence.NewDataBlob([]byte("started event data"), common.EncodingTypeThriftRW), + }, + } + return store.prepareActivityInfosForWorkflowTxn(activityInfos) + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.NoError(t, err) + infos, ok := result.(map[int64]*persistence.InternalActivityInfo) + assert.True(t, ok) + assert.Len(t, infos, 1) + for _, info := range infos { + assert.NotNil(t, info.ScheduledEvent) + assert.NotNil(t, info.StartedEvent) + } + }, + }, + { + name: "PrepareTimerInfosForWorkflowTxn - Success", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + timerInfos := []*persistence.TimerInfo{ + { + TimerID: "timer1", + }, + } + return store.prepareTimerInfosForWorkflowTxn(timerInfos) + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.NoError(t, err) + infos, ok := result.(map[string]*persistence.TimerInfo) + assert.True(t, ok) + assert.Len(t, infos, 1) + assert.NotNil(t, infos["timer1"]) + }, + }, + { + name: "PrepareChildWFInfosForWorkflowTxn - Success", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + childWFInfos := []*persistence.InternalChildExecutionInfo{ + { + InitiatedID: 1, + InitiatedEvent: persistence.NewDataBlob([]byte("initiated event data"), common.EncodingTypeThriftRW), + StartedEvent: persistence.NewDataBlob([]byte("started event data"), common.EncodingTypeThriftRW), + }, + } + return store.prepareChildWFInfosForWorkflowTxn(childWFInfos) + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.NoError(t, err) + infos, ok := result.(map[int64]*persistence.InternalChildExecutionInfo) + assert.True(t, ok) + assert.Len(t, infos, 1) + for _, info := range infos { + assert.NotNil(t, info.InitiatedEvent) + assert.NotNil(t, info.StartedEvent) + } + }, + }, + { + name: "PrepareTimerInfosForWorkflowTxn - Nil Timer Info", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + return store.prepareTimerInfosForWorkflowTxn(nil) + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.NoError(t, err) + assert.Empty(t, result) + }, + }, + { + name: "PrepareChildWFInfosForWorkflowTxn - Nil Child Execution Info", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + return store.prepareChildWFInfosForWorkflowTxn(nil) + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.NoError(t, err) + assert.Empty(t, result) + }, + }, + { + name: "PrepareChildWFInfosForWorkflowTxn - Encoding Mismatch Error", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + childWFInfos := []*persistence.InternalChildExecutionInfo{ + { + InitiatedID: 1, + InitiatedEvent: persistence.NewDataBlob([]byte("initiated"), common.EncodingTypeThriftRW), + StartedEvent: persistence.NewDataBlob([]byte("started"), common.EncodingTypeJSON), // Encoding mismatch + }, + } + return store.prepareChildWFInfosForWorkflowTxn(childWFInfos) + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.Error(t, err) + assert.Nil(t, result) + }, + }, + { + name: "PrepareRequestCancelsForWorkflowTxn - Success", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + requestCancels := []*persistence.RequestCancelInfo{ + { + InitiatedID: 1, + CancelRequestID: "cancel-1", + }, + { + InitiatedID: 2, + CancelRequestID: "cancel-2", + }, + } + cancels, err := store.prepareRequestCancelsForWorkflowTxn(requestCancels) + return cancels, err + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.NoError(t, err) + cancels := result.(map[int64]*persistence.RequestCancelInfo) + assert.Equal(t, 2, len(cancels)) + assert.Contains(t, cancels, int64(1)) + assert.Contains(t, cancels, int64(2)) + }, + }, + { + name: "PrepareRequestCancelsForWorkflowTxn - Duplicate Initiated IDs", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + requestCancels := []*persistence.RequestCancelInfo{ + { + InitiatedID: 1, + CancelRequestID: "cancel-1", + }, + { + InitiatedID: 1, // Duplicate InitiatedID + CancelRequestID: "cancel-1-duplicate", + }, + } + cancels, err := store.prepareRequestCancelsForWorkflowTxn(requestCancels) + return cancels, err + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.NoError(t, err) + cancels := result.(map[int64]*persistence.RequestCancelInfo) + assert.Equal(t, 1, len(cancels)) + assert.Equal(t, "cancel-1-duplicate", cancels[1].CancelRequestID) + }, + }, + { + name: "PrepareSignalInfosForWorkflowTxn - Success", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + signalInfos := []*persistence.SignalInfo{ + {InitiatedID: 1, SignalRequestID: "signal-1"}, + {InitiatedID: 2, SignalRequestID: "signal-2"}, + } + return store.prepareSignalInfosForWorkflowTxn(signalInfos) + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.NoError(t, err) + infos := result.(map[int64]*persistence.SignalInfo) + assert.Equal(t, 2, len(infos)) + assert.Equal(t, "signal-1", infos[1].SignalRequestID) + assert.Equal(t, "signal-2", infos[2].SignalRequestID) + }, + }, + { + name: "PrepareUpdateWorkflowExecutionTxn - Success", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + executionInfo := &persistence.InternalWorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + State: persistence.WorkflowStateRunning, + CloseStatus: persistence.WorkflowCloseStatusNone, + } + versionHistories := &persistence.DataBlob{ + Encoding: common.EncodingTypeJSON, + Data: []byte(`[{"Branches":[{"BranchID":"test-branch-id","BeginNodeID":1,"EndNodeID":2}]}]`), + } + checksum := checksum.Checksum{Version: 1, + Value: []byte("create-checksum")} + return store.prepareUpdateWorkflowExecutionTxn(executionInfo, versionHistories, checksum, time.Now(), 123) + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.NoError(t, err) + req := result.(*nosqlplugin.WorkflowExecutionRequest) + assert.Equal(t, "test-domain-id", req.DomainID) + assert.Equal(t, int64(123), req.LastWriteVersion) + }, + }, + { + name: "PrepareUpdateWorkflowExecutionTxn - Emptyvalues", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + executionInfo := &persistence.InternalWorkflowExecutionInfo{ + DomainID: "", + WorkflowID: "", + State: persistence.WorkflowStateCompleted, + } + versionHistories := &persistence.DataBlob{ + Encoding: common.EncodingTypeJSON, + Data: []byte(`[{"Branches":[{"BranchID":"branch-id","BeginNodeID":1,"EndNodeID":2}]}]`), + } + checksum := checksum.Checksum{Version: 1, Value: []byte("checksum")} + // This should result in an error due to invalid executionInfo state for the creation scenario + return store.prepareUpdateWorkflowExecutionTxn(executionInfo, versionHistories, checksum, time.Now(), 123) + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.Error(t, err) // Expect an error due to invalid state + assert.Nil(t, result) + }, + }, + { + name: "PrepareUpdateWorkflowExecutionTxn - Invalid Workflow State", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + executionInfo := &persistence.InternalWorkflowExecutionInfo{ + DomainID: "domainID-invalid-state", + WorkflowID: "workflowID-invalid-state", + RunID: "runID-invalid-state", + State: 343, // Invalid state + CloseStatus: persistence.WorkflowCloseStatusNone, + } + versionHistories := &persistence.DataBlob{ + Encoding: common.EncodingTypeJSON, + Data: []byte(`[{"Branches":[{"BranchID":"branch-id","BeginNodeID":1,"EndNodeID":2}]}]`), + } + checksum := checksum.Checksum{Version: 1, Value: []byte("checksum")} + return store.prepareUpdateWorkflowExecutionTxn(executionInfo, versionHistories, checksum, time.Now(), 123) + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.Error(t, err) // Expect an error due to invalid workflow state + assert.Nil(t, result) // No WorkflowExecutionRequest should be returned + }, + }, + { + name: "PrepareCreateWorkflowExecutionTxn - Success", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + executionInfo := &persistence.InternalWorkflowExecutionInfo{ + DomainID: "create-domain-id", + WorkflowID: "create-workflow-id", + RunID: "create-run-id", + State: persistence.WorkflowStateCreated, + CloseStatus: persistence.WorkflowCloseStatusNone, + } + versionHistories := &persistence.DataBlob{ + Encoding: common.EncodingTypeJSON, + Data: []byte(`[{"Branches":[{"BranchID":"create-branch-id","BeginNodeID":1,"EndNodeID":2}]}]`), + } + checksum := checksum.Checksum{Version: 1, Value: []byte("create-checksum")} + return store.prepareCreateWorkflowExecutionTxn(executionInfo, versionHistories, checksum, time.Now(), 123) + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.NoError(t, err) + req := result.(*nosqlplugin.WorkflowExecutionRequest) + assert.Equal(t, "create-domain-id", req.DomainID) + assert.Equal(t, int64(123), req.LastWriteVersion) + }, + }, + { + name: "PrepareCreateWorkflowExecutionTxn - Invalid State", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + executionInfo := &persistence.InternalWorkflowExecutionInfo{ + DomainID: "create-domain-id", + WorkflowID: "create-workflow-id", + RunID: "create-run-id", + State: 232, // Invalid state for creating a workflow execution + CloseStatus: persistence.WorkflowCloseStatusNone, + } + versionHistories := &persistence.DataBlob{ + Encoding: common.EncodingTypeJSON, + Data: []byte(`[{"Branches":[{"BranchID":"create-branch-id","BeginNodeID":1,"EndNodeID":2}]}]`), + } + checksum := checksum.Checksum{Version: 1, Value: []byte("create-checksum")} + return store.prepareCreateWorkflowExecutionTxn(executionInfo, versionHistories, checksum, time.Now(), 123) + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.Error(t, err) + assert.Nil(t, result) + }, + }, + { + name: "prepareCurrentWorkflowRequestForCreateWorkflowTxn - BrandNew mode", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + executionInfo := &persistence.InternalWorkflowExecutionInfo{ + State: persistence.WorkflowStateCreated, + CloseStatus: persistence.WorkflowCloseStatusNone, + CreateRequestID: "test-create-request-id", + } + request := &persistence.InternalCreateWorkflowExecutionRequest{ + Mode: persistence.CreateWorkflowModeBrandNew, + } + return store.prepareCurrentWorkflowRequestForCreateWorkflowTxn( + "test-domain-id", "test-workflow-id", "test-run-id", executionInfo, 123, request) + }, + validate: func(t *testing.T, result interface{}, err error) { + assert.NoError(t, err) + currentWorkflowReq, ok := result.(*nosqlplugin.CurrentWorkflowWriteRequest) + assert.True(t, ok) + assert.NotNil(t, currentWorkflowReq) + assert.Equal(t, nosqlplugin.CurrentWorkflowWriteModeInsert, currentWorkflowReq.WriteMode) + }, + }, + { + name: "processUpdateWorkflowResult - CurrentWorkflowConditionFailInfo error", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + err := &nosqlplugin.WorkflowOperationConditionFailure{ + CurrentWorkflowConditionFailInfo: common.StringPtr("current workflow condition failed"), + } + return nil, store.processUpdateWorkflowResult(err, 99) + }, + validate: func(t *testing.T, _ interface{}, err error) { + assert.Error(t, err) + _, ok := err.(*persistence.CurrentWorkflowConditionFailedError) + assert.True(t, ok) + }, + }, + { + name: "processUpdateWorkflowResult - Success", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + return nil, store.processUpdateWorkflowResult(nil, 99) + }, + validate: func(t *testing.T, _ interface{}, err error) { + assert.NoError(t, err) + }, + }, + { + name: "processUpdateWorkflowResult - ShardRangeIDNotMatch error", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + err := &nosqlplugin.WorkflowOperationConditionFailure{ + ShardRangeIDNotMatch: common.Int64Ptr(100), + } + return nil, store.processUpdateWorkflowResult(err, 99) + }, + validate: func(t *testing.T, _ interface{}, err error) { + assert.Error(t, err) + _, ok := err.(*persistence.ShardOwnershipLostError) + assert.True(t, ok) + }, + }, + { + name: "prepareCurrentWorkflowRequestForCreateWorkflowTxn - WorkflowIDReuse mode with non-completed state", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + executionInfo := &persistence.InternalWorkflowExecutionInfo{ + State: persistence.WorkflowStateRunning, // Simulate a running state + CloseStatus: persistence.WorkflowCloseStatusNone, + CreateRequestID: "test-create-request-id", + } + request := &persistence.InternalCreateWorkflowExecutionRequest{ + Mode: persistence.CreateWorkflowModeWorkflowIDReuse, + PreviousRunID: "test-run-id", + PreviousLastWriteVersion: 123, // Simulating a non-completed state with a valid version + } + return store.prepareCurrentWorkflowRequestForCreateWorkflowTxn( + "test-domain-id", "test-workflow-id", "test-run-id", executionInfo, 123, request) + }, + validate: func(t *testing.T, result interface{}, err error) { + _, ok := err.(*persistence.CurrentWorkflowConditionFailedError) + assert.False(t, ok) + }, + }, + { + name: "assertNotCurrentExecution - Success with different RunID", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + ctx := context.Background() + mockDB := store.db.(*nosqlplugin.MockDB) + mockDB.EXPECT().SelectCurrentWorkflow( + gomock.Any(), + store.shardID, + "test-domain-id", + "test-workflow-id", + ).Return(&nosqlplugin.CurrentWorkflowRow{ + RunID: "different-run-id", + }, nil) + return nil, store.assertNotCurrentExecution(ctx, "test-domain-id", "test-workflow-id", "expected-run-id") + }, + validate: func(t *testing.T, _ interface{}, err error) { + assert.NoError(t, err) + }, + }, + { + name: "assertNotCurrentExecution - No current workflow", + setupStore: func(store *nosqlExecutionStore) (interface{}, error) { + ctx := context.Background() + mockDB := store.db.(*nosqlplugin.MockDB) + + mockDB.EXPECT().SelectCurrentWorkflow( + gomock.Any(), + store.shardID, + "test-domain-id", + "test-workflow-id", + ).Return(nil, &types.EntityNotExistsError{}) + mockDB.EXPECT().IsNotFoundError(gomock.Any()).Return(true).AnyTimes() + return nil, store.assertNotCurrentExecution(ctx, "test-domain-id", "test-workflow-id", "expected-run-id") + }, + validate: func(t *testing.T, _ interface{}, err error) { + assert.NoError(t, err) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + + mockDB := nosqlplugin.NewMockDB(mockCtrl) + store := newTestNosqlExecutionStore(mockDB, log.NewNoop()) + + result, err := tc.setupStore(store) + tc.validate(t, result, err) }) } } diff --git a/common/persistence/nosql/nosqlplugin/cassandra/domain.go b/common/persistence/nosql/nosqlplugin/cassandra/domain.go index 220564a0715..e8170719764 100644 --- a/common/persistence/nosql/nosqlplugin/cassandra/domain.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/domain.go @@ -42,10 +42,7 @@ const ( // Insert a new record to domain // return types.DomainAlreadyExistsError error if failed or already exists // Must return ConditionFailure error if other condition doesn't match -func (db *cdb) InsertDomain( - ctx context.Context, - row *nosqlplugin.DomainRow, -) error { +func (db *cdb) InsertDomain(ctx context.Context, row *nosqlplugin.DomainRow) error { query := db.session.Query(templateCreateDomainQuery, row.Info.ID, row.Info.Name).WithContext(ctx) applied, err := query.MapScanCAS(make(map[string]interface{})) if err != nil { @@ -163,10 +160,7 @@ func (db *cdb) updateMetadataBatch( } // Update domain -func (db *cdb) UpdateDomain( - ctx context.Context, - row *nosqlplugin.DomainRow, -) error { +func (db *cdb) UpdateDomain(ctx context.Context, row *nosqlplugin.DomainRow) error { batch := db.session.NewBatch(gocql.LoggedBatch).WithContext(ctx) failoverEndTime := emptyFailoverEndTime if row.FailoverEndTime != nil { @@ -434,11 +428,7 @@ func (db *cdb) SelectAllDomains( } // Delete a domain, either by domainID or domainName -func (db *cdb) DeleteDomain( - ctx context.Context, - domainID *string, - domainName *string, -) error { +func (db *cdb) DeleteDomain(ctx context.Context, domainID *string, domainName *string) error { if domainName == nil && domainID == nil { return fmt.Errorf("must provide either domainID or domainName") } @@ -455,26 +445,24 @@ func (db *cdb) DeleteDomain( } domainName = common.StringPtr(name) } else { - var ID string + var id string query := db.session.Query(templateGetDomainByNameQueryV2, constDomainPartition, *domainName).WithContext(ctx) - err := query.Scan(&ID, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + err := query.Scan(&id, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) if err != nil { if db.client.IsNotFoundError(err) { return nil } return err } - domainID = common.StringPtr(ID) + domainID = common.StringPtr(id) } return db.deleteDomain(ctx, *domainName, *domainID) } -func (db *cdb) SelectDomainMetadata( - ctx context.Context, -) (int64, error) { +func (db *cdb) SelectDomainMetadata(ctx context.Context) (int64, error) { var notificationVersion int64 - query := db.session.Query(templateGetMetadataQueryV2, constDomainPartition, domainMetadataRecordName) + query := db.session.Query(templateGetMetadataQueryV2, constDomainPartition, domainMetadataRecordName).WithContext(ctx) err := query.Scan(¬ificationVersion) if err != nil { if db.client.IsNotFoundError(err) { @@ -488,10 +476,7 @@ func (db *cdb) SelectDomainMetadata( return notificationVersion, nil } -func (db *cdb) deleteDomain( - ctx context.Context, - name, ID string, -) error { +func (db *cdb) deleteDomain(ctx context.Context, name, ID string) error { query := db.session.Query(templateDeleteDomainByNameQueryV2, constDomainPartition, name).WithContext(ctx) if err := db.executeWithConsistencyAll(query); err != nil { return err diff --git a/common/persistence/nosql/nosqlplugin/cassandra/domain_test.go b/common/persistence/nosql/nosqlplugin/cassandra/domain_test.go new file mode 100644 index 00000000000..32e06e71d3c --- /dev/null +++ b/common/persistence/nosql/nosqlplugin/cassandra/domain_test.go @@ -0,0 +1,945 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cassandra + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/go-cmp/cmp" + + "github.com/uber/cadence/common" + "github.com/uber/cadence/common/config" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log/testlogger" + "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin/cassandra/gocql" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin/cassandra/testdata" + "github.com/uber/cadence/common/types" +) + +func TestInsertDomain(t *testing.T) { + ts, err := time.Parse(time.RFC3339, "2024-04-03T18:00:00Z") + if err != nil { + t.Fatalf("Failed to parse time: %v", err) + } + + tests := []struct { + name string + row *nosqlplugin.DomainRow + queryMockFn func(query *gocql.MockQuery) + clientMockFn func(client *gocql.MockClient) + mapExecuteBatchCASApplied bool + mapExecuteBatchCASPrev map[string]any + mapExecuteBatchCASErr error + wantSessionQueries []string + wantBatchQueries []string + wantErr bool + }{ + { + name: "insertion MapScanCAS failed", + row: testdata.NewDomainRow(ts), + queryMockFn: func(query *gocql.MockQuery) { + // mock calls for insert + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return false, errors.New("some random error") + }).Times(1) + }, + wantErr: true, + }, + { + name: "insertion MapScanCAS could not apply", + row: testdata.NewDomainRow(ts), + queryMockFn: func(query *gocql.MockQuery) { + // mock calls for insert + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return false, nil + }).Times(1) + }, + wantErr: true, + }, + { + name: "insertion success - select metadata failed", + row: testdata.NewDomainRow(ts), + queryMockFn: func(query *gocql.MockQuery) { + // mock calls for insert + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return true, nil + }).Times(1) + + // mock calls for SelectDomainMetadata + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + return errors.New("some random error") + }).Times(1) + }, + clientMockFn: func(client *gocql.MockClient) { + client.EXPECT().IsNotFoundError(gomock.Any()).Return(false).Times(1) + }, + wantErr: true, + }, + { + name: "insertion success - select metadata success - insertion to domains_by_name_v2 failed", + row: testdata.NewDomainRow(ts), + mapExecuteBatchCASErr: errors.New("insert to domains_by_name_v2 failed"), + queryMockFn: func(query *gocql.MockQuery) { + // mock calls for insert + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return true, nil + }).Times(1) + + // mock calls for SelectDomainMetadata + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + return nil + }).Times(1) + }, + wantErr: true, + }, + { + name: "insertion success - select metadata success - insertion to domains_by_name_v2 not applied - orphan domain deletion failed", + row: testdata.NewDomainRow(ts), + mapExecuteBatchCASApplied: false, + queryMockFn: func(query *gocql.MockQuery) { + // mock calls for insert + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return true, nil + }).Times(1) + + // mock calls for SelectDomainMetadata + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + return nil + }).Times(1) + + // mock calls for deleting orphan domain + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Exec().Return(errors.New("orphan domain deletion failure")).Times(1) + }, + wantErr: true, + }, + { + name: "insertion success - select metadata success - insertion to domains_by_name_v2 not applied - domain already exists", + row: testdata.NewDomainRow(ts), + mapExecuteBatchCASApplied: false, + mapExecuteBatchCASPrev: map[string]any{ + "name": testdata.NewDomainRow(ts).Info.Name, // this will causedomain already exist error + }, + queryMockFn: func(query *gocql.MockQuery) { + // mock calls for insert + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return true, nil + }).Times(1) + + // mock calls for SelectDomainMetadata + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + return nil + }).Times(1) + + // mock calls for deleting orphan domain + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Exec().Return(nil).Times(1) + }, + wantErr: true, + }, + { + name: "all success", + row: testdata.NewDomainRow(ts), + mapExecuteBatchCASApplied: true, + queryMockFn: func(query *gocql.MockQuery) { + // mock calls for insert + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return true, nil + }).Times(1) + + // mock calls for SelectDomainMetadata + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + notificationVersion := args[0].(*int64) + *notificationVersion = 7 + return nil + }).Times(1) + }, + wantSessionQueries: []string{ + `INSERT INTO domains (id, domain) VALUES(test-domain-id, {name: test-domain-name}) IF NOT EXISTS`, + `SELECT notification_version FROM domains_by_name_v2 WHERE domains_partition = 0 and name = cadence-domain-metadata `, + }, + wantBatchQueries: []string{ + `INSERT INTO domains_by_name_v2 (` + + `domains_partition, ` + + `name, ` + + `domain, ` + + `config, ` + + `replication_config, ` + + `is_global_domain, ` + + `config_version, ` + + `failover_version, ` + + `failover_notification_version, ` + + `previous_failover_version, ` + + `failover_end_time, ` + + `last_updated_time, ` + + `notification_version) ` + + `VALUES(` + + `0, ` + + `test-domain-name, ` + + `{id: test-domain-id, name: test-domain-name, status: 0, description: test-domain-description, owner_email: test-domain-owner-email, data: map[k1:v1] }, ` + + `{retention: 7, emit_metric: true, archival_bucket: test-archival-bucket, archival_status: ENABLED,history_archival_status: ENABLED, history_archival_uri: test-history-archival-uri, visibility_archival_status: ENABLED, visibility_archival_uri: test-visibility-archival-uri, bad_binaries: [98 97 100 45 98 105 110 97 114 105 101 115],bad_binaries_encoding: thriftrw,isolation_groups: [105 115 111 108 97 116 105 111 110 45 103 114 111 117 112],isolation_groups_encoding: thriftrw,async_workflow_config: [97 115 121 110 99 45 119 111 114 107 102 108 111 119 115 45 99 111 110 102 105 103],async_workflow_config_encoding: thriftrw}, ` + + `{active_cluster_name: test-active-cluster-name, clusters: [map[cluster_name:test-cluster-name]] }, ` + + `true, ` + + `3, ` + + `4, ` + + `0, ` + + `-1, ` + + `1712167200000000000, ` + + `1712167200000000000, ` + + `7) ` + + `IF NOT EXISTS`, + `UPDATE domains_by_name_v2 SET notification_version = 8 WHERE domains_partition = 0 and name = cadence-domain-metadata IF notification_version = 7 `, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + mapExecuteBatchCASApplied: tc.mapExecuteBatchCASApplied, + mapExecuteBatchCASPrev: tc.mapExecuteBatchCASPrev, + mapExecuteBatchCASErr: tc.mapExecuteBatchCASErr, + iter: &fakeIter{}, + } + client := gocql.NewMockClient(ctrl) + if tc.clientMockFn != nil { + tc.clientMockFn(client) + } + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client)) + + err := db.InsertDomain(context.Background(), tc.row) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantSessionQueries, session.queries); diff != "" { + t.Fatalf("Session queries mismatch (-want +got):\n%s", diff) + } + + if len(session.batches) != 1 { + t.Fatalf("Expected 1 batch, got %v", len(session.batches)) + } + + if diff := cmp.Diff(tc.wantBatchQueries, session.batches[0].queries); diff != "" { + t.Fatalf("Batch queries mismatch (-want +got):\n%s", diff) + } + + if !session.iter.closed { + t.Error("Expected iter to be closed") + } + }) + } +} + +func TestUpdateDomain(t *testing.T) { + ts, err := time.Parse(time.RFC3339, "2024-04-04T18:00:00Z") + if err != nil { + t.Fatalf("Failed to parse time: %v", err) + } + + tests := []struct { + name string + row *nosqlplugin.DomainRow + mapExecuteBatchCASApplied bool + mapExecuteBatchCASPrev map[string]any + mapExecuteBatchCASErr error + wantBatchQueries []string + wantErr bool + }{ + { + name: "mapExecuteBatchCAS could not apply", + row: func() *nosqlplugin.DomainRow { + r := testdata.NewDomainRow(ts) + r.FailoverEndTime = nil + return r + }(), + mapExecuteBatchCASApplied: false, + wantErr: true, + }, + { + name: "mapExecuteBatchCAS failed", + row: func() *nosqlplugin.DomainRow { + r := testdata.NewDomainRow(ts) + r.FailoverEndTime = nil + return r + }(), + mapExecuteBatchCASErr: errors.New("some random error"), + wantErr: true, + }, + { + name: "empty failover end time", + row: func() *nosqlplugin.DomainRow { + r := testdata.NewDomainRow(ts) + r.FailoverEndTime = nil + return r + }(), + mapExecuteBatchCASApplied: true, + wantBatchQueries: []string{ + `UPDATE domains_by_name_v2 SET ` + + `domain = {id: test-domain-id, name: test-domain-name, status: 0, description: test-domain-description, owner_email: test-domain-owner-email, data: map[k1:v1] }, ` + + `config = {retention: 7, emit_metric: true, archival_bucket: test-archival-bucket, archival_status: ENABLED,history_archival_status: ENABLED, history_archival_uri: test-history-archival-uri, visibility_archival_status: ENABLED, visibility_archival_uri: test-visibility-archival-uri, bad_binaries: [98 97 100 45 98 105 110 97 114 105 101 115],bad_binaries_encoding: thriftrw,isolation_groups: [105 115 111 108 97 116 105 111 110 45 103 114 111 117 112],isolation_groups_encoding: thriftrw,async_workflow_config: [97 115 121 110 99 45 119 111 114 107 102 108 111 119 115 45 99 111 110 102 105 103],async_workflow_config_encoding: thriftrw}, ` + + `replication_config = {active_cluster_name: test-active-cluster-name, clusters: [map[cluster_name:test-cluster-name]] }, ` + + `config_version = 3 ,` + + `failover_version = 4 ,` + + `failover_notification_version = 0 , ` + + `previous_failover_version = 0 , ` + + `failover_end_time = 0,` + + `last_updated_time = 1712253600000000000,` + + `notification_version = 5 ` + + `WHERE domains_partition = 0 and name = test-domain-name`, + `UPDATE domains_by_name_v2 SET notification_version = 6 WHERE ` + + `domains_partition = 0 and ` + + `name = cadence-domain-metadata ` + + `IF notification_version = 5 `, + }, + }, + { + name: "success", + row: testdata.NewDomainRow(ts), + mapExecuteBatchCASApplied: true, + wantBatchQueries: []string{ + `UPDATE domains_by_name_v2 SET ` + + `domain = {id: test-domain-id, name: test-domain-name, status: 0, description: test-domain-description, owner_email: test-domain-owner-email, data: map[k1:v1] }, ` + + `config = {retention: 7, emit_metric: true, archival_bucket: test-archival-bucket, archival_status: ENABLED,history_archival_status: ENABLED, history_archival_uri: test-history-archival-uri, visibility_archival_status: ENABLED, visibility_archival_uri: test-visibility-archival-uri, bad_binaries: [98 97 100 45 98 105 110 97 114 105 101 115],bad_binaries_encoding: thriftrw,isolation_groups: [105 115 111 108 97 116 105 111 110 45 103 114 111 117 112],isolation_groups_encoding: thriftrw,async_workflow_config: [97 115 121 110 99 45 119 111 114 107 102 108 111 119 115 45 99 111 110 102 105 103],async_workflow_config_encoding: thriftrw}, ` + + `replication_config = {active_cluster_name: test-active-cluster-name, clusters: [map[cluster_name:test-cluster-name]] }, ` + + `config_version = 3 ,` + + `failover_version = 4 ,` + + `failover_notification_version = 0 , ` + + `previous_failover_version = 0 , ` + + `failover_end_time = 1712253600000000000,` + + `last_updated_time = 1712253600000000000,` + + `notification_version = 5 ` + + `WHERE domains_partition = 0 and name = test-domain-name`, + `UPDATE domains_by_name_v2 SET notification_version = 6 WHERE ` + + `domains_partition = 0 and ` + + `name = cadence-domain-metadata ` + + `IF notification_version = 5 `, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + session := &fakeSession{ + query: query, + mapExecuteBatchCASApplied: tc.mapExecuteBatchCASApplied, + mapExecuteBatchCASPrev: tc.mapExecuteBatchCASPrev, + mapExecuteBatchCASErr: tc.mapExecuteBatchCASErr, + iter: &fakeIter{}, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client)) + + err := db.UpdateDomain(context.Background(), tc.row) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if len(session.batches) != 1 { + t.Fatalf("Expected 1 batch, got %v", len(session.batches)) + } + + if diff := cmp.Diff(tc.wantBatchQueries, session.batches[0].queries); diff != "" { + t.Fatalf("Batch queries mismatch (-want +got):\n%s", diff) + } + + if !session.iter.closed { + t.Error("Expected iter to be closed") + } + }) + } +} + +func TestSelectDomain(t *testing.T) { + tests := []struct { + name string + domainID *string + domainName *string + queryMockFn func(query *gocql.MockQuery) + wantQueries []string + wantErr bool + }{ + { + name: "both domainName and domainID not provided", + domainName: nil, + domainID: nil, + wantErr: true, + }, + { + name: "both domainName and domainID provided", + domainID: common.StringPtr("domain_id_1"), + domainName: common.StringPtr("domain_name_1"), + wantErr: true, + }, + { + name: "domainName not provided - success", + domainID: common.StringPtr("domain_id_1"), + domainName: nil, + queryMockFn: func(query *gocql.MockQuery) { + // mock calls to select domainName + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + name := args[0].(**string) + domainName := "domain_name_1" + *name = &domainName + return nil + }).Times(1) + + // mock calls to select domain + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).Return(nil).Times(1) + }, + wantQueries: []string{ + `SELECT domain.name FROM domains WHERE id = domain_id_1`, + `SELECT domain.id, domain.name, domain.status, domain.description, domain.owner_email, domain.data, config.retention, config.emit_metric, config.archival_bucket, config.archival_status, config.history_archival_status, config.history_archival_uri, config.visibility_archival_status, config.visibility_archival_uri, config.bad_binaries, config.bad_binaries_encoding, replication_config.active_cluster_name, replication_config.clusters, config.isolation_groups,config.isolation_groups_encoding,config.async_workflow_config,config.async_workflow_config_encoding,is_global_domain, config_version, failover_version, failover_notification_version, previous_failover_version, failover_end_time, last_updated_time, notification_version FROM domains_by_name_v2 WHERE domains_partition = 0 and name = domain_name_1`, + }, + }, + { + name: "domainName not provided - scan failure", + domainID: common.StringPtr("domain_id_1"), + domainName: nil, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + return errors.New("some random error") + }).Times(1) + }, + wantErr: true, + }, + { + name: "domainID not provided - scan failure", + domainID: nil, + domainName: common.StringPtr("domain_name_1"), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + return errors.New("some random error") + }).Times(1) + + }, + wantErr: true, + }, + { + name: "domainID not provided - success", + domainID: nil, + domainName: common.StringPtr("domain_name_1"), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).Return(nil).Times(1) + }, + wantQueries: []string{ + `SELECT domain.id, domain.name, domain.status, domain.description, domain.owner_email, domain.data, config.retention, config.emit_metric, config.archival_bucket, config.archival_status, config.history_archival_status, config.history_archival_uri, config.visibility_archival_status, config.visibility_archival_uri, config.bad_binaries, config.bad_binaries_encoding, replication_config.active_cluster_name, replication_config.clusters, config.isolation_groups,config.isolation_groups_encoding,config.async_workflow_config,config.async_workflow_config_encoding,is_global_domain, config_version, failover_version, failover_notification_version, previous_failover_version, failover_end_time, last_updated_time, notification_version FROM domains_by_name_v2 WHERE domains_partition = 0 and name = domain_name_1`, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + if tc.queryMockFn != nil { + tc.queryMockFn(query) + } + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client)) + + gotRow, err := db.SelectDomain(context.Background(), tc.domainID, tc.domainName) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if gotRow == nil { + t.Error("Expected domain row to be returned") + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestSelectAllDomains(t *testing.T) { + ts, err := time.Parse(time.RFC3339, "2024-04-03T18:00:00Z") + if err != nil { + t.Fatalf("Failed to parse time: %v", err) + } + + tests := []struct { + name string + pageSize int + pagetoken []byte + iter *fakeIter + wantQueries []string + wantRows []*nosqlplugin.DomainRow + wantErr bool + }{ + { + name: "nil iter", + wantErr: true, + }, + { + name: "iter close failed", + iter: &fakeIter{closeErr: errors.New("some random error")}, + wantErr: true, + }, + { + name: "success", + iter: &fakeIter{ + scanInputs: [][]interface{}{ + { + "domain_name_1", + "domain_id_1", + "domain_name_1", + persistence.DomainStatusRegistered, + "domain_description_1", + "domain_owner_email_1", + map[string]string{"k1": "v1"}, + int32(7), + true, + "test-archival-bucket", + types.ArchivalStatusEnabled, + types.ArchivalStatusEnabled, + "test-history-archival-uri", + types.ArchivalStatusEnabled, + "test-visibility-archival-uri", + []byte("bad-binaries"), + "thriftrw", + []byte("isolation-groups"), + "thriftrw", + []byte("async-workflow-config"), + "thriftrw", + "test-active-cluster-name", + []map[string]interface{}{}, + true, + int64(3), + int64(4), + int64(0), + int64(-1), + int64(1712167200000000000), + int64(1712167200000000000), + int64(7), + }, + }, + }, + wantRows: []*nosqlplugin.DomainRow{ + { + Info: &persistence.DomainInfo{ + ID: "domain_id_1", + Name: "domain_name_1", + Description: "domain_description_1", + OwnerEmail: "domain_owner_email_1", + Data: map[string]string{"k1": "v1"}, + }, + Config: &nosqlplugin.NoSQLInternalDomainConfig{ + Retention: 7 * 24 * time.Hour, + EmitMetric: true, + ArchivalBucket: "test-archival-bucket", + ArchivalStatus: types.ArchivalStatusEnabled, + HistoryArchivalStatus: types.ArchivalStatusEnabled, + HistoryArchivalURI: "test-history-archival-uri", + VisibilityArchivalStatus: types.ArchivalStatusEnabled, + VisibilityArchivalURI: "test-visibility-archival-uri", + BadBinaries: &persistence.DataBlob{Encoding: "thriftrw", Data: []uint8("bad-binaries")}, + IsolationGroups: &persistence.DataBlob{Encoding: "thriftrw", Data: []uint8("isolation-groups")}, + AsyncWorkflowsConfig: &persistence.DataBlob{Encoding: "thriftrw", Data: []uint8("async-workflow-config")}, + }, + ReplicationConfig: &persistence.DomainReplicationConfig{ + ActiveClusterName: "test-active-cluster-name", + Clusters: []*persistence.ClusterReplicationConfig{}, + }, + ConfigVersion: 3, + FailoverVersion: 4, + PreviousFailoverVersion: -1, + FailoverEndTime: &ts, + NotificationVersion: 7, + LastUpdatedTime: ts, + IsGlobalDomain: true, + }, + }, + wantQueries: []string{ + `SELECT name, domain.id, domain.name, domain.status, domain.description, domain.owner_email, domain.data, config.retention, config.emit_metric, config.archival_bucket, config.archival_status, config.history_archival_status, config.history_archival_uri, config.visibility_archival_status, config.visibility_archival_uri, config.bad_binaries, config.bad_binaries_encoding, config.isolation_groups, config.isolation_groups_encoding, config.async_workflow_config, config.async_workflow_config_encoding, replication_config.active_cluster_name, replication_config.clusters, is_global_domain, config_version, failover_version, failover_notification_version, previous_failover_version, failover_end_time, last_updated_time, notification_version FROM domains_by_name_v2 WHERE domains_partition = 0 `, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + query.EXPECT().PageSize(gomock.Any()).Return(query).Times(1) + query.EXPECT().PageState(gomock.Any()).Return(query).Times(1) + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + if tc.iter != nil { + query.EXPECT().Iter().Return(tc.iter).Times(1) + } else { + query.EXPECT().Iter().Return(nil).Times(1) + } + + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + db := newCassandraDBFromSession(cfg, session, logger, nil, dbWithClient(client)) + + gotRows, _, err := db.SelectAllDomains(context.Background(), tc.pageSize, tc.pagetoken) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tc.wantRows, gotRows); diff != "" { + t.Fatalf("Task rows mismatch (-want +got):\n%s", diff) + } + + if !tc.iter.closed { + t.Fatal("iterator not closed") + } + }) + } +} + +func TestSelectDomainMetadata(t *testing.T) { + tests := []struct { + name string + queryMockFn func(query *gocql.MockQuery) + clientMockFn func(client *gocql.MockClient) + wantNtfVer int64 + wantQueries []string + wantErr bool + }{ + { + name: "success", + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + notificationVersion := args[0].(*int64) + *notificationVersion = 3 + return nil + }).Times(1) + }, + wantNtfVer: 3, + wantQueries: []string{ + `SELECT notification_version FROM domains_by_name_v2 WHERE domains_partition = 0 and name = cadence-domain-metadata `, + }, + }, + { + name: "scan failure - isnotfound", + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + return errors.New("some error that will be considered as not found by client mock") + }).Times(1) + }, + clientMockFn: func(client *gocql.MockClient) { + client.EXPECT().IsNotFoundError(gomock.Any()).Return(true).Times(1) + }, + wantNtfVer: 0, + wantQueries: []string{ + `SELECT notification_version FROM domains_by_name_v2 WHERE domains_partition = 0 and name = cadence-domain-metadata `, + }, + }, + { + name: "scan failure", + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + return errors.New("some random error") + }).Times(1) + }, + clientMockFn: func(client *gocql.MockClient) { + client.EXPECT().IsNotFoundError(gomock.Any()).Return(false).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + if tc.clientMockFn != nil { + tc.clientMockFn(client) + } + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client)) + + gotNtfVer, err := db.SelectDomainMetadata(context.Background()) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if gotNtfVer != tc.wantNtfVer { + t.Errorf("Got notification version = %v, want %v", gotNtfVer, tc.wantNtfVer) + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestDeleteDomain(t *testing.T) { + tests := []struct { + name string + domainID *string + domainName *string + queryMockFn func(query *gocql.MockQuery) + clientMockFn func(client *gocql.MockClient) + wantQueries []string + wantErr bool + }{ + { + name: "both domainName and domainID not provided", + domainName: nil, + domainID: nil, + wantErr: true, + }, + { + name: "domainName not provided", + domainID: common.StringPtr("domain_id_1"), + domainName: nil, + queryMockFn: func(query *gocql.MockQuery) { + // mock calls for initial select + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + name := args[0].(*string) + *name = "domain_name_1" + return nil + }).Times(1) + + // mock calls for delete from domains_by_name_v2 + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Exec().Return(nil).Times(1) + + // mock calls for delete from domains + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Exec().Return(nil).Times(1) + }, + wantQueries: []string{ + `SELECT domain.name FROM domains WHERE id = domain_id_1`, + `DELETE FROM domains_by_name_v2 WHERE domains_partition = 0 and name = domain_name_1`, + `DELETE FROM domains WHERE id = domain_id_1`, + }, + }, + { + name: "domainName not provided - scan failure - isnotfound", + domainID: common.StringPtr("domain_id_1"), + domainName: nil, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + return errors.New("some error that will be considered as not found by client mock") + }).Times(1) + }, + clientMockFn: func(client *gocql.MockClient) { + client.EXPECT().IsNotFoundError(gomock.Any()).Return(true).Times(1) + }, + wantQueries: []string{ + `SELECT domain.name FROM domains WHERE id = domain_id_1`, + }, + }, + { + name: "domainName not provided - scan failure", + domainID: common.StringPtr("domain_id_1"), + domainName: nil, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + return errors.New("some random error") + }).Times(1) + }, + clientMockFn: func(client *gocql.MockClient) { + client.EXPECT().IsNotFoundError(gomock.Any()).Return(false).Times(1) + }, + wantErr: true, + }, + { + name: "domainID not provided", + domainID: nil, + domainName: common.StringPtr("domain_name_1"), + queryMockFn: func(query *gocql.MockQuery) { + // mock calls for initial select + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + // Ideally we should be using DoAndReturn here to set the domainID, + // but it panics because gomock doesn't handle n-arity func calls with nil params such as query.Scan(&id, nil, nil) + query.EXPECT().Scan(gomock.Any()).Return(nil).Times(1) + + // mock calls for delete from domains_by_name_v2 + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Exec().Return(nil).Times(1) + + // mock calls for delete from domains + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Exec().Return(nil).Times(1) + }, + wantQueries: []string{ + `SELECT domain.id, domain.name, domain.status, domain.description, domain.owner_email, domain.data, config.retention, config.emit_metric, config.archival_bucket, config.archival_status, config.history_archival_status, config.history_archival_uri, config.visibility_archival_status, config.visibility_archival_uri, config.bad_binaries, config.bad_binaries_encoding, replication_config.active_cluster_name, replication_config.clusters, config.isolation_groups,config.isolation_groups_encoding,config.async_workflow_config,config.async_workflow_config_encoding,is_global_domain, config_version, failover_version, failover_notification_version, previous_failover_version, failover_end_time, last_updated_time, notification_version FROM domains_by_name_v2 WHERE domains_partition = 0 and name = domain_name_1`, + `DELETE FROM domains_by_name_v2 WHERE domains_partition = 0 and name = domain_name_1`, + `DELETE FROM domains WHERE id = `, // domainID is nil, so we expect an empty string here. See the comment above inside mockQueryFn. + }, + }, + { + name: "domainID not provided - scan failure - isnotfound", + domainID: nil, + domainName: common.StringPtr("domain_name_1"), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).Return(errors.New("some error that will be considered as not found by client mock")).Times(1) + }, + clientMockFn: func(client *gocql.MockClient) { + client.EXPECT().IsNotFoundError(gomock.Any()).Return(true).Times(1) + }, + wantQueries: []string{ + `SELECT domain.id, domain.name, domain.status, domain.description, domain.owner_email, domain.data, config.retention, config.emit_metric, config.archival_bucket, config.archival_status, config.history_archival_status, config.history_archival_uri, config.visibility_archival_status, config.visibility_archival_uri, config.bad_binaries, config.bad_binaries_encoding, replication_config.active_cluster_name, replication_config.clusters, config.isolation_groups,config.isolation_groups_encoding,config.async_workflow_config,config.async_workflow_config_encoding,is_global_domain, config_version, failover_version, failover_notification_version, previous_failover_version, failover_end_time, last_updated_time, notification_version FROM domains_by_name_v2 WHERE domains_partition = 0 and name = domain_name_1`, + }, + }, + { + name: "domainID not provided - scan failure", + domainID: nil, + domainName: common.StringPtr("domain_name_1"), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).Return(errors.New("some random error")).Times(1) + }, + clientMockFn: func(client *gocql.MockClient) { + client.EXPECT().IsNotFoundError(gomock.Any()).Return(false).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + if tc.queryMockFn != nil { + tc.queryMockFn(query) + } + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + if tc.clientMockFn != nil { + tc.clientMockFn(client) + } + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{ + EnableCassandraAllConsistencyLevelDelete: func(opts ...dynamicconfig.FilterOption) bool { + return false + }, + } + + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client)) + + err := db.DeleteDomain(context.Background(), tc.domainID, tc.domainName) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/common/persistence/nosql/nosqlplugin/cassandra/gocql/client_test.go b/common/persistence/nosql/nosqlplugin/cassandra/gocql/client_test.go new file mode 100644 index 00000000000..445884dcba2 --- /dev/null +++ b/common/persistence/nosql/nosqlplugin/cassandra/gocql/client_test.go @@ -0,0 +1,89 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package gocql + +import ( + "testing" + + "github.com/gocql/gocql" + gomock "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/uber/cadence/common/config" +) + +func Test_GetRegisteredClient(t *testing.T) { + assert.Panics(t, func() { GetRegisteredClient() }) +} + +func Test_GetRegisteredClientNotNil(t *testing.T) { + mockCtrl := gomock.NewController(t) + registered = NewMockClient(mockCtrl) + assert.Equal(t, registered, GetRegisteredClient()) +} + +func Test_RegisterClient(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + RegisterClient(nil) +} + +func Test_RegisterClientNotNil(t *testing.T) { + mockCtrl := gomock.NewController(t) + newClient := NewMockClient(mockCtrl) + registered = nil + RegisterClient(newClient) + assert.Equal(t, newClient, registered) +} + +func Test_newCassandraCluster(t *testing.T) { + testFullConfig := ClusterConfig{ + Hosts: "testHost1,testHost2,testHost3,testHost4", + Port: 123, + User: "testUser", + Password: "testPassword", + Keyspace: "testKeyspace", + Datacenter: "testDatacenter", + Region: "testRegion", + TLS: &config.TLS{ + Enabled: true, + CertFile: "testCertFile", + KeyFile: "testKeyFile", + }, + MaxConns: 10, + } + clusterConfig := newCassandraCluster(testFullConfig) + assert.Equal(t, []string{"testHost1", "testHost2", "testHost3", "testHost4"}, clusterConfig.Hosts) + assert.Equal(t, testFullConfig.Port, clusterConfig.Port) + assert.Equal(t, testFullConfig.User, clusterConfig.Authenticator.(gocql.PasswordAuthenticator).Username) + assert.Equal(t, testFullConfig.Password, clusterConfig.Authenticator.(gocql.PasswordAuthenticator).Password) + assert.Equal(t, testFullConfig.Keyspace, clusterConfig.Keyspace) + assert.Equal(t, testFullConfig.TLS.CertFile, clusterConfig.SslOpts.CertPath) + assert.Equal(t, testFullConfig.TLS.KeyFile, clusterConfig.SslOpts.KeyPath) + assert.Equal(t, testFullConfig.MaxConns, clusterConfig.NumConns) + + assert.False(t, clusterConfig.HostFilter.Accept(&gocql.HostInfo{})) +} diff --git a/common/persistence/pinotVisibilityTripleManager_test.go b/common/persistence/nosql/nosqlplugin/cassandra/gocql/consistency_test.go similarity index 52% rename from common/persistence/pinotVisibilityTripleManager_test.go rename to common/persistence/nosql/nosqlplugin/cassandra/gocql/consistency_test.go index 016a04ec405..81d21fb3961 100644 --- a/common/persistence/pinotVisibilityTripleManager_test.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/gocql/consistency_test.go @@ -20,43 +20,48 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -package persistence +package gocql import ( "testing" + "github.com/gocql/gocql" "github.com/stretchr/testify/assert" ) -func TestFilterAttrPrefix(t *testing.T) { - tests := map[string]struct { - expectedInput string - expectedOutput string +func Test_mustConvertConsistency(t *testing.T) { + tests := []struct { + input Consistency + output gocql.Consistency }{ - "Case1: empty input": { - expectedInput: "", - expectedOutput: "", - }, - "Case2: filtered input": { - expectedInput: "`Attr.CustomIntField` = 12", - expectedOutput: "CustomIntField = 12", - }, - "Case3: complex input": { - expectedInput: "WorkflowID = 'test-wf' and (`Attr.CustomIntField` = 12 or `Attr.CustomStringField` = 'a-b-c' and WorkflowType = 'wf-type')", - expectedOutput: "WorkflowID = 'test-wf' and (CustomIntField = 12 or CustomStringField = 'a-b-c' and WorkflowType = 'wf-type')", - }, - "Case4: false positive case": { - expectedInput: "`Attr.CustomStringField` = '`Attr.ABCtesting'", - expectedOutput: "CustomStringField = 'ABCtesting'", // this is supposed to be CustomStringField = '`Attr.ABCtesting' - }, + {Any, gocql.Any}, + {One, gocql.One}, + {Two, gocql.Two}, + {Three, gocql.Three}, + {Quorum, gocql.Quorum}, + {All, gocql.All}, + {LocalQuorum, gocql.LocalQuorum}, + {EachQuorum, gocql.EachQuorum}, + {LocalOne, gocql.LocalOne}, } - for name, test := range tests { - t.Run(name, func(t *testing.T) { - assert.NotPanics(t, func() { - actualOutput := filterAttrPrefix(test.expectedInput) - assert.Equal(t, test.expectedOutput, actualOutput) - }) - }) + for _, tt := range tests { + assert.Equal(t, tt.output, mustConvertConsistency(tt.input)) } + assert.Panics(t, func() { mustConvertConsistency(Consistency(9999)) }) +} + +func Test_mustConvertSerialConsistency(t *testing.T) { + tests := []struct { + input SerialConsistency + output gocql.SerialConsistency + }{ + {Serial, gocql.Serial}, + {LocalSerial, gocql.LocalSerial}, + } + + for _, tt := range tests { + assert.Equal(t, tt.output, mustConvertSerialConsistency(tt.input)) + } + assert.Panics(t, func() { mustConvertSerialConsistency(SerialConsistency(9999)) }) } diff --git a/common/persistence/nosql/nosqlplugin/cassandra/gocql/public/client_test.go b/common/persistence/nosql/nosqlplugin/cassandra/gocql/public/client_test.go new file mode 100644 index 00000000000..d904a3ff875 --- /dev/null +++ b/common/persistence/nosql/nosqlplugin/cassandra/gocql/public/client_test.go @@ -0,0 +1,190 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package public + +import ( + "context" + "fmt" + "testing" + + "github.com/gocql/gocql" + "github.com/stretchr/testify/assert" +) + +// MockError to simulate gocql.Error behavior +type MockError struct { + gocql.RequestError + code int + message string +} + +func (m MockError) Code() int { + return m.code +} + +func (m MockError) Message() string { + return m.message +} + +func TestClient_IsTimeoutError(t *testing.T) { + client := client{} + errorMap := map[error]bool{ + nil: false, + context.DeadlineExceeded: true, + gocql.ErrTimeoutNoResponse: true, + gocql.ErrConnectionClosed: true, + &gocql.RequestErrWriteTimeout{}: true, + gocql.ErrFrameTooBig: false, + } + for err, expected := range errorMap { + assert.Equal(t, expected, client.IsTimeoutError(err)) + } +} + +func TestClient_IsNotFoundError(t *testing.T) { + client := client{} + errorMap := map[error]bool{ + nil: false, + gocql.ErrNotFound: true, + gocql.ErrFrameTooBig: false, + } + for err, expected := range errorMap { + assert.Equal(t, expected, client.IsNotFoundError(err)) + } +} + +// TestClient_IsThrottlingError tests the IsThrottlingError function with different error codes +func TestClient_IsThrottlingError(t *testing.T) { + client := client{} + tests := []struct { + name string + mockErrorCode int + expectedResult bool + nonCompatibleError error + }{ + { + name: "With Throttling Error", + mockErrorCode: 0x1001, + expectedResult: true, + }, + { + name: "With Non-Throttling Error", + mockErrorCode: 0x0001, + expectedResult: false, + nonCompatibleError: fmt.Errorf("with Non-Throttling Error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.nonCompatibleError != nil { + result := client.IsThrottlingError(tt.nonCompatibleError) + assert.False(t, result) + } + err := MockError{code: tt.mockErrorCode} + result := client.IsThrottlingError(err) + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestClient_IsDBUnavailableError(t *testing.T) { + client := client{} + tests := []struct { + name string + mockMessage string + mockErrorCode int + expectedResult bool + nonCompatibleError error + }{ + { + name: "With DB Unavailable Error", + mockMessage: "Cannot perform LWT operation", + mockErrorCode: 0x1000, + expectedResult: true, + }, + { + name: "With Non-DB Unavailable Error", + mockMessage: "Cannot perform LWT operation", + mockErrorCode: 0x0001, + expectedResult: false, + }, + { + name: "With Non-compatible Error", + mockErrorCode: 0x0001, + expectedResult: false, + nonCompatibleError: fmt.Errorf("with Non-compatible Error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.nonCompatibleError != nil { + result := client.IsDBUnavailableError(tt.nonCompatibleError) + assert.False(t, result) + } + err := MockError{code: tt.mockErrorCode, message: tt.mockMessage} + result := client.IsDBUnavailableError(err) + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestClient_IsCassandraConsistencyError(t *testing.T) { + client := client{} + tests := []struct { + name string + mockErrorCode int + expectedResult bool + nonCompatibleError error + }{ + { + name: "With Cassandra Consistency Error", + mockErrorCode: 0x1000, + expectedResult: true, + }, + { + name: "With Non-Cassandra Consistency Error", + mockErrorCode: 0x0001, + expectedResult: false, + }, + { + name: "With Non-compatible Error", + mockErrorCode: 0x0001, + expectedResult: false, + nonCompatibleError: fmt.Errorf("with Non-compatible Error"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.nonCompatibleError != nil { + result := client.IsCassandraConsistencyError(tt.nonCompatibleError) + assert.False(t, result) + } + err := MockError{code: tt.mockErrorCode} + result := client.IsCassandraConsistencyError(err) + assert.Equal(t, tt.expectedResult, result) + }) + } +} diff --git a/common/persistence/nosql/nosqlplugin/cassandra/gocql/public/testBase.go b/common/persistence/nosql/nosqlplugin/cassandra/gocql/public/testdata.go similarity index 100% rename from common/persistence/nosql/nosqlplugin/cassandra/gocql/public/testBase.go rename to common/persistence/nosql/nosqlplugin/cassandra/gocql/public/testdata.go diff --git a/common/persistence/nosql/nosqlplugin/cassandra/queue.go b/common/persistence/nosql/nosqlplugin/cassandra/queue.go index ade5b3fded4..3530d91c356 100644 --- a/common/persistence/nosql/nosqlplugin/cassandra/queue.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/queue.go @@ -252,16 +252,10 @@ func (db *cdb) GetQueueSize( return result["count"].(int64), nil } -func getMessagePayload( - message map[string]interface{}, -) []byte { - +func getMessagePayload(message map[string]interface{}) []byte { return message["message_payload"].([]byte) } -func getMessageID( - message map[string]interface{}, -) int64 { - +func getMessageID(message map[string]interface{}) int64 { return message["message_id"].(int64) } diff --git a/common/persistence/nosql/nosqlplugin/cassandra/queue_test.go b/common/persistence/nosql/nosqlplugin/cassandra/queue_test.go new file mode 100644 index 00000000000..7e19d702f28 --- /dev/null +++ b/common/persistence/nosql/nosqlplugin/cassandra/queue_test.go @@ -0,0 +1,936 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cassandra + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/go-cmp/cmp" + + "github.com/uber/cadence/common/config" + "github.com/uber/cadence/common/log/testlogger" + "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin/cassandra/gocql" +) + +func TestInsertIntoQueue(t *testing.T) { + tests := []struct { + name string + row *nosqlplugin.QueueMessageRow + queryMockFn func(query *gocql.MockQuery) + wantQueries []string + wantErr bool + }{ + { + name: "successfully applied", + row: queueMessageRow(101), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return true, nil + }).Times(1) + }, + wantQueries: []string{ + `INSERT INTO queue (queue_type, message_id, message_payload) VALUES(1, 101, [116 101 115 116 45 112 97 121 108 111 97 100 45 49 48 49]) IF NOT EXISTS`, + }, + }, + { + name: "not applied", + row: queueMessageRow(101), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return false, nil + }).Times(1) + }, + wantErr: true, + }, + { + name: "mapscancas failed", + row: queueMessageRow(101), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return false, errors.New("mapscancas failed") + }).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client)) + + err := db.InsertIntoQueue(context.Background(), tc.row) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestSelectLastEnqueuedMessageID(t *testing.T) { + tests := []struct { + name string + queueType persistence.QueueType + queryMockFn func(query *gocql.MockQuery) + wantQueries []string + wantMsgID int64 + wantErr bool + }{ + { + name: "success with shard map fully populated", + queueType: persistence.DomainReplicationQueueType, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScan(gomock.Any()).DoAndReturn(func(m map[string]interface{}) error { + m["message_id"] = int64(101) + return nil + }).Times(1) + }, + wantMsgID: int64(101), + wantQueries: []string{ + `SELECT message_id FROM queue WHERE queue_type=1 ORDER BY message_id DESC LIMIT 1`, + }, + }, + { + name: "mapscan failed", + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScan(gomock.Any()).DoAndReturn(func(m map[string]interface{}) error { + return errors.New("mapscan failed") + }).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client)) + + gotMsgID, err := db.SelectLastEnqueuedMessageID(context.Background(), tc.queueType) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if gotMsgID != tc.wantMsgID { + t.Errorf("Got message ID = %v, want %v", gotMsgID, tc.wantMsgID) + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestSelectQueueMetadata(t *testing.T) { + tests := []struct { + name string + queueType persistence.QueueType + queryMockFn func(query *gocql.MockQuery) + wantRow *nosqlplugin.QueueMetadataRow + wantQueries []string + wantErr bool + }{ + { + name: "success", + queueType: persistence.QueueType(2), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + ackLevels := args[0].(*map[string]int64) + *ackLevels = make(map[string]int64) + (*ackLevels)["cluster1"] = 1000 + (*ackLevels)["cluster2"] = 2000 + version := args[1].(*int64) + *version = int64(25) + return nil + }).Times(1) + }, + wantRow: &nosqlplugin.QueueMetadataRow{ + QueueType: persistence.QueueType(2), + ClusterAckLevels: map[string]int64{"cluster1": 1000, "cluster2": 2000}, + Version: 25, + }, + wantQueries: []string{ + `SELECT cluster_ack_level, version FROM queue_metadata WHERE queue_type = 2`, + }, + }, + { + name: "success with empty acklevels", + queueType: persistence.QueueType(2), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + version := args[1].(*int64) + *version = int64(26) + return nil + }).Times(1) + }, + wantRow: &nosqlplugin.QueueMetadataRow{ + QueueType: persistence.QueueType(2), + ClusterAckLevels: map[string]int64{}, + Version: 26, + }, + wantQueries: []string{ + `SELECT cluster_ack_level, version FROM queue_metadata WHERE queue_type = 2`, + }, + }, + { + name: "scan failure", + queueType: persistence.QueueType(2), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Scan(gomock.Any()).DoAndReturn(func(args ...interface{}) error { + return errors.New("some random error") + }).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client)) + + gotRow, err := db.SelectQueueMetadata(context.Background(), tc.queueType) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantRow, gotRow); diff != "" { + t.Fatalf("Row mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestGetQueueSize(t *testing.T) { + tests := []struct { + name string + queueType persistence.QueueType + queryMockFn func(query *gocql.MockQuery) + wantQueries []string + wantCount int64 + wantErr bool + }{ + { + name: "success with shard map fully populated", + queueType: persistence.DomainReplicationQueueType, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScan(gomock.Any()).DoAndReturn(func(m map[string]interface{}) error { + m["count"] = int64(12) + return nil + }).Times(1) + }, + wantCount: int64(12), + wantQueries: []string{ + `SELECT COUNT(1) AS count FROM queue WHERE queue_type=1`, + }, + }, + { + name: "mapscan failed", + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScan(gomock.Any()).DoAndReturn(func(m map[string]interface{}) error { + return errors.New("mapscan failed") + }).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client)) + + gotCount, err := db.GetQueueSize(context.Background(), tc.queueType) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if gotCount != tc.wantCount { + t.Errorf("Got message ID = %v, want %v", gotCount, tc.wantCount) + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestSelectMessagesFrom(t *testing.T) { + tests := []struct { + name string + queueType persistence.QueueType + exclusiveBeginMessageID int64 + maxRows int + iter *fakeIter + wantQueries []string + wantRows []*nosqlplugin.QueueMessageRow + wantErr bool + }{ + { + name: "nil iter", + queueType: persistence.DomainReplicationQueueType, + exclusiveBeginMessageID: 20, + maxRows: 10, + iter: nil, + wantErr: true, + }, + { + name: "iter close failed", + queueType: persistence.DomainReplicationQueueType, + exclusiveBeginMessageID: 20, + maxRows: 10, + iter: &fakeIter{closeErr: errors.New("some random error")}, + wantErr: true, + }, + { + name: "success", + queueType: persistence.DomainReplicationQueueType, + exclusiveBeginMessageID: 20, + maxRows: 10, + iter: &fakeIter{ + mapScanInputs: []map[string]interface{}{ + { + "message_id": int64(21), + "message_payload": []byte("test-payload-1"), + }, + { + "message_id": int64(22), + "message_payload": []byte("test-payload-2"), + }, + }, + }, + wantRows: []*nosqlplugin.QueueMessageRow{ + { + ID: 21, + Payload: []byte("test-payload-1"), + }, + { + ID: 22, + Payload: []byte("test-payload-2"), + }, + }, + wantQueries: []string{ + `SELECT message_id, message_payload FROM queue WHERE queue_type = 1 and message_id > 20 LIMIT 10`, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + if tc.iter != nil { + query.EXPECT().Iter().Return(tc.iter).Times(1) + } else { + query.EXPECT().Iter().Return(nil).Times(1) + } + + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + db := newCassandraDBFromSession(cfg, session, logger, nil, dbWithClient(client)) + + gotRows, err := db.SelectMessagesFrom(context.Background(), tc.queueType, tc.exclusiveBeginMessageID, tc.maxRows) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tc.wantRows, gotRows); diff != "" { + t.Fatalf("rows mismatch (-want +got):\n%s", diff) + } + + if !tc.iter.closed { + t.Fatal("iterator not closed") + } + }) + } +} + +func TestSelectMessagesBetween(t *testing.T) { + tests := []struct { + name string + request nosqlplugin.SelectMessagesBetweenRequest + iter *fakeIter + wantQueries []string + wantResp *nosqlplugin.SelectMessagesBetweenResponse + wantErr bool + }{ + { + name: "nil iter", + request: nosqlplugin.SelectMessagesBetweenRequest{ + QueueType: persistence.DomainReplicationQueueType, + ExclusiveBeginMessageID: 50, + InclusiveEndMessageID: 60, + PageSize: 5, + NextPageToken: []byte("next page token"), + }, + iter: nil, + wantErr: true, + }, + { + name: "iter close failed", + request: nosqlplugin.SelectMessagesBetweenRequest{ + QueueType: persistence.DomainReplicationQueueType, + ExclusiveBeginMessageID: 50, + InclusiveEndMessageID: 60, + PageSize: 5, + NextPageToken: []byte("next page token"), + }, + iter: &fakeIter{closeErr: errors.New("some random error")}, + wantErr: true, + }, + { + name: "success", + request: nosqlplugin.SelectMessagesBetweenRequest{ + QueueType: persistence.DomainReplicationQueueType, + ExclusiveBeginMessageID: 50, + InclusiveEndMessageID: 60, + PageSize: 5, + NextPageToken: []byte("next page token"), + }, + iter: &fakeIter{ + mapScanInputs: []map[string]interface{}{ + { + "message_id": int64(51), + "message_payload": []byte("test-payload-1"), + }, + { + "message_id": int64(52), + "message_payload": []byte("test-payload-2"), + }, + { + "message_id": int64(53), + "message_payload": []byte("test-payload-3"), + }, + }, + pageState: []byte("more pages"), + }, + wantResp: &nosqlplugin.SelectMessagesBetweenResponse{ + Rows: []nosqlplugin.QueueMessageRow{ + { + ID: 51, + Payload: []byte("test-payload-1"), + }, + { + ID: 52, + Payload: []byte("test-payload-2"), + }, + { + ID: 53, + Payload: []byte("test-payload-3"), + }, + }, + NextPageToken: []byte("more pages"), + }, + wantQueries: []string{ + `SELECT message_id, message_payload FROM queue WHERE queue_type = 1 and message_id > 50 and message_id <= 60`, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + query.EXPECT().PageSize(tc.request.PageSize).Return(query).Times(1) + query.EXPECT().PageState(tc.request.NextPageToken).Return(query).Times(1) + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + if tc.iter != nil { + query.EXPECT().Iter().Return(tc.iter).Times(1) + } else { + query.EXPECT().Iter().Return(nil).Times(1) + } + + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + db := newCassandraDBFromSession(cfg, session, logger, nil, dbWithClient(client)) + + gotResp, err := db.SelectMessagesBetween(context.Background(), tc.request) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tc.wantResp, gotResp); diff != "" { + t.Fatalf("Response mismatch (-want +got):\n%s", diff) + } + + if !tc.iter.closed { + t.Fatal("iterator not closed") + } + }) + } +} + +func TestDeleteMessagesBefore(t *testing.T) { + tests := []struct { + name string + queueType persistence.QueueType + exclusiveBeginMessageID int64 + queryMockFn func(query *gocql.MockQuery) + wantQueries []string + wantErr bool + }{ + { + name: "success", + queueType: persistence.DomainReplicationQueueType, + exclusiveBeginMessageID: 100, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Exec().Return(nil).Times(1) + }, + wantQueries: []string{ + `DELETE FROM queue WHERE queue_type = 1 and message_id < 100`, + }, + }, + { + name: "failure", + queueType: persistence.DomainReplicationQueueType, + exclusiveBeginMessageID: 100, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Exec().Return(errors.New("some random error")).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + db := newCassandraDBFromSession(cfg, session, logger, nil, dbWithClient(client)) + + err := db.DeleteMessagesBefore(context.Background(), tc.queueType, tc.exclusiveBeginMessageID) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestDeleteMessagesInRange(t *testing.T) { + tests := []struct { + name string + queueType persistence.QueueType + exclusiveBeginMessageID int64 + inclusiveEndMsgID int64 + queryMockFn func(query *gocql.MockQuery) + wantQueries []string + wantErr bool + }{ + { + name: "success", + queueType: persistence.DomainReplicationQueueType, + exclusiveBeginMessageID: 100, + inclusiveEndMsgID: 200, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Exec().Return(nil).Times(1) + }, + wantQueries: []string{ + `DELETE FROM queue WHERE queue_type = 1 and message_id > 100 and message_id <= 200`, + }, + }, + { + name: "failure", + queueType: persistence.DomainReplicationQueueType, + exclusiveBeginMessageID: 100, + inclusiveEndMsgID: 200, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Exec().Return(errors.New("some random error")).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + db := newCassandraDBFromSession(cfg, session, logger, nil, dbWithClient(client)) + + err := db.DeleteMessagesInRange(context.Background(), tc.queueType, tc.exclusiveBeginMessageID, tc.inclusiveEndMsgID) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestDeleteMessage(t *testing.T) { + tests := []struct { + name string + queueType persistence.QueueType + msgID int64 + queryMockFn func(query *gocql.MockQuery) + wantQueries []string + wantErr bool + }{ + { + name: "success", + queueType: persistence.DomainReplicationQueueType, + msgID: 36, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Exec().Return(nil).Times(1) + }, + wantQueries: []string{ + `DELETE FROM queue WHERE queue_type = 1 and message_id = 36`, + }, + }, + { + name: "failure", + queueType: persistence.DomainReplicationQueueType, + msgID: 36, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().Exec().Return(errors.New("some random error")).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + db := newCassandraDBFromSession(cfg, session, logger, nil, dbWithClient(client)) + + err := db.DeleteMessage(context.Background(), tc.queueType, tc.msgID) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestInsertQueueMetadata(t *testing.T) { + tests := []struct { + name string + queueType persistence.QueueType + version int64 + queryMockFn func(query *gocql.MockQuery) + wantQueries []string + wantErr bool + }{ + { + name: "success", + queueType: persistence.QueueType(2), + version: 25, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().ScanCAS(gomock.Any()).Return(false, nil).Times(1) + }, + wantQueries: []string{ + `INSERT INTO queue_metadata (queue_type, cluster_ack_level, version) VALUES(2, map[], 25) IF NOT EXISTS`, + }, + }, + { + name: "scan failure", + queueType: persistence.QueueType(2), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().ScanCAS(gomock.Any()).Return(false, errors.New("some random error")).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client)) + + err := db.InsertQueueMetadata(context.Background(), tc.queueType, tc.version) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestUpdateQueueMetadataCas(t *testing.T) { + tests := []struct { + name string + row nosqlplugin.QueueMetadataRow + queryMockFn func(query *gocql.MockQuery) + wantQueries []string + wantErr bool + }{ + { + name: "successfully applied", + row: nosqlplugin.QueueMetadataRow{ + QueueType: persistence.QueueType(2), + ClusterAckLevels: map[string]int64{"cluster1": 1000, "cluster2": 2000}, + Version: 25, + }, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().ScanCAS(gomock.Any()).Return(true, nil).Times(1) + }, + wantQueries: []string{ + `UPDATE queue_metadata SET cluster_ack_level = map[cluster1:1000 cluster2:2000], version = 25 WHERE queue_type = 2 IF version = 24`, + }, + }, + { + name: "could not apply", + row: nosqlplugin.QueueMetadataRow{ + QueueType: persistence.QueueType(2), + ClusterAckLevels: map[string]int64{"cluster1": 1000, "cluster2": 2000}, + Version: 25, + }, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().ScanCAS(gomock.Any()).Return(false, nil).Times(1) + }, + wantErr: true, + }, + { + name: "scancas failed", + row: nosqlplugin.QueueMetadataRow{ + QueueType: persistence.QueueType(2), + ClusterAckLevels: map[string]int64{"cluster1": 1000, "cluster2": 2000}, + Version: 25, + }, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().ScanCAS(gomock.Any()).Return(false, errors.New("some random error")).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client)) + + err := db.UpdateQueueMetadataCas(context.Background(), tc.row) + + if (err != nil) != tc.wantErr { + t.Errorf("Got error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} +func queueMessageRow(id int64) *nosqlplugin.QueueMessageRow { + return &nosqlplugin.QueueMessageRow{ + QueueType: persistence.DomainReplicationQueueType, + ID: id, + Payload: []byte(fmt.Sprintf("test-payload-%d", id)), + } +} diff --git a/common/persistence/nosql/nosqlplugin/cassandra/shard.go b/common/persistence/nosql/nosqlplugin/cassandra/shard.go index cedb51d94b6..de88337b349 100644 --- a/common/persistence/nosql/nosqlplugin/cassandra/shard.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/shard.go @@ -34,7 +34,7 @@ import ( // InsertShard creates a new shard, return error is there is any. // Return ShardOperationConditionFailure if the condition doesn't meet func (db *cdb) InsertShard(ctx context.Context, row *nosqlplugin.ShardRow) error { - cqlNowTimestamp := persistence.UnixNanoToDBTimestamp(time.Now().UnixNano()) + cqlNowTimestamp := persistence.UnixNanoToDBTimestamp(db.timeSrc.Now().UnixNano()) markerData, markerEncoding := persistence.FromDataBlob(row.PendingFailoverMarkers) transferPQS, transferPQSEncoding := persistence.FromDataBlob(row.TransferProcessingQueueStates) crossClusterPQS, crossClusterPQSEncoding := persistence.FromDataBlob(row.CrossClusterProcessingQueueStates) @@ -247,7 +247,7 @@ func (db *cdb) UpdateRangeID(ctx context.Context, shardID int, rangeID int64, pr // UpdateShard updates a shard, return error is there is any. // Return ShardOperationConditionFailure if the condition doesn't meet func (db *cdb) UpdateShard(ctx context.Context, row *nosqlplugin.ShardRow, previousRangeID int64) error { - cqlNowTimestamp := persistence.UnixNanoToDBTimestamp(time.Now().UnixNano()) + cqlNowTimestamp := persistence.UnixNanoToDBTimestamp(db.timeSrc.Now().UnixNano()) markerData, markerEncoding := persistence.FromDataBlob(row.PendingFailoverMarkers) transferPQS, transferPQSEncoding := persistence.FromDataBlob(row.TransferProcessingQueueStates) crossClusterPQS, crossClusterPQSEncoding := persistence.FromDataBlob(row.CrossClusterProcessingQueueStates) diff --git a/common/persistence/nosql/nosqlplugin/cassandra/shard_cql.go b/common/persistence/nosql/nosqlplugin/cassandra/shard_cql.go index 1084029263c..cfdccc58c67 100644 --- a/common/persistence/nosql/nosqlplugin/cassandra/shard_cql.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/shard_cql.go @@ -47,7 +47,7 @@ const ( `}` templateCreateShardQuery = `INSERT INTO executions (` + - `shard_id, type, domain_id, workflow_id, run_id, visibility_ts, task_id, shard, range_id)` + + `shard_id, type, domain_id, workflow_id, run_id, visibility_ts, task_id, shard, range_id) ` + `VALUES(?, ?, ?, ?, ?, ?, ?, ` + templateShardType + `, ?) IF NOT EXISTS` templateGetShardQuery = `SELECT shard, range_id ` + diff --git a/common/persistence/nosql/nosqlplugin/cassandra/shard_test.go b/common/persistence/nosql/nosqlplugin/cassandra/shard_test.go new file mode 100644 index 00000000000..f29ac8522c0 --- /dev/null +++ b/common/persistence/nosql/nosqlplugin/cassandra/shard_test.go @@ -0,0 +1,506 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package cassandra + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/go-cmp/cmp" + + "github.com/uber/cadence/common/clock" + "github.com/uber/cadence/common/config" + "github.com/uber/cadence/common/log/testlogger" + "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin/cassandra/gocql" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin/cassandra/testdata" +) + +func TestInsertShard(t *testing.T) { + ts, err := time.Parse(time.RFC3339, "2024-04-02T18:00:00Z") + if err != nil { + t.Fatalf("Failed to parse time: %v", err) + } + + tests := []struct { + name string + row *nosqlplugin.ShardRow + queryMockFn func(query *gocql.MockQuery) + wantQueries []string + wantErr bool + }{ + { + name: "successfully applied", + row: testdata.NewShardRow(ts), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return true, nil + }).Times(1) + }, + wantQueries: []string{ + `INSERT INTO executions (` + + `shard_id, type, domain_id, workflow_id, run_id, ` + + `visibility_ts, task_id, ` + + `shard, ` + + `range_id` + + `) ` + + `VALUES(` + + `15, 0, 10000000-1000-f000-f000-000000000000, 20000000-1000-f000-f000-000000000000, 30000000-1000-f000-f000-000000000000, ` + + `946684800000, -11, ` + + `{shard_id: 15, owner: owner, range_id: 1000, stolen_since_renew: 0, updated_at: 1712080800000, replication_ack_level: 2000, transfer_ack_level: 3000, timer_ack_level: 2024-04-02T17:00:00Z, cluster_transfer_ack_level: map[cluster2:4000], cluster_timer_ack_level: map[cluster2:2024-04-02 16:00:00 +0000 UTC], transfer_processing_queue_states: [116 114 97 110 115 102 101 114 113 117 101 117 101], transfer_processing_queue_states_encoding: thriftrw, cross_cluster_processing_queue_states: [120 99 108 117 115 116 101 114 113 117 101 117 101], cross_cluster_processing_queue_states_encoding: thriftrw, timer_processing_queue_states: [116 105 109 101 114 113 117 101 117 101], timer_processing_queue_states_encoding: thriftrw, domain_notification_version: 3, cluster_replication_level: map[cluster2:5000], replication_dlq_ack_level: map[cluster2:10], pending_failover_markers: [102 97 105 108 111 118 101 114 109 97 114 107 101 114 115], pending_failover_markers_encoding: thriftrw }, ` + + `1000` + + `) IF NOT EXISTS`, + }, + }, + { + name: "not applied", + row: testdata.NewShardRow(ts), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + m["range_id"] = int64(1001) + return false, nil + }).Times(1) + }, + wantErr: true, + }, + { + name: "mapscancas failed", + row: testdata.NewShardRow(ts), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return false, errors.New("mapscancas failed") + }).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + timeSrc := clock.NewMockedTimeSourceAt(ts) + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client), dbWithTimeSource(timeSrc)) + + err := db.InsertShard(context.Background(), tc.row) + + if (err != nil) != tc.wantErr { + t.Errorf("InsertShard() error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestSelectShard(t *testing.T) { + ts, err := time.Parse(time.RFC3339, "2024-04-02T18:00:00Z") + if err != nil { + t.Fatalf("Failed to parse time: %v", err) + } + + tests := []struct { + name string + shardID int + cluster string + queryMockFn func(query *gocql.MockQuery) + wantQueries []string + wantRangeID int64 + wantShardInfo *persistence.InternalShardInfo + wantErr bool + }{ + { + name: "success with shard map fully populated", + shardID: 15, + cluster: "cluster1", + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScan(gomock.Any()).DoAndReturn(func(m map[string]interface{}) error { + m["range_id"] = int64(1000) + m["shard"] = testdata.NewShardMap(ts) + return nil + }).Times(1) + }, + wantRangeID: int64(1000), + wantShardInfo: &persistence.InternalShardInfo{ + ShardID: 15, + Owner: "owner", + RangeID: 1000, + UpdatedAt: ts, + ReplicationAckLevel: 2000, + ReplicationDLQAckLevel: map[string]int64{"cluster2": 5}, + TransferAckLevel: 3000, + TimerAckLevel: ts.Add(-1 * time.Hour), + ClusterTransferAckLevel: map[string]int64{"cluster1": 3000}, + ClusterTimerAckLevel: map[string]time.Time{"cluster1": ts.Add(-1 * time.Hour)}, + TransferProcessingQueueStates: &persistence.DataBlob{Encoding: "thriftrw", Data: []uint8("transferqueue")}, + CrossClusterProcessingQueueStates: &persistence.DataBlob{Encoding: "thriftrw", Data: []uint8("xclusterqueue")}, + TimerProcessingQueueStates: &persistence.DataBlob{Encoding: "thriftrw", Data: []uint8("timerqueue")}, + ClusterReplicationLevel: map[string]int64{"cluster2": 1500}, + DomainNotificationVersion: 3, + PendingFailoverMarkers: &persistence.DataBlob{Encoding: "thriftrw", Data: []uint8("failovermarkers")}, + }, + wantQueries: []string{ + `SELECT shard, range_id FROM executions WHERE ` + + `shard_id = 15 and type = 0 and ` + + `domain_id = 10000000-1000-f000-f000-000000000000 and ` + + `workflow_id = 20000000-1000-f000-f000-000000000000 and ` + + `run_id = 30000000-1000-f000-f000-000000000000 and ` + + `visibility_ts = 946684800000 and ` + + `task_id = -11`, + }, + }, + { + name: "success with shard map missing some fields", + shardID: 15, + cluster: "cluster1", + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScan(gomock.Any()).DoAndReturn(func(m map[string]interface{}) error { + m["range_id"] = int64(1000) + sm := testdata.NewShardMap(ts) + // delete some fields and validate they are initialized properly + delete(sm, "cluster_transfer_ack_level") + delete(sm, "cluster_timer_ack_level") + delete(sm, "cluster_replication_level") + delete(sm, "replication_dlq_ack_level") + m["shard"] = sm + return nil + }).Times(1) + }, + wantRangeID: int64(1000), + wantShardInfo: &persistence.InternalShardInfo{ + ShardID: 15, + Owner: "owner", + RangeID: 1000, + UpdatedAt: ts, + ReplicationAckLevel: 2000, + ReplicationDLQAckLevel: map[string]int64{}, // this was reset to empty map + TransferAckLevel: 3000, + TimerAckLevel: ts.Add(-1 * time.Hour), + ClusterTransferAckLevel: map[string]int64{"cluster1": 3000}, + ClusterTimerAckLevel: map[string]time.Time{"cluster1": ts.Add(-1 * time.Hour)}, + TransferProcessingQueueStates: &persistence.DataBlob{Encoding: "thriftrw", Data: []uint8("transferqueue")}, + CrossClusterProcessingQueueStates: &persistence.DataBlob{Encoding: "thriftrw", Data: []uint8("xclusterqueue")}, + TimerProcessingQueueStates: &persistence.DataBlob{Encoding: "thriftrw", Data: []uint8("timerqueue")}, + ClusterReplicationLevel: map[string]int64{}, // this was reset to empty map + DomainNotificationVersion: 3, + PendingFailoverMarkers: &persistence.DataBlob{Encoding: "thriftrw", Data: []uint8("failovermarkers")}, + }, + wantQueries: []string{ + `SELECT shard, range_id FROM executions WHERE ` + + `shard_id = 15 and type = 0 and ` + + `domain_id = 10000000-1000-f000-f000-000000000000 and ` + + `workflow_id = 20000000-1000-f000-f000-000000000000 and ` + + `run_id = 30000000-1000-f000-f000-000000000000 and ` + + `visibility_ts = 946684800000 and ` + + `task_id = -11`, + }, + }, + { + name: "mapscan failed", + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScan(gomock.Any()).DoAndReturn(func(m map[string]interface{}) error { + return errors.New("mapscan failed") + }).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + timeSrc := clock.NewMockedTimeSourceAt(ts) + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client), dbWithTimeSource(timeSrc)) + + gotRangeID, gotShardInfo, err := db.SelectShard(context.Background(), tc.shardID, tc.cluster) + + if (err != nil) != tc.wantErr { + t.Errorf("SelectShard() error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if gotRangeID != tc.wantRangeID { + t.Errorf("Got RangeID = %v, want %v", gotRangeID, tc.wantRangeID) + } + + if diff := cmp.Diff(tc.wantShardInfo, gotShardInfo); diff != "" { + t.Fatalf("ShardInfo mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestUpdateRangeID(t *testing.T) { + ts, err := time.Parse(time.RFC3339, "2024-04-02T18:00:00Z") + if err != nil { + t.Fatalf("Failed to parse time: %v", err) + } + + tests := []struct { + name string + shardID int + rangeID int64 + prevRangeID int64 + queryMockFn func(query *gocql.MockQuery) + wantQueries []string + wantErr bool + }{ + { + name: "successfully applied", + shardID: 15, + rangeID: 1000, + prevRangeID: 999, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return true, nil + }).Times(1) + }, + wantQueries: []string{ + `UPDATE executions SET range_id = 1000 WHERE ` + + `shard_id = 15 and ` + + `type = 0 and ` + + `domain_id = 10000000-1000-f000-f000-000000000000 and ` + + `workflow_id = 20000000-1000-f000-f000-000000000000 and ` + + `run_id = 30000000-1000-f000-f000-000000000000 and ` + + `visibility_ts = 946684800000 and ` + + `task_id = -11 ` + + `IF range_id = 999`, + }, + }, + { + name: "not applied", + shardID: 15, + rangeID: 1000, + prevRangeID: 999, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + m["range_id"] = int64(1001) + return false, nil + }).Times(1) + }, + wantErr: true, + }, + { + name: "mapscancas failed", + shardID: 15, + rangeID: 1000, + prevRangeID: 999, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return false, errors.New("mapscancas failed") + }).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + timeSrc := clock.NewMockedTimeSourceAt(ts) + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client), dbWithTimeSource(timeSrc)) + + err := db.UpdateRangeID(context.Background(), tc.shardID, tc.rangeID, tc.prevRangeID) + + if (err != nil) != tc.wantErr { + t.Errorf("UpdateRangeID() error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestUpdateShard(t *testing.T) { + ts, err := time.Parse(time.RFC3339, "2024-04-02T18:00:00Z") + if err != nil { + t.Fatalf("Failed to parse time: %v", err) + } + + tests := []struct { + name string + row *nosqlplugin.ShardRow + prevRangeID int64 + queryMockFn func(query *gocql.MockQuery) + wantQueries []string + wantErr bool + }{ + { + name: "successfully applied", + row: testdata.NewShardRow(ts), + prevRangeID: 988, + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return true, nil + }).Times(1) + }, + wantQueries: []string{ + `UPDATE executions SET shard = {` + + `shard_id: 15, ` + + `owner: owner, ` + + `range_id: 1000, ` + + `stolen_since_renew: 0, ` + + `updated_at: 1712080800000, ` + + `replication_ack_level: 2000, ` + + `transfer_ack_level: 3000, ` + + `timer_ack_level: 2024-04-02T17:00:00Z, ` + + `cluster_transfer_ack_level: map[cluster2:4000], ` + + `cluster_timer_ack_level: map[cluster2:2024-04-02 16:00:00 +0000 UTC], ` + + `transfer_processing_queue_states: [116 114 97 110 115 102 101 114 113 117 101 117 101], ` + + `transfer_processing_queue_states_encoding: thriftrw, ` + + `cross_cluster_processing_queue_states: [120 99 108 117 115 116 101 114 113 117 101 117 101], ` + + `cross_cluster_processing_queue_states_encoding: thriftrw, ` + + `timer_processing_queue_states: [116 105 109 101 114 113 117 101 117 101], ` + + `timer_processing_queue_states_encoding: thriftrw, ` + + `domain_notification_version: 3, ` + + `cluster_replication_level: map[cluster2:5000], ` + + `replication_dlq_ack_level: map[cluster2:10], ` + + `pending_failover_markers: [102 97 105 108 111 118 101 114 109 97 114 107 101 114 115], ` + + `pending_failover_markers_encoding: thriftrw ` + + `}, ` + + `range_id = 1000 ` + + `WHERE ` + + `shard_id = 15 and ` + + `type = 0 and ` + + `domain_id = 10000000-1000-f000-f000-000000000000 and ` + + `workflow_id = 20000000-1000-f000-f000-000000000000 and ` + + `run_id = 30000000-1000-f000-f000-000000000000 and ` + + `visibility_ts = 946684800000 and ` + + `task_id = -11 ` + + `IF range_id = 988`, + }, + }, + { + name: "not applied", + row: testdata.NewShardRow(ts), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + m["range_id"] = int64(1001) + return false, nil + }).Times(1) + }, + wantErr: true, + }, + { + name: "mapscancas failed", + row: testdata.NewShardRow(ts), + queryMockFn: func(query *gocql.MockQuery) { + query.EXPECT().WithContext(gomock.Any()).Return(query).Times(1) + query.EXPECT().MapScanCAS(gomock.Any()).DoAndReturn(func(m map[string]interface{}) (bool, error) { + return false, errors.New("mapscancas failed") + }).Times(1) + }, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + query := gocql.NewMockQuery(ctrl) + tc.queryMockFn(query) + session := &fakeSession{ + query: query, + } + client := gocql.NewMockClient(ctrl) + cfg := &config.NoSQL{} + logger := testlogger.New(t) + dc := &persistence.DynamicConfiguration{} + timeSrc := clock.NewMockedTimeSourceAt(ts) + db := newCassandraDBFromSession(cfg, session, logger, dc, dbWithClient(client), dbWithTimeSource(timeSrc)) + + err := db.UpdateShard(context.Background(), tc.row, tc.prevRangeID) + + if (err != nil) != tc.wantErr { + t.Errorf("UpdateShard() error = %v, wantErr %v", err, tc.wantErr) + } + + if err != nil { + return + } + + t.Log(session.queries[0]) + if diff := cmp.Diff(tc.wantQueries, session.queries); diff != "" { + t.Fatalf("Query mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/common/persistence/nosql/nosqlplugin/cassandra/tasks_test.go b/common/persistence/nosql/nosqlplugin/cassandra/tasks_test.go index e94360094d4..5a824afcf2f 100644 --- a/common/persistence/nosql/nosqlplugin/cassandra/tasks_test.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/tasks_test.go @@ -83,7 +83,7 @@ func TestSelectTaskList(t *testing.T) { }, }, { - name: "scal failure", + name: "scan failure", filter: &nosqlplugin.TaskListFilter{ DomainID: "domain1", TaskListName: "tasklist1", diff --git a/common/persistence/nosql/nosqlplugin/cassandra/testdata/domain.go b/common/persistence/nosql/nosqlplugin/cassandra/testdata/domain.go new file mode 100644 index 00000000000..e294673722d --- /dev/null +++ b/common/persistence/nosql/nosqlplugin/cassandra/testdata/domain.go @@ -0,0 +1,71 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package testdata + +import ( + "time" + + "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin" + "github.com/uber/cadence/common/types" +) + +func NewDomainRow(ts time.Time) *nosqlplugin.DomainRow { + return &nosqlplugin.DomainRow{ + Info: &persistence.DomainInfo{ + ID: "test-domain-id", + Name: "test-domain-name", + Status: persistence.DomainStatusRegistered, + Description: "test-domain-description", + OwnerEmail: "test-domain-owner-email", + Data: map[string]string{"k1": "v1"}, + }, + Config: &nosqlplugin.NoSQLInternalDomainConfig{ + Retention: 7 * 24 * time.Hour, + EmitMetric: true, + ArchivalBucket: "test-archival-bucket", + ArchivalStatus: types.ArchivalStatusEnabled, + HistoryArchivalStatus: types.ArchivalStatusEnabled, + HistoryArchivalURI: "test-history-archival-uri", + VisibilityArchivalStatus: types.ArchivalStatusEnabled, + VisibilityArchivalURI: "test-visibility-archival-uri", + BadBinaries: &persistence.DataBlob{Encoding: "thriftrw", Data: []byte("bad-binaries")}, + IsolationGroups: &persistence.DataBlob{Encoding: "thriftrw", Data: []byte("isolation-group")}, + AsyncWorkflowsConfig: &persistence.DataBlob{Encoding: "thriftrw", Data: []byte("async-workflows-config")}, + }, + ReplicationConfig: &persistence.DomainReplicationConfig{ + ActiveClusterName: "test-active-cluster-name", + Clusters: []*persistence.ClusterReplicationConfig{ + { + ClusterName: "test-cluster-name", + }, + }, + }, + IsGlobalDomain: true, + ConfigVersion: 3, + FailoverVersion: 4, + FailoverEndTime: &ts, + LastUpdatedTime: ts, + NotificationVersion: 5, + } +} diff --git a/common/persistence/nosql/nosqlplugin/cassandra/testdata/shard.go b/common/persistence/nosql/nosqlplugin/cassandra/testdata/shard.go new file mode 100644 index 00000000000..fdac47c772b --- /dev/null +++ b/common/persistence/nosql/nosqlplugin/cassandra/testdata/shard.go @@ -0,0 +1,80 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package testdata + +import ( + "time" + + "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin" +) + +func NewShardRow(ts time.Time) *nosqlplugin.ShardRow { + return &nosqlplugin.ShardRow{ + ShardID: 15, + Owner: "owner", + RangeID: 1000, + ReplicationAckLevel: 2000, + TransferAckLevel: 3000, + TimerAckLevel: ts.Add(-time.Hour), + ClusterTransferAckLevel: map[string]int64{"cluster2": 4000}, + ClusterTimerAckLevel: map[string]time.Time{"cluster2": ts.Add(-2 * time.Hour)}, + DomainNotificationVersion: 3, + ClusterReplicationLevel: map[string]int64{"cluster2": 5000}, + ReplicationDLQAckLevel: map[string]int64{"cluster2": 10}, + PendingFailoverMarkers: &persistence.DataBlob{Encoding: "thriftrw", Data: []byte("failovermarkers")}, + TransferProcessingQueueStates: &persistence.DataBlob{Encoding: "thriftrw", Data: []byte("transferqueue")}, + CrossClusterProcessingQueueStates: &persistence.DataBlob{Encoding: "thriftrw", Data: []byte("xclusterqueue")}, + TimerProcessingQueueStates: &persistence.DataBlob{Encoding: "thriftrw", Data: []byte("timerqueue")}, + } +} + +func NewShardMap(ts time.Time) map[string]interface{} { + return map[string]interface{}{ + "shard_id": int(15), + "range_id": int64(1000), + "owner": "owner", + "stolen_since_renew": 0, + "updated_at": ts, + "replication_ack_level": int64(2000), + "transfer_ack_level": int64(3000), + "timer_ack_level": ts.Add(-1 * time.Hour), + "cluster_transfer_ack_level": map[string]int64{ + "cluster1": int64(3000), + }, + "cluster_timer_ack_level": map[string]time.Time{ + "cluster1": ts.Add(-1 * time.Hour), + }, + "transfer_processing_queue_states": []byte("transferqueue"), + "transfer_processing_queue_states_encoding": "thriftrw", + "cross_cluster_processing_queue_states": []byte("xclusterqueue"), + "cross_cluster_processing_queue_states_encoding": "thriftrw", + "timer_processing_queue_states": []byte("timerqueue"), + "timer_processing_queue_states_encoding": "thriftrw", + "domain_notification_version": int64(3), + "cluster_replication_level": map[string]int64{"cluster2": 1500}, + "replication_dlq_ack_level": map[string]int64{"cluster2": 5}, + "pending_failover_markers": []byte("failovermarkers"), + "pending_failover_markers_encoding": "thriftrw", + } +} diff --git a/common/persistence/nosql/nosqlplugin/cassandra/workflow_parsing_utils.go b/common/persistence/nosql/nosqlplugin/cassandra/workflow_parsing_utils.go index 551e1ed5f9b..ceee5f4a872 100644 --- a/common/persistence/nosql/nosqlplugin/cassandra/workflow_parsing_utils.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/workflow_parsing_utils.go @@ -542,19 +542,7 @@ func parseTransferTaskInfo( func parseCrossClusterTaskInfo( result map[string]interface{}, ) *persistence.CrossClusterTaskInfo { - info := (*persistence.CrossClusterTaskInfo)(parseTransferTaskInfo(result)) - if persistence.CrossClusterTaskDefaultTargetRunID == persistence.TransferTaskTransferTargetRunID { - return info - } - - // incase CrossClusterTaskDefaultTargetRunID is updated and not equal to TransferTaskTransferTargetRunID - if v, ok := result["target_run_id"]; ok { - info.TargetRunID = v.(gocql.UUID).String() - if info.TargetRunID == persistence.CrossClusterTaskDefaultTargetRunID { - info.TargetRunID = "" - } - } - return info + return parseTransferTaskInfo(result) } func parseReplicationTaskInfo( diff --git a/common/persistence/nosql/nosqlplugin/cassandra/workflow_parsing_utils_test.go b/common/persistence/nosql/nosqlplugin/cassandra/workflow_parsing_utils_test.go new file mode 100644 index 00000000000..dadfd6e96a6 --- /dev/null +++ b/common/persistence/nosql/nosqlplugin/cassandra/workflow_parsing_utils_test.go @@ -0,0 +1,523 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package cassandra + +import ( + "testing" + "time" + + cql "github.com/gocql/gocql" + "github.com/stretchr/testify/assert" + + "github.com/uber/cadence/common" + "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/persistence/nosql/nosqlplugin" +) + +type mockUUID struct { + uuid string +} + +func (m mockUUID) String() string { + return m.uuid +} + +func newMockUUID(s string) mockUUID { + return mockUUID{s} +} + +func Test_parseWorkflowExecutionInfo(t *testing.T) { + + completionEventData := []byte("completion event data") + autoResetPointsData := []byte("auto reset points data") + searchAttributes := map[string][]byte{"AttributeKey": []byte("AttributeValue")} + memo := map[string][]byte{"MemoKey": []byte("MemoValue")} + partitionConfig := map[string]string{"PartitionKey": "PartitionValue"} + timeNow := time.Now() + + tests := []struct { + args map[string]interface{} + want *persistence.InternalWorkflowExecutionInfo + }{ + { + args: map[string]interface{}{ + "domain_id": newMockUUID("domain_id"), + "workflow_id": "workflow_id", + "run_id": newMockUUID("run_id"), + "parent_workflow_id": "parent_workflow_id", + "initiated_id": int64(1), + "completion_event_batch_id": int64(2), + "task_list": "task_list", + "workflow_type_name": "workflow_type_name", + "workflow_timeout": 10, + "decision_task_timeout": 5, + "execution_context": []byte("execution context"), + "state": 1, + "close_status": 2, + "last_first_event_id": int64(3), + "last_event_task_id": int64(4), + "next_event_id": int64(5), + "last_processed_event": int64(6), + "start_time": timeNow, + "last_updated_time": timeNow, + "create_request_id": newMockUUID("create_request_id"), + "signal_count": 7, + "history_size": int64(8), + "decision_version": int64(9), + "decision_schedule_id": int64(10), + "decision_started_id": int64(11), + "decision_request_id": "decision_request_id", + "decision_timeout": 8, + "decision_timestamp": int64(200), + "decision_scheduled_timestamp": int64(201), + "decision_original_scheduled_timestamp": int64(202), + "decision_attempt": int64(203), + "cancel_requested": true, + "cancel_request_id": "cancel_request_id", + "sticky_task_list": "sticky_task_list", + "sticky_schedule_to_start_timeout": 9, + "client_library_version": "client_lib_version", + "client_feature_version": "client_feature_version", + "client_impl": "client_impl", + "attempt": 12, + "has_retry_policy": true, + "init_interval": 10, + "backoff_coefficient": 1.5, + "max_interval": 20, + "max_attempts": 13, + "expiration_time": timeNow, + "non_retriable_errors": []string{"error1", "error2"}, + "branch_token": []byte("branch token"), + "cron_schedule": "cron_schedule", + "expiration_seconds": 14, + "search_attributes": searchAttributes, + "memo": memo, + "partition_config": partitionConfig, + "completion_event": completionEventData, + "completion_event_data_encoding": "Proto3", + "auto_reset_points": autoResetPointsData, + "auto_reset_points_encoding": "Proto3", + }, + want: &persistence.InternalWorkflowExecutionInfo{ + DomainID: "domain_id", + WorkflowID: "workflow_id", + RunID: "run_id", + ParentWorkflowID: "parent_workflow_id", + InitiatedID: int64(1), + CompletionEventBatchID: int64(2), + TaskList: "task_list", + WorkflowTypeName: "workflow_type_name", + WorkflowTimeout: common.SecondsToDuration(int64(10)), + DecisionStartToCloseTimeout: common.SecondsToDuration(int64(5)), + ExecutionContext: []byte("execution context"), + State: 1, + CloseStatus: 2, + LastFirstEventID: int64(3), + LastEventTaskID: int64(4), + NextEventID: int64(5), + LastProcessedEvent: int64(6), + StartTimestamp: timeNow, + LastUpdatedTimestamp: timeNow, + CreateRequestID: "create_request_id", + SignalCount: int32(7), + HistorySize: int64(8), + DecisionVersion: int64(9), + DecisionScheduleID: int64(10), + DecisionStartedID: int64(11), + DecisionRequestID: "decision_request_id", + DecisionTimeout: common.SecondsToDuration(int64(8)), + DecisionStartedTimestamp: time.Unix(0, int64(200)), + DecisionScheduledTimestamp: time.Unix(0, int64(201)), + DecisionOriginalScheduledTimestamp: time.Unix(0, int64(202)), + DecisionAttempt: int64(203), + CancelRequested: true, + CancelRequestID: "cancel_request_id", + StickyTaskList: "sticky_task_list", + StickyScheduleToStartTimeout: common.SecondsToDuration(int64(9)), + ClientLibraryVersion: "client_lib_version", + ClientFeatureVersion: "client_feature_version", + ClientImpl: "client_impl", + Attempt: int32(12), + HasRetryPolicy: true, + InitialInterval: common.SecondsToDuration(int64(10)), + BackoffCoefficient: 1.5, + MaximumInterval: common.SecondsToDuration(int64(20)), + MaximumAttempts: int32(13), + ExpirationTime: timeNow, + NonRetriableErrors: []string{"error1", "error2"}, + Memo: memo, + PartitionConfig: partitionConfig, + }, + }, + { + args: map[string]interface{}{ + "first_run_id": newMockUUID("first_run_id"), + "parent_domain_id": newMockUUID("parent_domain_id"), + "parent_run_id": newMockUUID("parent_run_id"), + }, + want: &persistence.InternalWorkflowExecutionInfo{ + FirstExecutionRunID: "first_run_id", + ParentDomainID: "parent_domain_id", + ParentRunID: "parent_run_id", + }, + }, + { + args: map[string]interface{}{ + "first_run_id": newMockUUID(emptyRunID), + "parent_domain_id": newMockUUID(emptyDomainID), + "parent_run_id": newMockUUID(emptyRunID), + }, + want: &persistence.InternalWorkflowExecutionInfo{}, + }, + { + args: map[string]interface{}{ + "first_run_id": newMockUUID(cql.UUID{}.String()), + }, + want: &persistence.InternalWorkflowExecutionInfo{ + FirstExecutionRunID: "", + }, + }, + } + for _, tt := range tests { + result := parseWorkflowExecutionInfo(tt.args) + assert.Equal(t, result.FirstExecutionRunID, tt.want.FirstExecutionRunID) + assert.Equal(t, result.DomainID, tt.want.DomainID) + assert.Equal(t, result.WorkflowID, tt.want.WorkflowID) + assert.Equal(t, result.RunID, tt.want.RunID) + assert.Equal(t, result.ParentWorkflowID, tt.want.ParentWorkflowID) + assert.Equal(t, result.InitiatedID, tt.want.InitiatedID) + assert.Equal(t, result.CompletionEventBatchID, tt.want.CompletionEventBatchID) + assert.Equal(t, result.TaskList, tt.want.TaskList) + assert.Equal(t, result.WorkflowTypeName, tt.want.WorkflowTypeName) + assert.Equal(t, result.WorkflowTimeout, tt.want.WorkflowTimeout) + assert.Equal(t, result.DecisionStartToCloseTimeout, tt.want.DecisionStartToCloseTimeout) + assert.Equal(t, result.ExecutionContext, tt.want.ExecutionContext) + assert.Equal(t, result.State, tt.want.State) + assert.Equal(t, result.CloseStatus, tt.want.CloseStatus) + assert.Equal(t, result.LastFirstEventID, tt.want.LastFirstEventID) + assert.Equal(t, result.LastEventTaskID, tt.want.LastEventTaskID) + assert.Equal(t, result.NextEventID, tt.want.NextEventID) + assert.Equal(t, result.LastProcessedEvent, tt.want.LastProcessedEvent) + assert.Equal(t, result.StartTimestamp, tt.want.StartTimestamp) + assert.Equal(t, result.LastUpdatedTimestamp, tt.want.LastUpdatedTimestamp) + assert.Equal(t, result.CreateRequestID, tt.want.CreateRequestID) + assert.Equal(t, result.SignalCount, tt.want.SignalCount) + assert.Equal(t, result.HistorySize, tt.want.HistorySize) + assert.Equal(t, result.DecisionVersion, tt.want.DecisionVersion) + assert.Equal(t, result.DecisionScheduleID, tt.want.DecisionScheduleID) + assert.Equal(t, result.DecisionStartedID, tt.want.DecisionStartedID) + assert.Equal(t, result.DecisionRequestID, tt.want.DecisionRequestID) + assert.Equal(t, result.DecisionTimeout, tt.want.DecisionTimeout) + assert.Equal(t, result.CancelRequested, tt.want.CancelRequested) + assert.Equal(t, result.DecisionStartedTimestamp, tt.want.DecisionStartedTimestamp) + assert.Equal(t, result.DecisionScheduledTimestamp, tt.want.DecisionScheduledTimestamp) + assert.Equal(t, result.DecisionOriginalScheduledTimestamp, tt.want.DecisionOriginalScheduledTimestamp) + assert.Equal(t, result.DecisionAttempt, tt.want.DecisionAttempt) + assert.Equal(t, result.ParentDomainID, tt.want.ParentDomainID) + } +} + +func Test_parseReplicationState(t *testing.T) { + tests := []struct { + args map[string]interface{} + want *persistence.ReplicationState + }{ + { + args: map[string]interface{}{ + "current_version": int64(1), + "start_version": int64(2), + "last_write_version": int64(3), + "last_write_event_id": int64(4), + "last_replication_info": map[string]map[string]interface{}{ + "map1": { + "version": int64(5), + "last_event_id": int64(6), + }, + "map2": { + "version": int64(7), + "last_event_id": int64(8), + }, + }, + }, + want: &persistence.ReplicationState{ + CurrentVersion: int64(1), + StartVersion: int64(2), + LastWriteVersion: int64(3), + LastWriteEventID: int64(4), + LastReplicationInfo: map[string]*persistence.ReplicationInfo{ + "map1": { + Version: int64(5), + LastEventID: int64(6), + }, + "map2": { + Version: int64(7), + LastEventID: int64(8), + }, + }, + }, + }, + } + for _, tt := range tests { + result := parseReplicationState(tt.args) + assert.Equal(t, result.CurrentVersion, tt.want.CurrentVersion) + assert.Equal(t, result.StartVersion, tt.want.StartVersion) + assert.Equal(t, result.LastWriteVersion, tt.want.LastWriteVersion) + assert.Equal(t, result.LastWriteEventID, tt.want.LastWriteEventID) + assert.Equal(t, result.LastReplicationInfo, tt.want.LastReplicationInfo) + } +} + +func Test_parseActivityInfo(t *testing.T) { + timeNow := time.Now() + testInput := map[string]interface{}{ + "version": int64(1), + "schedule_id": int64(2), + "scheduled_event_batch_id": int64(3), + "scheduled_event": []byte("scheduled_event"), + "scheduled_time": timeNow, + "started_id": int64(4), + "started_event": []byte("started_event"), + "started_time": timeNow, + "activity_id": "activity_id", + "request_id": "request_id", + "details": []byte("details"), + "schedule_to_start_timeout": 5, + "schedule_to_close_timeout": 6, + "start_to_close_timeout": 7, + "heart_beat_timeout": 8, + "cancel_requested": true, + "cancel_request_id": int64(9), + "last_hb_updated_time": timeNow, + "timer_task_status": 9, + "attempt": 10, + "task_list": "task_list", + "started_identity": "started_identity", + "has_retry_policy": true, + "init_interval": 11, + "backoff_coefficient": 1.5, + "max_interval": 12, + "max_attempts": 13, + "expiration_time": timeNow, + "non_retriable_errors": []string{"error1", "error2"}, + "last_failure_reason": "last_failure_reason", + "last_worker_identity": "last_worker_identity", + "last_failure_details": []byte("last_failure_details"), + "event_data_encoding": "Proto3", + } + + expected := &persistence.InternalActivityInfo{ + Version: int64(1), + ScheduleID: int64(2), + ScheduledEventBatchID: int64(3), + ScheduledEvent: persistence.NewDataBlob([]byte("scheduled_event"), "Proto3"), + ScheduledTime: timeNow, + StartedID: int64(4), + StartedEvent: persistence.NewDataBlob([]byte("started_event"), "Proto3"), + StartedTime: timeNow, + ActivityID: "activity_id", + RequestID: "request_id", + Details: []byte("details"), + ScheduleToStartTimeout: common.SecondsToDuration(int64(5)), + ScheduleToCloseTimeout: common.SecondsToDuration(int64(6)), + StartToCloseTimeout: common.SecondsToDuration(int64(7)), + HeartbeatTimeout: common.SecondsToDuration(int64(8)), + CancelRequested: true, + CancelRequestID: int64(9), + LastHeartBeatUpdatedTime: timeNow, + TimerTaskStatus: int32(9), + Attempt: int32(10), + TaskList: "task_list", + StartedIdentity: "started_identity", + HasRetryPolicy: true, + InitialInterval: common.SecondsToDuration(int64(11)), + BackoffCoefficient: 1.5, + MaximumInterval: common.SecondsToDuration(int64(12)), + MaximumAttempts: int32(13), + ExpirationTime: timeNow, + NonRetriableErrors: []string{"error1", "error2"}, + LastFailureReason: "last_failure_reason", + LastWorkerIdentity: "last_worker_identity", + LastFailureDetails: []byte("last_failure_details"), + DomainID: "domain_id", + } + + assert.Equal(t, expected, parseActivityInfo("domain_id", testInput)) +} + +func Test_parseTimerInfo(t *testing.T) { + timeNow := time.Now() + testInput := map[string]interface{}{ + "version": int64(1), + "timer_id": "timer_id", + "started_id": int64(2), + "expiry_time": timeNow, + "task_id": int64(3), + } + expected := &persistence.TimerInfo{ + Version: int64(1), + TimerID: "timer_id", + StartedID: int64(2), + ExpiryTime: timeNow, + TaskStatus: int64(3), + } + assert.Equal(t, expected, parseTimerInfo(testInput)) +} + +func Test_parseChildExecutionInfo(t *testing.T) { + startedRunID := newMockUUID("started_run_id") + createRequestID := newMockUUID("create_request_id") + domainID := newMockUUID("domain_id") + testInput := map[string]interface{}{ + "version": int64(1), + "initiated_id": int64(2), + "initiated_event_batch_id": int64(3), + "initiated_event": []byte("initiated_event"), + "started_id": int64(4), + "started_workflow_id": "started_workflow_id", + "started_run_id": startedRunID, + "started_event": []byte("started_event"), + "create_request_id": createRequestID, + "event_data_encoding": "Proto3", + "domain_id": domainID, + "workflow_type_name": "workflow_type_name", + "parent_close_policy": 1, + } + expected := &persistence.InternalChildExecutionInfo{ + Version: int64(1), + InitiatedID: int64(2), + InitiatedEventBatchID: int64(3), + InitiatedEvent: persistence.NewDataBlob([]byte("initiated_event"), "Proto3"), + StartedID: int64(4), + StartedWorkflowID: "started_workflow_id", + StartedRunID: startedRunID.String(), + StartedEvent: persistence.NewDataBlob([]byte("started_event"), "Proto3"), + CreateRequestID: createRequestID.String(), + DomainID: domainID.String(), + WorkflowTypeName: "workflow_type_name", + ParentClosePolicy: 1, + } + assert.Equal(t, expected, parseChildExecutionInfo(testInput)) + + // edge case + testInput = map[string]interface{}{ + "domain_id": newMockUUID(_emptyUUID.String()), + "domain_name": "domain_name", + } + assert.Equal(t, "domain_name", parseChildExecutionInfo(testInput).DomainNameDEPRECATED) + assert.Equal(t, "", parseChildExecutionInfo(testInput).DomainID) +} + +func Test_parseRequestCancelInfo(t *testing.T) { + testInput := map[string]interface{}{ + "version": int64(1), + "initiated_id": int64(2), + "initiated_event_batch_id": int64(3), + "cancel_request_id": "cancel_request_id", + } + expected := &persistence.RequestCancelInfo{ + Version: int64(1), + InitiatedID: int64(2), + InitiatedEventBatchID: int64(3), + CancelRequestID: "cancel_request_id", + } + assert.Equal(t, expected, parseRequestCancelInfo(testInput)) +} + +func Test_parseSignalInfo(t *testing.T) { + testInput := map[string]interface{}{ + "version": int64(1), + "initiated_id": int64(2), + "initiated_event_batch_id": int64(3), + "signal_request_id": newMockUUID("signal_request_id"), + "signal_name": "signal_name", + "input": []byte("input"), + "control": []byte("control"), + } + expected := &persistence.SignalInfo{ + Version: int64(1), + InitiatedID: int64(2), + InitiatedEventBatchID: int64(3), + SignalName: "signal_name", + SignalRequestID: "signal_request_id", + Input: []byte("input"), + Control: []byte("control"), + } + assert.Equal(t, expected, parseSignalInfo(testInput)) +} + +func Test_parseTimerTaskInfo(t *testing.T) { + timeNow := time.Now() + testInput := map[string]interface{}{ + "version": int64(1), + "visibility_ts": timeNow, + "task_id": int64(2), + "run_id": newMockUUID("run_id"), + "type": 3, + "timeout_type": 3, + "event_id": int64(4), + "schedule_attempt": int64(5), + } + expected := &persistence.TimerTaskInfo{ + Version: int64(1), + VisibilityTimestamp: timeNow, + TaskID: int64(2), + RunID: "run_id", + TaskType: 3, + TimeoutType: 3, + EventID: int64(4), + ScheduleAttempt: int64(5), + } + assert.Equal(t, expected, parseTimerTaskInfo(testInput)) +} + +func Test_parseReplicationTaskInfo(t *testing.T) { + testInput := map[string]interface{}{ + "domain_id": newMockUUID("domain_id"), + "workflow_id": "workflow_id", + "run_id": newMockUUID("run_id"), + "task_id": int64(1), + "type": 2, + "first_event_id": int64(3), + "next_event_id": int64(4), + "version": int64(5), + "scheduled_id": int64(6), + "branch_token": []byte("branch_token"), + "new_run_branch_token": []byte("new_run_branch_token"), + "created_time": int64(7), + } + expected := &nosqlplugin.ReplicationTask{ + DomainID: "domain_id", + WorkflowID: "workflow_id", + RunID: "run_id", + TaskID: int64(1), + TaskType: 2, + FirstEventID: int64(3), + NextEventID: int64(4), + Version: int64(5), + ScheduledID: int64(6), + BranchToken: []byte("branch_token"), + NewRunBranchToken: []byte("new_run_branch_token"), + CreationTime: time.Unix(0, 7), + } + assert.Equal(t, expected, parseReplicationTaskInfo(testInput)) +} diff --git a/common/persistence/nosql/nosqlplugin/cassandra/workflow_utils_test.go b/common/persistence/nosql/nosqlplugin/cassandra/workflow_utils_test.go index 67c01bc2c85..40f8096931b 100644 --- a/common/persistence/nosql/nosqlplugin/cassandra/workflow_utils_test.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/workflow_utils_test.go @@ -131,17 +131,47 @@ func sanitizedRenderedQuery(queryTmpl string, args ...interface{}) string { type fakeIter struct { // input parametrs mapScanInputs []map[string]interface{} + scanInputs [][]interface{} pageState []byte closeErr error // output parameters mapScanCalls int + scanCalls int closed bool } // Scan is fake implementation of gocql.Iter.Scan -func (i *fakeIter) Scan(...interface{}) bool { - return false +func (i *fakeIter) Scan(outArgs ...interface{}) bool { + if i.scanCalls >= len(i.scanInputs) { + return false + } + + for j, v := range i.scanInputs[i.scanCalls] { + if len(outArgs) <= j { + panic(fmt.Sprintf("outArgs length: %d is less than expected: %d", len(outArgs), len(i.scanInputs[i.scanCalls]))) + } + + if v == nil { + continue + } + + dst := outArgs[j] + dstPtrValue := reflect.ValueOf(dst) + dstValue := reflect.Indirect(dstPtrValue) + + func() { + defer func() { + if r := recover(); r != nil { + panic(fmt.Sprintf("failed to set %dth value: %v to %v, inner panic: %s", j, v, dst, r)) + } + }() + dstValue.Set(reflect.ValueOf(v)) + }() + } + + i.scanCalls++ + return true } // MapScan is fake implementation of gocql.Iter.MapScan diff --git a/common/persistence/pinotiVsibilityDualManager.go b/common/persistence/pinot_visibility_dual_manager.go similarity index 100% rename from common/persistence/pinotiVsibilityDualManager.go rename to common/persistence/pinot_visibility_dual_manager.go diff --git a/common/persistence/pinot_visibility_dual_manager_test.go b/common/persistence/pinot_visibility_dual_manager_test.go new file mode 100644 index 00000000000..385696da8b4 --- /dev/null +++ b/common/persistence/pinot_visibility_dual_manager_test.go @@ -0,0 +1,1247 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package persistence + +import ( + "context" + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/uber/cadence/common" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log" +) + +func TestNewPinotVisibilityDualManager(t *testing.T) { + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + }{ + "Case1: nil case": { + mockDBVisibilityManager: nil, + mockPinotVisibilityManager: nil, + }, + "Case2: success case": { + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + assert.NotPanics(t, func() { + NewPinotVisibilityDualManager(test.mockDBVisibilityManager, test.mockPinotVisibilityManager, nil, nil, log.NewNoop()) + }) + }) + } +} + +func TestPinotDualManagerClose(t *testing.T) { + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + }{ + "Case1-1: success case with DB visibility is not nil": { + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().Close().Return().Times(1) + }, + }, + "Case1-2: success case with Pinot visibility is not nil": { + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().Close().Return().Times(1) + }, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, test.mockPinotVisibilityManager, nil, nil, log.NewNoop()) + assert.NotPanics(t, func() { + visibilityManager.Close() + }) + }) + } +} + +func TestPinotDualManagerGetName(t *testing.T) { + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + }{ + "Case1-1: success case with DB visibility is not nil": { + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().GetName().Return(testTableName).Times(1) + + }, + }, + "Case1-2: success case with Pinot visibility is not nil": { + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().GetName().Return(testTableName).Times(1) + }, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, test.mockPinotVisibilityManager, nil, nil, log.NewNoop()) + + assert.NotPanics(t, func() { + visibilityManager.GetName() + }) + }) + } +} + +func TestPinotDualRecordWorkflowExecutionStarted(t *testing.T) { + request := &RecordWorkflowExecutionStartedRequest{} + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *RecordWorkflowExecutionStartedRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionStarted(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionStarted(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, nil, test.advancedVisibilityWritingMode, log.NewNoop()) + + err := visibilityManager.RecordWorkflowExecutionStarted(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualRecordWorkflowExecutionClosed(t *testing.T) { + request := &RecordWorkflowExecutionClosedRequest{} + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + context context.Context + request *RecordWorkflowExecutionClosedRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case0-1: error case with advancedVisibilityWritingMode is nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + expectedError: fmt.Errorf("error"), + }, + "Case0-2: error case with Pinot has errors in dual mode": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(fmt.Errorf("error")).AnyTimes() + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeDual), + expectedError: fmt.Errorf("error"), + }, + "Case1-1: success case with DB visibility is not nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + context: context.Background(), + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + "Case1-3: success case with dual manager": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeDual), + expectedError: nil, + }, + "Case2-1: choose DB visibility manager when it is nil": { + context: context.Background(), + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + }, + "Case2-2: choose Pinot visibility manager when it is nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + }, + "Case2-3: choose both when Pinot is nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeDual), + }, + "Case2-4: choose both when DB is nil": { + context: context.Background(), + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeDual), + }, + "Case3-1: chooseVisibilityModeForAdmin when Pinot is nil": { + context: context.WithValue(context.Background(), VisibilityAdminDeletionKey("visibilityAdminDelete"), true), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + }, + "Case3-2: chooseVisibilityModeForAdmin when DB is nil": { + context: context.WithValue(context.Background(), VisibilityAdminDeletionKey("visibilityAdminDelete"), true), + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + }, + "Case3-3: chooseVisibilityModeForAdmin when both are not nil": { + context: context.WithValue(context.Background(), VisibilityAdminDeletionKey("visibilityAdminDelete"), true), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, nil, test.advancedVisibilityWritingMode, log.NewNoop()) + + err := visibilityManager.RecordWorkflowExecutionClosed(test.context, test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// test an edge case +func TestPinotChooseVisibilityModeForAdmin(t *testing.T) { + ctrl := gomock.NewController(t) + dbManager := NewMockVisibilityManager(ctrl) + esManager := NewMockVisibilityManager(ctrl) + mgr := NewPinotVisibilityDualManager(dbManager, esManager, nil, nil, log.NewNoop()) + dualManager := mgr.(*pinotVisibilityDualManager) + dualManager.dbVisibilityManager = nil + dualManager.pinotVisibilityManager = nil + assert.Equal(t, "INVALID_ADMIN_MODE", dualManager.chooseVisibilityModeForAdmin()) +} + +func TestPinotDualRecordWorkflowExecutionUninitialized(t *testing.T) { + request := &RecordWorkflowExecutionUninitializedRequest{} + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *RecordWorkflowExecutionUninitializedRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionUninitialized(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionUninitialized(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, nil, test.advancedVisibilityWritingMode, log.NewNoop()) + + err := visibilityManager.RecordWorkflowExecutionUninitialized(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualUpsertWorkflowExecution(t *testing.T) { + request := &UpsertWorkflowExecutionRequest{} + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *UpsertWorkflowExecutionRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().UpsertWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().UpsertWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, nil, test.advancedVisibilityWritingMode, log.NewNoop()) + + err := visibilityManager.UpsertWorkflowExecution(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualDeleteWorkflowExecution(t *testing.T) { + request := &VisibilityDeleteWorkflowExecutionRequest{} + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *VisibilityDeleteWorkflowExecutionRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().DeleteWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().DeleteWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, nil, test.advancedVisibilityWritingMode, log.NewNoop()) + + err := visibilityManager.DeleteWorkflowExecution(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualDeleteUninitializedWorkflowExecution(t *testing.T) { + request := &VisibilityDeleteWorkflowExecutionRequest{} + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *VisibilityDeleteWorkflowExecutionRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().DeleteUninitializedWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().DeleteUninitializedWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, nil, test.advancedVisibilityWritingMode, log.NewNoop()) + + err := visibilityManager.DeleteUninitializedWorkflowExecution(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualListOpenWorkflowExecutions(t *testing.T) { + request := &ListWorkflowExecutionsRequest{ + Domain: "test-domain", + } + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListOpenWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().ListOpenWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListOpenWorkflowExecutions(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualListClosedWorkflowExecutions(t *testing.T) { + request := &ListWorkflowExecutionsRequest{ + Domain: "test-domain", + } + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListClosedWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().ListClosedWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case2-1: success case with DB visibility is not nil and read mod is false": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListClosedWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(false), + expectedError: nil, + }, + "Case2-2: success case with Pinot visibility is not nil and read mod is false": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().ListClosedWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(false), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListClosedWorkflowExecutions(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualListOpenWorkflowExecutionsByType(t *testing.T) { + request := &ListWorkflowExecutionsByTypeRequest{ + ListWorkflowExecutionsRequest: ListWorkflowExecutionsRequest{ + Domain: "test-domain", + }, + } + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsByTypeRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListOpenWorkflowExecutionsByType(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().ListOpenWorkflowExecutionsByType(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListOpenWorkflowExecutionsByType(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualListClosedWorkflowExecutionsByType(t *testing.T) { + request := &ListWorkflowExecutionsByTypeRequest{ + ListWorkflowExecutionsRequest: ListWorkflowExecutionsRequest{ + Domain: "test-domain", + }, + } + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsByTypeRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListClosedWorkflowExecutionsByType(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().ListClosedWorkflowExecutionsByType(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListClosedWorkflowExecutionsByType(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualListOpenWorkflowExecutionsByWorkflowID(t *testing.T) { + request := &ListWorkflowExecutionsByWorkflowIDRequest{ + ListWorkflowExecutionsRequest: ListWorkflowExecutionsRequest{ + Domain: "test-domain", + }, + } + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsByWorkflowIDRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListOpenWorkflowExecutionsByWorkflowID(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().ListOpenWorkflowExecutionsByWorkflowID(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListOpenWorkflowExecutionsByWorkflowID(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualListClosedWorkflowExecutionsByWorkflowID(t *testing.T) { + request := &ListWorkflowExecutionsByWorkflowIDRequest{ + ListWorkflowExecutionsRequest: ListWorkflowExecutionsRequest{ + Domain: "test-domain", + }, + } + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsByWorkflowIDRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListClosedWorkflowExecutionsByWorkflowID(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().ListClosedWorkflowExecutionsByWorkflowID(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListClosedWorkflowExecutionsByWorkflowID(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualListClosedWorkflowExecutionsByStatus(t *testing.T) { + request := &ListClosedWorkflowExecutionsByStatusRequest{ + ListWorkflowExecutionsRequest: ListWorkflowExecutionsRequest{ + Domain: "test-domain", + }, + } + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListClosedWorkflowExecutionsByStatusRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListClosedWorkflowExecutionsByStatus(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().ListClosedWorkflowExecutionsByStatus(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListClosedWorkflowExecutionsByStatus(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualGetClosedWorkflowExecution(t *testing.T) { + request := &GetClosedWorkflowExecutionRequest{ + Domain: "test-domain", + } + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *GetClosedWorkflowExecutionRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().GetClosedWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().GetClosedWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.GetClosedWorkflowExecution(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualListWorkflowExecutions(t *testing.T) { + request := &ListWorkflowExecutionsByQueryRequest{ + Domain: "test-domain", + } + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsByQueryRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().ListWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListWorkflowExecutions(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualScanWorkflowExecutions(t *testing.T) { + request := &ListWorkflowExecutionsByQueryRequest{ + Domain: "test-domain", + } + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsByQueryRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ScanWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().ScanWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ScanWorkflowExecutions(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotDualCountWorkflowExecutions(t *testing.T) { + request := &CountWorkflowExecutionsRequest{ + Domain: "test-domain", + } + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *CountWorkflowExecutionsRequest + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().CountWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().CountWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityDualManager(test.mockDBVisibilityManager, + test.mockPinotVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.CountWorkflowExecutions(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/common/persistence/pinotVisibilityTripleManager.go b/common/persistence/pinot_visibility_triple_manager.go similarity index 100% rename from common/persistence/pinotVisibilityTripleManager.go rename to common/persistence/pinot_visibility_triple_manager.go diff --git a/common/persistence/pinot_visibility_triple_manager_test.go b/common/persistence/pinot_visibility_triple_manager_test.go new file mode 100644 index 00000000000..b7b6b2e3c28 --- /dev/null +++ b/common/persistence/pinot_visibility_triple_manager_test.go @@ -0,0 +1,891 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package persistence + +import ( + "context" + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/uber/cadence/common" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log" +) + +func TestNewPinotVisibilityTripleManager(t *testing.T) { + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + }{ + "Case1: nil case": { + mockDBVisibilityManager: nil, + mockESVisibilityManager: nil, + mockPinotVisibilityManager: nil, + }, + "Case2: success case": { + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + assert.NotPanics(t, func() { + NewPinotVisibilityTripleManager(test.mockDBVisibilityManager, test.mockPinotVisibilityManager, + test.mockESVisibilityManager, nil, nil, + nil, nil, log.NewNoop()) + }) + }) + } +} + +func TestPinotTripleManagerClose(t *testing.T) { + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + mockDBVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + }{ + "Case1-1: success case with DB visibility is not nil": { + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().Close().Return().Times(1) + }, + }, + "Case1-2: success case with ES visibility is not nil": { + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().Close().Return().Times(1) + }, + }, + "Case1-3: success case with pinot visibility is not nil": { + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().Close().Return().Times(1) + }, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + + visibilityManager := NewPinotVisibilityTripleManager(test.mockDBVisibilityManager, test.mockPinotVisibilityManager, + test.mockESVisibilityManager, nil, nil, + nil, nil, log.NewNoop()) + assert.NotPanics(t, func() { + visibilityManager.Close() + }) + }) + } +} + +func TestPinotTripleManagerGetName(t *testing.T) { + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + }{ + "Case1-1: success case with DB visibility is not nil": { + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().GetName().Return(testTableName).Times(1) + + }, + }, + "Case1-2: success case with ES visibility is not nil": { + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().GetName().Return(testTableName).Times(1) + }, + }, + "Case1-3: success case with pinot visibility is not nil": { + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().GetName().Return(testTableName).Times(1) + }, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityTripleManager(test.mockDBVisibilityManager, test.mockPinotVisibilityManager, + test.mockESVisibilityManager, nil, nil, + nil, nil, log.NewNoop()) + + assert.NotPanics(t, func() { + visibilityManager.GetName() + }) + }) + } +} + +func TestPinotTripleRecordWorkflowExecutionStarted(t *testing.T) { + request := &RecordWorkflowExecutionStartedRequest{} + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *RecordWorkflowExecutionStartedRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionStarted(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionStarted(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + "Case1-3: success case with pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionStarted(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewPinotVisibilityTripleManager(test.mockDBVisibilityManager, test.mockPinotVisibilityManager, + test.mockESVisibilityManager, nil, nil, + test.advancedVisibilityWritingMode, nil, log.NewNoop()) + + err := visibilityManager.RecordWorkflowExecutionStarted(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotTripleRecordWorkflowExecutionClosed(t *testing.T) { + request := &RecordWorkflowExecutionClosedRequest{} + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + context context.Context + request *RecordWorkflowExecutionClosedRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case0-1: error case with advancedVisibilityWritingMode is nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + expectedError: fmt.Errorf("error"), + }, + "Case0-2: error case with ES has errors in dual mode": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(fmt.Errorf("error")).AnyTimes() + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(fmt.Errorf("error")).AnyTimes() + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeDual), + expectedError: fmt.Errorf("error"), + }, + "Case0-3: error case with ES has errors in On mode with Pinot is not nil": { + context: context.Background(), + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(fmt.Errorf("error")).AnyTimes() + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(fmt.Errorf("error")).AnyTimes() + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: fmt.Errorf("error"), + }, + "Case0-4: error case with Pinot has errors in On mode": { + context: context.Background(), + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(fmt.Errorf("error")).AnyTimes() + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: fmt.Errorf("error"), + }, + "Case0-5: error case with Pinot has errors in Triple mode": { + context: context.Background(), + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(fmt.Errorf("error")).AnyTimes() + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeTriple), + expectedError: fmt.Errorf("error"), + }, + "Case0-6: error case with ES has errors in Triple mode": { + context: context.Background(), + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(fmt.Errorf("error")).AnyTimes() + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeTriple), + expectedError: fmt.Errorf("error"), + }, + "Case1-1: success case with DB visibility is not nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + context: context.Background(), + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + "Case1-3: success case with pinot visibility is not nil": { + context: context.Background(), + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + }, + "Case1-4: success case with dual manager": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeDual), + expectedError: nil, + }, + "Case1-5: success case with dual manager when ES and Pinot are not nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeDual), + expectedError: nil, + }, + "Case1-6: success case with triple write when ES and Pinot are not nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeTriple), + expectedError: nil, + }, + "Case1-7: success case with triple write when ES and Pinot are nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeTriple), + expectedError: nil, + }, + "Case1-8: success case with triple write when db is nil": { + context: context.Background(), + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeTriple), + expectedError: nil, + }, + "Case2-1: choose DB visibility manager when it is nil": { + context: context.Background(), + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + }, + "Case2-2: choose ES visibility manager when it is nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + }, + "Case2-3: choose both when ES is nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeDual), + }, + "Case2-4: choose both when DB is nil": { + context: context.Background(), + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeDual), + }, + "Case3-1: chooseVisibilityModeForAdmin when ES is nil": { + context: context.WithValue(context.Background(), VisibilityAdminDeletionKey("visibilityAdminDelete"), true), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + }, + "Case3-2: chooseVisibilityModeForAdmin when DB is nil": { + context: context.WithValue(context.Background(), VisibilityAdminDeletionKey("visibilityAdminDelete"), true), + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + }, + "Case3-3: chooseVisibilityModeForAdmin when both are not nil": { + context: context.WithValue(context.Background(), VisibilityAdminDeletionKey("visibilityAdminDelete"), true), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + }, + "Case3-3: chooseVisibilityModeForAdmin when triple are not nil": { + context: context.WithValue(context.Background(), VisibilityAdminDeletionKey("visibilityAdminDelete"), true), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + + visibilityManager := NewPinotVisibilityTripleManager(test.mockDBVisibilityManager, test.mockPinotVisibilityManager, + test.mockESVisibilityManager, nil, nil, + test.advancedVisibilityWritingMode, nil, log.NewNoop()) + + err := visibilityManager.RecordWorkflowExecutionClosed(test.context, test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// test an edge case +func TestPinotTripleChooseVisibilityModeForAdmin(t *testing.T) { + ctrl := gomock.NewController(t) + dbManager := NewMockVisibilityManager(ctrl) + esManager := NewMockVisibilityManager(ctrl) + pntManager := NewMockVisibilityManager(ctrl) + mgr := NewPinotVisibilityTripleManager(dbManager, pntManager, esManager, nil, nil, + nil, nil, log.NewNoop()) + tripleManager := mgr.(*pinotVisibilityTripleManager) + tripleManager.dbVisibilityManager = nil + tripleManager.pinotVisibilityManager = nil + tripleManager.esVisibilityManager = nil + assert.Equal(t, "INVALID_ADMIN_MODE", tripleManager.chooseVisibilityModeForAdmin()) +} + +func TestPinotTripleRecordWorkflowExecutionUninitialized(t *testing.T) { + request := &RecordWorkflowExecutionUninitializedRequest{} + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *RecordWorkflowExecutionUninitializedRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionUninitialized(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().RecordWorkflowExecutionUninitialized(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + "Case1-3: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionUninitialized(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + + visibilityManager := NewPinotVisibilityTripleManager(test.mockDBVisibilityManager, test.mockPinotVisibilityManager, + test.mockESVisibilityManager, nil, nil, + test.advancedVisibilityWritingMode, nil, log.NewNoop()) + + err := visibilityManager.RecordWorkflowExecutionUninitialized(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotTripleUpsertWorkflowExecution(t *testing.T) { + request := &UpsertWorkflowExecutionRequest{} + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *UpsertWorkflowExecutionRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().UpsertWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().UpsertWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + "Case1-3: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().UpsertWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + + visibilityManager := NewPinotVisibilityTripleManager(test.mockDBVisibilityManager, test.mockPinotVisibilityManager, + test.mockESVisibilityManager, nil, nil, + test.advancedVisibilityWritingMode, nil, log.NewNoop()) + + err := visibilityManager.UpsertWorkflowExecution(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotTripleDeleteWorkflowExecution(t *testing.T) { + request := &VisibilityDeleteWorkflowExecutionRequest{} + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *VisibilityDeleteWorkflowExecutionRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().DeleteWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().DeleteWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + "Case1-3: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().DeleteWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + + visibilityManager := NewPinotVisibilityTripleManager(test.mockDBVisibilityManager, test.mockPinotVisibilityManager, + test.mockESVisibilityManager, nil, nil, + test.advancedVisibilityWritingMode, nil, log.NewNoop()) + + err := visibilityManager.DeleteWorkflowExecution(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPinotTripleDeleteUninitializedWorkflowExecution(t *testing.T) { + request := &VisibilityDeleteWorkflowExecutionRequest{} + + // put this outside because need to use it as an input of the table tests + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *VisibilityDeleteWorkflowExecutionRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockPinotVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockPinotVisibilityManagerAccordance func(mockPinotVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().DeleteUninitializedWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with Pinot visibility is not nil": { + request: request, + mockPinotVisibilityManager: NewMockVisibilityManager(ctrl), + mockPinotVisibilityManagerAccordance: func(mockPinotVisibilityManager *MockVisibilityManager) { + mockPinotVisibilityManager.EXPECT().DeleteUninitializedWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + "Case1-3: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().DeleteUninitializedWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockPinotVisibilityManager != nil { + test.mockPinotVisibilityManagerAccordance(test.mockPinotVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + + visibilityManager := NewPinotVisibilityTripleManager(test.mockDBVisibilityManager, test.mockPinotVisibilityManager, + test.mockESVisibilityManager, nil, nil, + test.advancedVisibilityWritingMode, nil, log.NewNoop()) + + err := visibilityManager.DeleteUninitializedWorkflowExecution(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestFilterAttrPrefix(t *testing.T) { + tests := map[string]struct { + expectedInput string + expectedOutput string + }{ + "Case1: empty input": { + expectedInput: "", + expectedOutput: "", + }, + "Case2: filtered input": { + expectedInput: "`Attr.CustomIntField` = 12", + expectedOutput: "CustomIntField = 12", + }, + "Case3: complex input": { + expectedInput: "WorkflowID = 'test-wf' and (`Attr.CustomIntField` = 12 or `Attr.CustomStringField` = 'a-b-c' and WorkflowType = 'wf-type')", + expectedOutput: "WorkflowID = 'test-wf' and (CustomIntField = 12 or CustomStringField = 'a-b-c' and WorkflowType = 'wf-type')", + }, + "Case4: false positive case": { + expectedInput: "`Attr.CustomStringField` = '`Attr.ABCtesting'", + expectedOutput: "CustomStringField = 'ABCtesting'", // this is supposed to be CustomStringField = '`Attr.ABCtesting' + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + assert.NotPanics(t, func() { + actualOutput := filterAttrPrefix(test.expectedInput) + assert.Equal(t, test.expectedOutput, actualOutput) + }) + }) + } +} diff --git a/common/persistence/visibilityDualManager.go b/common/persistence/visibility_dual_manager.go similarity index 100% rename from common/persistence/visibilityDualManager.go rename to common/persistence/visibility_dual_manager.go diff --git a/common/persistence/visibility_dual_manager_test.go b/common/persistence/visibility_dual_manager_test.go new file mode 100644 index 00000000000..a9e0c4884d9 --- /dev/null +++ b/common/persistence/visibility_dual_manager_test.go @@ -0,0 +1,1228 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package persistence + +import ( + "context" + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/uber/cadence/common" + "github.com/uber/cadence/common/dynamicconfig" + "github.com/uber/cadence/common/log" +) + +func TestNewVisibilityDualManager(t *testing.T) { + ctrl := gomock.NewController(t) + + tests := map[string]struct { + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + }{ + "Case1: nil case": { + mockDBVisibilityManager: nil, + mockESVisibilityManager: nil, + }, + "Case2: success case": { + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + assert.NotPanics(t, func() { + NewVisibilityDualManager(test.mockDBVisibilityManager, test.mockESVisibilityManager, nil, nil, log.NewNoop()) + }) + }) + } +} + +func TestDualManagerClose(t *testing.T) { + ctrl := gomock.NewController(t) + + tests := map[string]struct { + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + }{ + "Case1-1: success case with DB visibility is not nil": { + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().Close().Return().Times(1) + }, + }, + "Case1-2: success case with ES visibility is not nil": { + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().Close().Return().Times(1) + }, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, test.mockESVisibilityManager, nil, nil, log.NewNoop()) + assert.NotPanics(t, func() { + visibilityManager.Close() + }) + }) + } +} + +func TestDualManagerGetName(t *testing.T) { + ctrl := gomock.NewController(t) + + tests := map[string]struct { + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + }{ + "Case1-1: success case with DB visibility is not nil": { + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().GetName().Return(testTableName).Times(1) + + }, + }, + "Case1-2: success case with ES visibility is not nil": { + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().GetName().Return(testTableName).Times(1) + }, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, test.mockESVisibilityManager, nil, nil, log.NewNoop()) + + assert.NotPanics(t, func() { + visibilityManager.GetName() + }) + }) + } +} + +func TestDualRecordWorkflowExecutionStarted(t *testing.T) { + request := &RecordWorkflowExecutionStartedRequest{} + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *RecordWorkflowExecutionStartedRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionStarted(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionStarted(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, nil, test.advancedVisibilityWritingMode, log.NewNoop()) + + err := visibilityManager.RecordWorkflowExecutionStarted(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualRecordWorkflowExecutionClosed(t *testing.T) { + request := &RecordWorkflowExecutionClosedRequest{} + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + context context.Context + request *RecordWorkflowExecutionClosedRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case0-1: error case with advancedVisibilityWritingMode is nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + expectedError: fmt.Errorf("error"), + }, + "Case0-2: error case with ES has errors in dual mode": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + }, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(fmt.Errorf("error")).AnyTimes() + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeDual), + expectedError: fmt.Errorf("error"), + }, + "Case1-1: success case with DB visibility is not nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + context: context.Background(), + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + "Case1-3: success case with dual manager": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeDual), + expectedError: nil, + }, + "Case2-1: choose DB visibility manager when it is nil": { + context: context.Background(), + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + }, + "Case2-2: choose ES visibility manager when it is nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + }, + "Case2-3: choose both when ES is nil": { + context: context.Background(), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeDual), + }, + "Case2-4: choose both when DB is nil": { + context: context.Background(), + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeDual), + }, + "Case3-1: chooseVisibilityModeForAdmin when ES is nil": { + context: context.WithValue(context.Background(), VisibilityAdminDeletionKey("visibilityAdminDelete"), true), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + }, + "Case3-2: chooseVisibilityModeForAdmin when DB is nil": { + context: context.WithValue(context.Background(), VisibilityAdminDeletionKey("visibilityAdminDelete"), true), + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + }, + "Case3-3: chooseVisibilityModeForAdmin when both are not nil": { + context: context.WithValue(context.Background(), VisibilityAdminDeletionKey("visibilityAdminDelete"), true), + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionClosed(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, nil, test.advancedVisibilityWritingMode, log.NewNoop()) + + err := visibilityManager.RecordWorkflowExecutionClosed(test.context, test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// test an edge case +func TestChooseVisibilityModeForAdmin(t *testing.T) { + ctrl := gomock.NewController(t) + dbManager := NewMockVisibilityManager(ctrl) + esManager := NewMockVisibilityManager(ctrl) + mgr := NewVisibilityDualManager(dbManager, esManager, nil, nil, log.NewNoop()) + dualManager := mgr.(*visibilityDualManager) + dualManager.dbVisibilityManager = nil + dualManager.esVisibilityManager = nil + assert.Equal(t, "INVALID_ADMIN_MODE", dualManager.chooseVisibilityModeForAdmin()) +} + +func TestDualRecordWorkflowExecutionUninitialized(t *testing.T) { + request := &RecordWorkflowExecutionUninitializedRequest{} + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *RecordWorkflowExecutionUninitializedRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().RecordWorkflowExecutionUninitialized(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().RecordWorkflowExecutionUninitialized(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, nil, test.advancedVisibilityWritingMode, log.NewNoop()) + + err := visibilityManager.RecordWorkflowExecutionUninitialized(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualUpsertWorkflowExecution(t *testing.T) { + request := &UpsertWorkflowExecutionRequest{} + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *UpsertWorkflowExecutionRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().UpsertWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().UpsertWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, nil, test.advancedVisibilityWritingMode, log.NewNoop()) + + err := visibilityManager.UpsertWorkflowExecution(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualDeleteWorkflowExecution(t *testing.T) { + request := &VisibilityDeleteWorkflowExecutionRequest{} + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *VisibilityDeleteWorkflowExecutionRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().DeleteWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().DeleteWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, nil, test.advancedVisibilityWritingMode, log.NewNoop()) + + err := visibilityManager.DeleteWorkflowExecution(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualDeleteUninitializedWorkflowExecution(t *testing.T) { + request := &VisibilityDeleteWorkflowExecutionRequest{} + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *VisibilityDeleteWorkflowExecutionRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + advancedVisibilityWritingMode dynamicconfig.StringPropertyFn + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().DeleteUninitializedWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOff), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().DeleteUninitializedWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil).Times(1) + }, + advancedVisibilityWritingMode: dynamicconfig.GetStringPropertyFn(common.AdvancedVisibilityWritingModeOn), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, nil, test.advancedVisibilityWritingMode, log.NewNoop()) + + err := visibilityManager.DeleteUninitializedWorkflowExecution(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualListOpenWorkflowExecutions(t *testing.T) { + request := &ListWorkflowExecutionsRequest{ + Domain: "test-domain", + } + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListOpenWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().ListOpenWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListOpenWorkflowExecutions(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualListClosedWorkflowExecutions(t *testing.T) { + request := &ListWorkflowExecutionsRequest{ + Domain: "test-domain", + } + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListClosedWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().ListClosedWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case2-1: success case with DB visibility is not nil and read mod is false": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListClosedWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(false), + expectedError: nil, + }, + "Case2-2: success case with ES visibility is not nil and read mod is false": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().ListClosedWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(false), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListClosedWorkflowExecutions(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualListOpenWorkflowExecutionsByType(t *testing.T) { + request := &ListWorkflowExecutionsByTypeRequest{ + ListWorkflowExecutionsRequest: ListWorkflowExecutionsRequest{ + Domain: "test-domain", + }, + } + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsByTypeRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListOpenWorkflowExecutionsByType(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().ListOpenWorkflowExecutionsByType(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListOpenWorkflowExecutionsByType(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualListClosedWorkflowExecutionsByType(t *testing.T) { + request := &ListWorkflowExecutionsByTypeRequest{ + ListWorkflowExecutionsRequest: ListWorkflowExecutionsRequest{ + Domain: "test-domain", + }, + } + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsByTypeRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListClosedWorkflowExecutionsByType(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().ListClosedWorkflowExecutionsByType(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListClosedWorkflowExecutionsByType(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualListOpenWorkflowExecutionsByWorkflowID(t *testing.T) { + request := &ListWorkflowExecutionsByWorkflowIDRequest{ + ListWorkflowExecutionsRequest: ListWorkflowExecutionsRequest{ + Domain: "test-domain", + }, + } + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsByWorkflowIDRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListOpenWorkflowExecutionsByWorkflowID(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().ListOpenWorkflowExecutionsByWorkflowID(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListOpenWorkflowExecutionsByWorkflowID(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualListClosedWorkflowExecutionsByWorkflowID(t *testing.T) { + request := &ListWorkflowExecutionsByWorkflowIDRequest{ + ListWorkflowExecutionsRequest: ListWorkflowExecutionsRequest{ + Domain: "test-domain", + }, + } + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsByWorkflowIDRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListClosedWorkflowExecutionsByWorkflowID(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().ListClosedWorkflowExecutionsByWorkflowID(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListClosedWorkflowExecutionsByWorkflowID(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualListClosedWorkflowExecutionsByStatus(t *testing.T) { + request := &ListClosedWorkflowExecutionsByStatusRequest{ + ListWorkflowExecutionsRequest: ListWorkflowExecutionsRequest{ + Domain: "test-domain", + }, + } + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListClosedWorkflowExecutionsByStatusRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListClosedWorkflowExecutionsByStatus(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().ListClosedWorkflowExecutionsByStatus(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListClosedWorkflowExecutionsByStatus(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualGetClosedWorkflowExecution(t *testing.T) { + request := &GetClosedWorkflowExecutionRequest{ + Domain: "test-domain", + } + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *GetClosedWorkflowExecutionRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().GetClosedWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().GetClosedWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.GetClosedWorkflowExecution(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualListWorkflowExecutions(t *testing.T) { + request := &ListWorkflowExecutionsByQueryRequest{ + Domain: "test-domain", + } + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsByQueryRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ListWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().ListWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ListWorkflowExecutions(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualScanWorkflowExecutions(t *testing.T) { + request := &ListWorkflowExecutionsByQueryRequest{ + Domain: "test-domain", + } + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *ListWorkflowExecutionsByQueryRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().ScanWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().ScanWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.ScanWorkflowExecutions(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDualCountWorkflowExecutions(t *testing.T) { + request := &CountWorkflowExecutionsRequest{ + Domain: "test-domain", + } + + ctrl := gomock.NewController(t) + + tests := map[string]struct { + request *CountWorkflowExecutionsRequest + mockDBVisibilityManager VisibilityManager + mockESVisibilityManager VisibilityManager + mockDBVisibilityManagerAccordance func(mockDBVisibilityManager *MockVisibilityManager) + mockESVisibilityManagerAccordance func(mockESVisibilityManager *MockVisibilityManager) + readModeIsFromES dynamicconfig.BoolPropertyFnWithDomainFilter + expectedError error + }{ + "Case1-1: success case with DB visibility is not nil": { + request: request, + mockDBVisibilityManager: NewMockVisibilityManager(ctrl), + mockDBVisibilityManagerAccordance: func(mockDBVisibilityManager *MockVisibilityManager) { + mockDBVisibilityManager.EXPECT().CountWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + "Case1-2: success case with ES visibility is not nil": { + request: request, + mockESVisibilityManager: NewMockVisibilityManager(ctrl), + mockESVisibilityManagerAccordance: func(mockESVisibilityManager *MockVisibilityManager) { + mockESVisibilityManager.EXPECT().CountWorkflowExecutions(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1) + }, + readModeIsFromES: dynamicconfig.GetBoolPropertyFnFilteredByDomain(true), + expectedError: nil, + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + if test.mockDBVisibilityManager != nil { + test.mockDBVisibilityManagerAccordance(test.mockDBVisibilityManager.(*MockVisibilityManager)) + } + if test.mockESVisibilityManager != nil { + test.mockESVisibilityManagerAccordance(test.mockESVisibilityManager.(*MockVisibilityManager)) + } + visibilityManager := NewVisibilityDualManager(test.mockDBVisibilityManager, + test.mockESVisibilityManager, test.readModeIsFromES, nil, log.NewNoop()) + + _, err := visibilityManager.CountWorkflowExecutions(context.Background(), test.request) + if test.expectedError != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/common/taskTokenSerializerInterfaces.go b/common/taskTokenSerializerInterfaces.go index 107a35b1646..6afeb67ae69 100644 --- a/common/taskTokenSerializerInterfaces.go +++ b/common/taskTokenSerializerInterfaces.go @@ -48,3 +48,11 @@ type ( TaskID string `json:"taskId"` } ) + +func (t TaskToken) GetDomainID() string { + return t.DomainID +} + +func (t QueryTaskToken) GetDomainID() string { + return t.DomainID +} diff --git a/go.mod b/go.mod index bfc56339f44..bd135351b73 100644 --- a/go.mod +++ b/go.mod @@ -130,6 +130,7 @@ require ( github.com/xdg/stringprep v1.0.0 // indirect github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect go.uber.org/dig v1.10.0 // indirect + go.uber.org/goleak v1.0.0 go.uber.org/net/metrics v1.3.0 // indirect golang.org/x/crypto v0.16.0 // indirect golang.org/x/exp/typeparams v0.0.0-20220218215828-6cf2b201936e // indirect diff --git a/host/workflowidratelimit_test.go b/host/workflowidratelimit_test.go index da68fda1c54..a47c383655d 100644 --- a/host/workflowidratelimit_test.go +++ b/host/workflowidratelimit_test.go @@ -30,19 +30,19 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + + "github.com/uber/cadence/common" "github.com/uber/cadence/common/clock" "github.com/uber/cadence/common/dynamicconfig" "github.com/uber/cadence/common/persistence" pt "github.com/uber/cadence/common/persistence/persistence-tests" - - "github.com/uber/cadence/common" "github.com/uber/cadence/common/types" ) func TestWorkflowIDRateLimitIntegrationSuite(t *testing.T) { flag.Parse() - clusterConfig, err := GetTestClusterConfig("integration_wfidratelimit_cluster.yaml") + clusterConfig, err := GetTestClusterConfig("testdata/integration_wfidratelimit_cluster.yaml") if err != nil { panic(err) } @@ -129,14 +129,12 @@ func (s *WorkflowIDRateLimitIntegrationSuite) TestWorkflowIDSpecificRateLimits() // The ratelimit is 5 per second, so we should be able to start 5 workflows without any error for i := 0; i < 5; i++ { - s.Require().NotNil(s.engine) _, err := s.engine.StartWorkflowExecution(ctx, request) assert.NoError(s.T(), err) } // Now we should get a rate limit error for i := 0; i < 5; i++ { - s.Require().NotNil(s.engine) _, err := s.engine.StartWorkflowExecution(ctx, request) var busyErr *types.ServiceBusyError assert.ErrorAs(s.T(), err, &busyErr) diff --git a/service/frontend/templates/clusterredirection.tmpl b/service/frontend/templates/clusterredirection.tmpl index 91313c69b37..5adb51dd025 100644 --- a/service/frontend/templates/clusterredirection.tmpl +++ b/service/frontend/templates/clusterredirection.tmpl @@ -1,12 +1,10 @@ import ( "context" - "time" "go.uber.org/yarpc" "github.com/uber/cadence/common" "github.com/uber/cadence/common/config" - "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/metrics" "github.com/uber/cadence/common/resource" "github.com/uber/cadence/common/types" @@ -65,20 +63,27 @@ func (handler *clusterRedirectionHandler) {{$method.Declaration}} { var apiName = "{{$method.Name}}" var cluster string + {{$policyMethod := "WithDomainNameRedirect"}} + {{$domain := printf "%s.GetDomain()" (index $method.Params 1).Name}} + {{- if has $method.Name $domainIDAPIs}} + token := domainIDGetter(noopdomainIDGetter{}) + {{- end}} scope, startTime := handler.beforeCall(metrics.DCRedirection{{$method.Name}}Scope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + {{- if has $method.Name $domainIDAPIs}} + handler.afterCall(recover(), scope, startTime, "", token.GetDomainID(), cluster, &err) + {{- else}} + handler.afterCall(recover(), scope, startTime, {{$domain}}, "", cluster, &err) + {{- end}} }() - {{$policyMethod := "WithDomainNameRedirect"}} - {{$domain := printf "%s.GetDomain()" (index $method.Params 1).Name}} {{if has $method.Name $domainIDAPIs}} {{$policyMethod = "WithDomainIDRedirect"}} - {{$domain = "token.DomainID"}} + {{$domain = "token.GetDomainID()"}} {{if has $method.Name $queryTaskTokenAPIs}} - token, err := handler.tokenSerializer.DeserializeQueryTaskToken({{(index $method.Params 1).Name}}.TaskToken) + token, err = handler.tokenSerializer.DeserializeQueryTaskToken({{(index $method.Params 1).Name}}.TaskToken) {{- else}} - token, err := handler.tokenSerializer.Deserialize({{(index $method.Params 1).Name}}.TaskToken) + token, err = handler.tokenSerializer.Deserialize({{(index $method.Params 1).Name}}.TaskToken) {{- end}} if err != nil { {{- if eq (len $method.Results) 1}} @@ -123,7 +128,7 @@ func (handler *clusterRedirectionHandler) QueryWorkflow( } scope, startTime := handler.beforeCall(metrics.DCRedirectionQueryWorkflowScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &retError) + handler.afterCall(recover(), scope, startTime, request.GetDomain(), "", cluster, &retError) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, request.GetDomain(), apiName, func(targetDC string) error { @@ -140,26 +145,3 @@ func (handler *clusterRedirectionHandler) QueryWorkflow( return resp, err } - -func (handler *clusterRedirectionHandler) beforeCall( - scope int, -) (metrics.Scope, time.Time) { - return handler.GetMetricsClient().Scope(scope), handler.GetTimeSource().Now() -} - -func (handler *clusterRedirectionHandler) afterCall( - recovered interface{}, - scope metrics.Scope, - startTime time.Time, - cluster string, - retError *error, -) { - log.CapturePanic(recovered, handler.GetLogger(), retError) - - scope = scope.Tagged(metrics.TargetClusterTag(cluster)) - scope.IncCounter(metrics.CadenceDcRedirectionClientRequests) - scope.RecordTimer(metrics.CadenceDcRedirectionClientLatency, handler.GetTimeSource().Now().Sub(startTime)) - if *retError != nil { - scope.IncCounter(metrics.CadenceDcRedirectionClientFailures) - } -} diff --git a/service/frontend/wrappers/clusterredirection/api_generated.go b/service/frontend/wrappers/clusterredirection/api_generated.go index 4c43e378d8a..68374a096d6 100644 --- a/service/frontend/wrappers/clusterredirection/api_generated.go +++ b/service/frontend/wrappers/clusterredirection/api_generated.go @@ -28,13 +28,11 @@ package clusterredirection import ( "context" - "time" "go.uber.org/yarpc" "github.com/uber/cadence/common" "github.com/uber/cadence/common/config" - "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/metrics" "github.com/uber/cadence/common/resource" "github.com/uber/cadence/common/types" @@ -85,7 +83,7 @@ func (handler *clusterRedirectionHandler) CountWorkflowExecutions(ctx context.Co scope, startTime := handler.beforeCall(metrics.DCRedirectionCountWorkflowExecutionsScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, cp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, cp1.GetDomain(), apiName, func(targetDC string) error { @@ -117,7 +115,7 @@ func (handler *clusterRedirectionHandler) DescribeTaskList(ctx context.Context, scope, startTime := handler.beforeCall(metrics.DCRedirectionDescribeTaskListScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, dp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, dp1.GetDomain(), apiName, func(targetDC string) error { @@ -141,7 +139,7 @@ func (handler *clusterRedirectionHandler) DescribeWorkflowExecution(ctx context. scope, startTime := handler.beforeCall(metrics.DCRedirectionDescribeWorkflowExecutionScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, dp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, dp1.GetDomain(), apiName, func(targetDC string) error { @@ -173,7 +171,7 @@ func (handler *clusterRedirectionHandler) GetTaskListsByDomain(ctx context.Conte scope, startTime := handler.beforeCall(metrics.DCRedirectionGetTaskListsByDomainScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, gp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, gp1.GetDomain(), apiName, func(targetDC string) error { @@ -197,7 +195,7 @@ func (handler *clusterRedirectionHandler) GetWorkflowExecutionHistory(ctx contex scope, startTime := handler.beforeCall(metrics.DCRedirectionGetWorkflowExecutionHistoryScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, gp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, gp1.GetDomain(), apiName, func(targetDC string) error { @@ -225,7 +223,7 @@ func (handler *clusterRedirectionHandler) ListArchivedWorkflowExecutions(ctx con scope, startTime := handler.beforeCall(metrics.DCRedirectionListArchivedWorkflowExecutionsScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, lp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, lp1.GetDomain(), apiName, func(targetDC string) error { @@ -249,7 +247,7 @@ func (handler *clusterRedirectionHandler) ListClosedWorkflowExecutions(ctx conte scope, startTime := handler.beforeCall(metrics.DCRedirectionListClosedWorkflowExecutionsScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, lp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, lp1.GetDomain(), apiName, func(targetDC string) error { @@ -277,7 +275,7 @@ func (handler *clusterRedirectionHandler) ListOpenWorkflowExecutions(ctx context scope, startTime := handler.beforeCall(metrics.DCRedirectionListOpenWorkflowExecutionsScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, lp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, lp1.GetDomain(), apiName, func(targetDC string) error { @@ -301,7 +299,7 @@ func (handler *clusterRedirectionHandler) ListTaskListPartitions(ctx context.Con scope, startTime := handler.beforeCall(metrics.DCRedirectionListTaskListPartitionsScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, lp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, lp1.GetDomain(), apiName, func(targetDC string) error { @@ -325,7 +323,7 @@ func (handler *clusterRedirectionHandler) ListWorkflowExecutions(ctx context.Con scope, startTime := handler.beforeCall(metrics.DCRedirectionListWorkflowExecutionsScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, lp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, lp1.GetDomain(), apiName, func(targetDC string) error { @@ -349,7 +347,7 @@ func (handler *clusterRedirectionHandler) PollForActivityTask(ctx context.Contex scope, startTime := handler.beforeCall(metrics.DCRedirectionPollForActivityTaskScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, pp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, pp1.GetDomain(), apiName, func(targetDC string) error { @@ -373,7 +371,7 @@ func (handler *clusterRedirectionHandler) PollForDecisionTask(ctx context.Contex scope, startTime := handler.beforeCall(metrics.DCRedirectionPollForDecisionTaskScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, pp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, pp1.GetDomain(), apiName, func(targetDC string) error { @@ -395,17 +393,18 @@ func (handler *clusterRedirectionHandler) RecordActivityTaskHeartbeat(ctx contex var apiName = "RecordActivityTaskHeartbeat" var cluster string + token := domainIDGetter(noopdomainIDGetter{}) scope, startTime := handler.beforeCall(metrics.DCRedirectionRecordActivityTaskHeartbeatScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, "", token.GetDomainID(), cluster, &err) }() - token, err := handler.tokenSerializer.Deserialize(rp1.TaskToken) + token, err = handler.tokenSerializer.Deserialize(rp1.TaskToken) if err != nil { return nil, err } - err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.DomainID, apiName, func(targetDC string) error { + err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.GetDomainID(), apiName, func(targetDC string) error { cluster = targetDC switch { case targetDC == handler.currentClusterName: @@ -426,7 +425,7 @@ func (handler *clusterRedirectionHandler) RecordActivityTaskHeartbeatByID(ctx co scope, startTime := handler.beforeCall(metrics.DCRedirectionRecordActivityTaskHeartbeatByIDScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, rp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, rp1.GetDomain(), apiName, func(targetDC string) error { @@ -450,7 +449,7 @@ func (handler *clusterRedirectionHandler) RefreshWorkflowTasks(ctx context.Conte scope, startTime := handler.beforeCall(metrics.DCRedirectionRefreshWorkflowTasksScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, rp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, rp1.GetDomain(), apiName, func(targetDC string) error { @@ -478,7 +477,7 @@ func (handler *clusterRedirectionHandler) RequestCancelWorkflowExecution(ctx con scope, startTime := handler.beforeCall(metrics.DCRedirectionRequestCancelWorkflowExecutionScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, rp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, rp1.GetDomain(), apiName, func(targetDC string) error { @@ -502,7 +501,7 @@ func (handler *clusterRedirectionHandler) ResetStickyTaskList(ctx context.Contex scope, startTime := handler.beforeCall(metrics.DCRedirectionResetStickyTaskListScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, rp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, rp1.GetDomain(), apiName, func(targetDC string) error { @@ -526,7 +525,7 @@ func (handler *clusterRedirectionHandler) ResetWorkflowExecution(ctx context.Con scope, startTime := handler.beforeCall(metrics.DCRedirectionResetWorkflowExecutionScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, rp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, rp1.GetDomain(), apiName, func(targetDC string) error { @@ -548,17 +547,18 @@ func (handler *clusterRedirectionHandler) RespondActivityTaskCanceled(ctx contex var apiName = "RespondActivityTaskCanceled" var cluster string + token := domainIDGetter(noopdomainIDGetter{}) scope, startTime := handler.beforeCall(metrics.DCRedirectionRespondActivityTaskCanceledScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, "", token.GetDomainID(), cluster, &err) }() - token, err := handler.tokenSerializer.Deserialize(rp1.TaskToken) + token, err = handler.tokenSerializer.Deserialize(rp1.TaskToken) if err != nil { return err } - err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.DomainID, apiName, func(targetDC string) error { + err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.GetDomainID(), apiName, func(targetDC string) error { cluster = targetDC switch { case targetDC == handler.currentClusterName: @@ -579,7 +579,7 @@ func (handler *clusterRedirectionHandler) RespondActivityTaskCanceledByID(ctx co scope, startTime := handler.beforeCall(metrics.DCRedirectionRespondActivityTaskCanceledByIDScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, rp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, rp1.GetDomain(), apiName, func(targetDC string) error { @@ -601,17 +601,18 @@ func (handler *clusterRedirectionHandler) RespondActivityTaskCompleted(ctx conte var apiName = "RespondActivityTaskCompleted" var cluster string + token := domainIDGetter(noopdomainIDGetter{}) scope, startTime := handler.beforeCall(metrics.DCRedirectionRespondActivityTaskCompletedScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, "", token.GetDomainID(), cluster, &err) }() - token, err := handler.tokenSerializer.Deserialize(rp1.TaskToken) + token, err = handler.tokenSerializer.Deserialize(rp1.TaskToken) if err != nil { return err } - err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.DomainID, apiName, func(targetDC string) error { + err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.GetDomainID(), apiName, func(targetDC string) error { cluster = targetDC switch { case targetDC == handler.currentClusterName: @@ -632,7 +633,7 @@ func (handler *clusterRedirectionHandler) RespondActivityTaskCompletedByID(ctx c scope, startTime := handler.beforeCall(metrics.DCRedirectionRespondActivityTaskCompletedByIDScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, rp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, rp1.GetDomain(), apiName, func(targetDC string) error { @@ -654,17 +655,18 @@ func (handler *clusterRedirectionHandler) RespondActivityTaskFailed(ctx context. var apiName = "RespondActivityTaskFailed" var cluster string + token := domainIDGetter(noopdomainIDGetter{}) scope, startTime := handler.beforeCall(metrics.DCRedirectionRespondActivityTaskFailedScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, "", token.GetDomainID(), cluster, &err) }() - token, err := handler.tokenSerializer.Deserialize(rp1.TaskToken) + token, err = handler.tokenSerializer.Deserialize(rp1.TaskToken) if err != nil { return err } - err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.DomainID, apiName, func(targetDC string) error { + err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.GetDomainID(), apiName, func(targetDC string) error { cluster = targetDC switch { case targetDC == handler.currentClusterName: @@ -685,7 +687,7 @@ func (handler *clusterRedirectionHandler) RespondActivityTaskFailedByID(ctx cont scope, startTime := handler.beforeCall(metrics.DCRedirectionRespondActivityTaskFailedByIDScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, rp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, rp1.GetDomain(), apiName, func(targetDC string) error { @@ -707,17 +709,18 @@ func (handler *clusterRedirectionHandler) RespondDecisionTaskCompleted(ctx conte var apiName = "RespondDecisionTaskCompleted" var cluster string + token := domainIDGetter(noopdomainIDGetter{}) scope, startTime := handler.beforeCall(metrics.DCRedirectionRespondDecisionTaskCompletedScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, "", token.GetDomainID(), cluster, &err) }() - token, err := handler.tokenSerializer.Deserialize(rp1.TaskToken) + token, err = handler.tokenSerializer.Deserialize(rp1.TaskToken) if err != nil { return nil, err } - err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.DomainID, apiName, func(targetDC string) error { + err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.GetDomainID(), apiName, func(targetDC string) error { cluster = targetDC switch { case targetDC == handler.currentClusterName: @@ -736,17 +739,18 @@ func (handler *clusterRedirectionHandler) RespondDecisionTaskFailed(ctx context. var apiName = "RespondDecisionTaskFailed" var cluster string + token := domainIDGetter(noopdomainIDGetter{}) scope, startTime := handler.beforeCall(metrics.DCRedirectionRespondDecisionTaskFailedScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, "", token.GetDomainID(), cluster, &err) }() - token, err := handler.tokenSerializer.Deserialize(rp1.TaskToken) + token, err = handler.tokenSerializer.Deserialize(rp1.TaskToken) if err != nil { return err } - err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.DomainID, apiName, func(targetDC string) error { + err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.GetDomainID(), apiName, func(targetDC string) error { cluster = targetDC switch { case targetDC == handler.currentClusterName: @@ -765,17 +769,18 @@ func (handler *clusterRedirectionHandler) RespondQueryTaskCompleted(ctx context. var apiName = "RespondQueryTaskCompleted" var cluster string + token := domainIDGetter(noopdomainIDGetter{}) scope, startTime := handler.beforeCall(metrics.DCRedirectionRespondQueryTaskCompletedScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, "", token.GetDomainID(), cluster, &err) }() - token, err := handler.tokenSerializer.DeserializeQueryTaskToken(rp1.TaskToken) + token, err = handler.tokenSerializer.DeserializeQueryTaskToken(rp1.TaskToken) if err != nil { return err } - err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.DomainID, apiName, func(targetDC string) error { + err = handler.redirectionPolicy.WithDomainIDRedirect(ctx, token.GetDomainID(), apiName, func(targetDC string) error { cluster = targetDC switch { case targetDC == handler.currentClusterName: @@ -796,7 +801,7 @@ func (handler *clusterRedirectionHandler) RestartWorkflowExecution(ctx context.C scope, startTime := handler.beforeCall(metrics.DCRedirectionRestartWorkflowExecutionScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, rp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, rp1.GetDomain(), apiName, func(targetDC string) error { @@ -820,7 +825,7 @@ func (handler *clusterRedirectionHandler) ScanWorkflowExecutions(ctx context.Con scope, startTime := handler.beforeCall(metrics.DCRedirectionScanWorkflowExecutionsScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, lp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, lp1.GetDomain(), apiName, func(targetDC string) error { @@ -844,7 +849,7 @@ func (handler *clusterRedirectionHandler) SignalWithStartWorkflowExecution(ctx c scope, startTime := handler.beforeCall(metrics.DCRedirectionSignalWithStartWorkflowExecutionScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, sp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, sp1.GetDomain(), apiName, func(targetDC string) error { @@ -868,7 +873,7 @@ func (handler *clusterRedirectionHandler) SignalWithStartWorkflowExecutionAsync( scope, startTime := handler.beforeCall(metrics.DCRedirectionSignalWithStartWorkflowExecutionAsyncScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, sp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, sp1.GetDomain(), apiName, func(targetDC string) error { @@ -892,7 +897,7 @@ func (handler *clusterRedirectionHandler) SignalWorkflowExecution(ctx context.Co scope, startTime := handler.beforeCall(metrics.DCRedirectionSignalWorkflowExecutionScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, sp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, sp1.GetDomain(), apiName, func(targetDC string) error { @@ -916,7 +921,7 @@ func (handler *clusterRedirectionHandler) StartWorkflowExecution(ctx context.Con scope, startTime := handler.beforeCall(metrics.DCRedirectionStartWorkflowExecutionScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, sp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, sp1.GetDomain(), apiName, func(targetDC string) error { @@ -940,7 +945,7 @@ func (handler *clusterRedirectionHandler) StartWorkflowExecutionAsync(ctx contex scope, startTime := handler.beforeCall(metrics.DCRedirectionStartWorkflowExecutionAsyncScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, sp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, sp1.GetDomain(), apiName, func(targetDC string) error { @@ -964,7 +969,7 @@ func (handler *clusterRedirectionHandler) TerminateWorkflowExecution(ctx context scope, startTime := handler.beforeCall(metrics.DCRedirectionTerminateWorkflowExecutionScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &err) + handler.afterCall(recover(), scope, startTime, tp1.GetDomain(), "", cluster, &err) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, tp1.GetDomain(), apiName, func(targetDC string) error { @@ -1002,7 +1007,7 @@ func (handler *clusterRedirectionHandler) QueryWorkflow( } scope, startTime := handler.beforeCall(metrics.DCRedirectionQueryWorkflowScope) defer func() { - handler.afterCall(recover(), scope, startTime, cluster, &retError) + handler.afterCall(recover(), scope, startTime, request.GetDomain(), "", cluster, &retError) }() err = handler.redirectionPolicy.WithDomainNameRedirect(ctx, request.GetDomain(), apiName, func(targetDC string) error { @@ -1019,26 +1024,3 @@ func (handler *clusterRedirectionHandler) QueryWorkflow( return resp, err } - -func (handler *clusterRedirectionHandler) beforeCall( - scope int, -) (metrics.Scope, time.Time) { - return handler.GetMetricsClient().Scope(scope), handler.GetTimeSource().Now() -} - -func (handler *clusterRedirectionHandler) afterCall( - recovered interface{}, - scope metrics.Scope, - startTime time.Time, - cluster string, - retError *error, -) { - log.CapturePanic(recovered, handler.GetLogger(), retError) - - scope = scope.Tagged(metrics.TargetClusterTag(cluster)) - scope.IncCounter(metrics.CadenceDcRedirectionClientRequests) - scope.RecordTimer(metrics.CadenceDcRedirectionClientLatency, handler.GetTimeSource().Now().Sub(startTime)) - if *retError != nil { - scope.IncCounter(metrics.CadenceDcRedirectionClientFailures) - } -} diff --git a/service/frontend/wrappers/clusterredirection/api_test.go b/service/frontend/wrappers/clusterredirection/api_test.go index 56db199c10e..4273da5af54 100644 --- a/service/frontend/wrappers/clusterredirection/api_test.go +++ b/service/frontend/wrappers/clusterredirection/api_test.go @@ -78,12 +78,6 @@ func TestClusterRedirectionHandlerSuite(t *testing.T) { suite.Run(t, s) } -func (s *clusterRedirectionHandlerSuite) SetupSuite() { -} - -func (s *clusterRedirectionHandlerSuite) TearDownSuite() { -} - func (s *clusterRedirectionHandlerSuite) SetupTest() { s.Assertions = require.New(s.T()) diff --git a/service/frontend/wrappers/clusterredirection/callwrappers.go b/service/frontend/wrappers/clusterredirection/callwrappers.go new file mode 100644 index 00000000000..0aec0759515 --- /dev/null +++ b/service/frontend/wrappers/clusterredirection/callwrappers.go @@ -0,0 +1,77 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package clusterredirection + +import ( + "time" + + "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/log/tag" + "github.com/uber/cadence/common/metrics" +) + +type ( + domainIDGetter interface { + GetDomainID() string + } +) + +func (handler *clusterRedirectionHandler) beforeCall( + scope int, +) (metrics.Scope, time.Time) { + return handler.GetMetricsClient().Scope(scope), handler.GetTimeSource().Now() +} + +func (handler *clusterRedirectionHandler) afterCall( + recovered interface{}, + scope metrics.Scope, + startTime time.Time, + domainName string, + domainID string, + cluster string, + retError *error, +) { + var extraTags []tag.Tag + if domainName != "" { + extraTags = append(extraTags, tag.WorkflowDomainName(domainName)) + } + if domainID != "" { + extraTags = append(extraTags, tag.WorkflowDomainID(domainID)) + } + log.CapturePanic(recovered, handler.GetLogger().WithTags(extraTags...), retError) + + scope = scope.Tagged(metrics.TargetClusterTag(cluster)) + scope.IncCounter(metrics.CadenceDcRedirectionClientRequests) + scope.RecordTimer(metrics.CadenceDcRedirectionClientLatency, handler.GetTimeSource().Now().Sub(startTime)) + if *retError != nil { + scope.IncCounter(metrics.CadenceDcRedirectionClientFailures) + } +} + +// noopdomainIDGetter is a domainIDGetter that always returns empty string. +// it is used for extraction of domainID from domainIDGetter in case of token extraction failure. +type noopdomainIDGetter struct{} + +func (noopdomainIDGetter) GetDomainID() string { + return "" +} diff --git a/service/frontend/wrappers/clusterredirection/policy.go b/service/frontend/wrappers/clusterredirection/policy.go index 227d36bbc6e..dde78e4b9b4 100644 --- a/service/frontend/wrappers/clusterredirection/policy.go +++ b/service/frontend/wrappers/clusterredirection/policy.go @@ -194,6 +194,9 @@ func (policy *selectedOrAllAPIsForwardingRedirectionPolicy) WithDomainIDRedirect if err != nil { return err } + if domainEntry.IsDeprecatedOrDeleted() { + return fmt.Errorf("domain %v is deprecated or deleted", domainEntry.GetInfo().Name) + } return policy.withRedirect(ctx, domainEntry, apiName, call) } @@ -203,6 +206,9 @@ func (policy *selectedOrAllAPIsForwardingRedirectionPolicy) WithDomainNameRedire if err != nil { return err } + if domainEntry.IsDeprecatedOrDeleted() { + return fmt.Errorf("domain %v is deprecated or deleted", domainName) + } return policy.withRedirect(ctx, domainEntry, apiName, call) } diff --git a/service/frontend/wrappers/clusterredirection/policy_test.go b/service/frontend/wrappers/clusterredirection/policy_test.go index f7733ce2c93..aa3d1d280ef 100644 --- a/service/frontend/wrappers/clusterredirection/policy_test.go +++ b/service/frontend/wrappers/clusterredirection/policy_test.go @@ -69,13 +69,6 @@ func TestNoopDCRedirectionPolicySuite(t *testing.T) { suite.Run(t, s) } -func (s *noopDCRedirectionPolicySuite) SetupSuite() { -} - -func (s *noopDCRedirectionPolicySuite) TearDownSuite() { - -} - func (s *noopDCRedirectionPolicySuite) SetupTest() { s.Assertions = require.New(s.T()) @@ -343,6 +336,79 @@ func (s *selectedAPIsForwardingRedirectionPolicySuite) TestGetTargetDataCenter_G s.Equal(2*len(selectedAPIsForwardingRedirectionPolicyAPIAllowlist), alternativeClustercallCount) } +func (s *selectedAPIsForwardingRedirectionPolicySuite) TestGetTargetDataCenter_GlobalDomain_NoDomainInCache() { + currentClustercallCount := 0 + alternativeClustercallCount := 0 + callFn := func(targetCluster string) error { + switch targetCluster { + case s.currentClusterName: + currentClustercallCount++ + return nil + case s.alternativeClusterName: + alternativeClustercallCount++ + return &types.DomainNotActiveError{ + CurrentCluster: s.alternativeClusterName, + ActiveCluster: s.currentClusterName, + } + default: + panic(fmt.Sprintf("unknown cluster name %v", targetCluster)) + } + } + + expectedErr := fmt.Errorf("some random error") + s.mockDomainCache.EXPECT().GetDomainByID(s.domainID).Return(nil, expectedErr).Times(len(selectedAPIsForwardingRedirectionPolicyAPIAllowlist)) + s.mockDomainCache.EXPECT().GetDomain(s.domainName).Return(nil, expectedErr).Times(len(selectedAPIsForwardingRedirectionPolicyAPIAllowlist)) + + for apiName := range selectedAPIsForwardingRedirectionPolicyAPIAllowlist { + err := s.policy.WithDomainIDRedirect(context.Background(), s.domainID, apiName, callFn) + s.Error(err) + s.Equal(expectedErr.Error(), err.Error()) + + err = s.policy.WithDomainNameRedirect(context.Background(), s.domainName, apiName, callFn) + s.Error(err) + s.Equal(expectedErr.Error(), err.Error()) + } + + // Ensure there were no calls to the target clusters + s.Equal(0, currentClustercallCount) + s.Equal(0, alternativeClustercallCount) +} + +func (s *selectedAPIsForwardingRedirectionPolicySuite) TestGetTargetDataCenter_GlobalDomain_Forwarding_DeprecatedDomain() { + s.setupGlobalDeprecatedDomainWithTwoReplicationCluster(true, false) + + currentClustercallCount := 0 + alternativeClustercallCount := 0 + callFn := func(targetCluster string) error { + switch targetCluster { + case s.currentClusterName: + currentClustercallCount++ + return nil + case s.alternativeClusterName: + alternativeClustercallCount++ + return &types.DomainNotActiveError{ + CurrentCluster: s.alternativeClusterName, + ActiveCluster: s.currentClusterName, + } + default: + panic(fmt.Sprintf("unknown cluster name %v", targetCluster)) + } + } + + for apiName := range selectedAPIsForwardingRedirectionPolicyAPIAllowlist { + err := s.policy.WithDomainIDRedirect(context.Background(), s.domainID, apiName, callFn) + s.Error(err) + s.Equal(fmt.Sprintf("domain %v is deprecated or deleted", s.domainName), err.Error()) + + err = s.policy.WithDomainNameRedirect(context.Background(), s.domainName, apiName, callFn) + s.Error(err) + s.Equal(fmt.Sprintf("domain %v is deprecated or deleted", s.domainName), err.Error()) + } + + s.Equal(0, currentClustercallCount) + s.Equal(0, alternativeClustercallCount) +} + func (s *selectedAPIsForwardingRedirectionPolicySuite) setupLocalDomain() { domainEntry := cache.NewLocalDomainCacheEntryForTest( &persistence.DomainInfo{ID: s.domainID, Name: s.domainName}, @@ -376,3 +442,26 @@ func (s *selectedAPIsForwardingRedirectionPolicySuite) setupGlobalDomainWithTwoR s.mockDomainCache.EXPECT().GetDomain(s.domainName).Return(domainEntry, nil).AnyTimes() s.mockConfig.EnableDomainNotActiveAutoForwarding = dynamicconfig.GetBoolPropertyFnFilteredByDomain(forwardingEnabled) } + +func (s *selectedAPIsForwardingRedirectionPolicySuite) setupGlobalDeprecatedDomainWithTwoReplicationCluster(forwardingEnabled bool, isRecordActive bool) { + activeCluster := s.alternativeClusterName + if isRecordActive { + activeCluster = s.currentClusterName + } + domainEntry := cache.NewGlobalDomainCacheEntryForTest( + &persistence.DomainInfo{ID: s.domainID, Name: s.domainName, Status: persistence.DomainStatusDeprecated}, + &persistence.DomainConfig{Retention: 1}, + &persistence.DomainReplicationConfig{ + ActiveClusterName: activeCluster, + Clusters: []*persistence.ClusterReplicationConfig{ + {ClusterName: cluster.TestCurrentClusterName}, + {ClusterName: cluster.TestAlternativeClusterName}, + }, + }, + 1234, // not used + ) + + s.mockDomainCache.EXPECT().GetDomainByID(s.domainID).Return(domainEntry, nil).AnyTimes() + s.mockDomainCache.EXPECT().GetDomain(s.domainName).Return(domainEntry, nil).AnyTimes() + s.mockConfig.EnableDomainNotActiveAutoForwarding = dynamicconfig.GetBoolPropertyFnFilteredByDomain(forwardingEnabled) +} diff --git a/service/history/execution/context.go b/service/history/execution/context.go index f47fdab0e2d..84c8def7f9f 100644 --- a/service/history/execution/context.go +++ b/service/history/execution/context.go @@ -40,6 +40,7 @@ import ( "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/types" hcommon "github.com/uber/cadence/service/history/common" + "github.com/uber/cadence/service/history/engine" "github.com/uber/cadence/service/history/events" "github.com/uber/cadence/service/history/shard" ) @@ -170,10 +171,24 @@ type ( logger log.Logger metricsClient metrics.Client - mutex locks.Mutex - mutableState MutableState - stats *persistence.ExecutionStats - updateCondition int64 + mutex locks.Mutex + mutableState MutableState + stats *persistence.ExecutionStats + + appendHistoryNodesFn func(context.Context, string, types.WorkflowExecution, *persistence.AppendHistoryNodesRequest) (*persistence.AppendHistoryNodesResponse, error) + persistStartWorkflowBatchEventsFn func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) + persistNonStartWorkflowBatchEventsFn func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) + createWorkflowExecutionFn func(context.Context, *persistence.CreateWorkflowExecutionRequest) (*persistence.CreateWorkflowExecutionResponse, error) + updateWorkflowExecutionFn func(context.Context, *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) + notifyTasksFromWorkflowSnapshotFn func(*persistence.WorkflowSnapshot, events.PersistedBlobs, bool) + notifyTasksFromWorkflowMutationFn func(*persistence.WorkflowMutation, events.PersistedBlobs, bool) + emitSessionUpdateStatsFn func(string, *persistence.MutableStateUpdateSessionStats) + emitWorkflowHistoryStatsFn func(string, int, int) + emitWorkflowCompletionStatsFn func(string, string, string, string, string, *types.HistoryEvent) + mergeContinueAsNewReplicationTasksFn func(persistence.UpdateWorkflowMode, *persistence.WorkflowMutation, *persistence.WorkflowSnapshot) error + updateWorkflowExecutionEventReapplyFn func(persistence.UpdateWorkflowMode, []*persistence.WorkflowEvents, []*persistence.WorkflowEvents) error + conflictResolveEventReapplyFn func(persistence.ConflictResolveWorkflowMode, []*persistence.WorkflowEvents, []*persistence.WorkflowEvents) error + emitLargeWorkflowShardIDStatsFn func(int64, int64, int64, int64) } ) @@ -187,7 +202,8 @@ func NewContext( executionManager persistence.ExecutionManager, logger log.Logger, ) Context { - return &contextImpl{ + logger = logger.WithTags(tag.WorkflowDomainID(domainID), tag.WorkflowID(execution.GetWorkflowID()), tag.WorkflowRunID(execution.GetRunID())) + ctx := &contextImpl{ domainID: domainID, workflowExecution: execution, shard: shard, @@ -198,7 +214,39 @@ func NewContext( stats: &persistence.ExecutionStats{ HistorySize: 0, }, - } + + appendHistoryNodesFn: func(ctx context.Context, domainID string, workflowExecution types.WorkflowExecution, request *persistence.AppendHistoryNodesRequest) (*persistence.AppendHistoryNodesResponse, error) { + return appendHistoryV2EventsWithRetry(ctx, shard, common.CreatePersistenceRetryPolicy(), domainID, workflowExecution, request) + }, + createWorkflowExecutionFn: func(ctx context.Context, request *persistence.CreateWorkflowExecutionRequest) (*persistence.CreateWorkflowExecutionResponse, error) { + return createWorkflowExecutionWithRetry(ctx, shard, logger, common.CreatePersistenceRetryPolicy(), request) + }, + updateWorkflowExecutionFn: func(ctx context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + return updateWorkflowExecutionWithRetry(ctx, shard, logger, common.CreatePersistenceRetryPolicy(), request) + }, + notifyTasksFromWorkflowSnapshotFn: func(snapshot *persistence.WorkflowSnapshot, blobs events.PersistedBlobs, persistentError bool) { + notifyTasksFromWorkflowSnapshot(shard.GetEngine(), snapshot, blobs, persistentError) + }, + notifyTasksFromWorkflowMutationFn: func(snapshot *persistence.WorkflowMutation, blobs events.PersistedBlobs, persistentError bool) { + notifyTasksFromWorkflowMutation(shard.GetEngine(), snapshot, blobs, persistentError) + }, + emitSessionUpdateStatsFn: func(domainName string, stats *persistence.MutableStateUpdateSessionStats) { + emitSessionUpdateStats(shard.GetMetricsClient(), domainName, stats) + }, + emitWorkflowHistoryStatsFn: func(domainName string, historySize int, eventCount int) { + emitWorkflowHistoryStats(shard.GetMetricsClient(), domainName, historySize, eventCount) + }, + emitWorkflowCompletionStatsFn: func(domainName, workflowType, workflowID, runID, taskList string, event *types.HistoryEvent) { + emitWorkflowCompletionStats(shard.GetMetricsClient(), logger, domainName, workflowType, workflowID, runID, taskList, event) + }, + mergeContinueAsNewReplicationTasksFn: mergeContinueAsNewReplicationTasks, + } + ctx.persistStartWorkflowBatchEventsFn = ctx.PersistStartWorkflowBatchEvents + ctx.persistNonStartWorkflowBatchEventsFn = ctx.PersistNonStartWorkflowBatchEvents + ctx.updateWorkflowExecutionEventReapplyFn = ctx.updateWorkflowExecutionEventReapply + ctx.conflictResolveEventReapplyFn = ctx.conflictResolveEventReapply + ctx.emitLargeWorkflowShardIDStatsFn = ctx.emitLargeWorkflowShardIDStats + return ctx } func (c *contextImpl) Lock(ctx context.Context) error { @@ -305,7 +353,6 @@ func (c *contextImpl) LoadWorkflowExecutionWithTaskVersion( } c.stats = response.State.ExecutionStats - c.updateCondition = response.State.ExecutionInfo.NextEventID // finally emit execution and session stats emitWorkflowExecutionStats( @@ -394,22 +441,18 @@ func (c *contextImpl) CreateWorkflowExecution( HistorySize: historySize, } - resp, err := c.createWorkflowExecutionWithRetry(ctx, createRequest) + resp, err := c.createWorkflowExecutionFn(ctx, createRequest) if err != nil { - if c.isPersistenceTimeoutError(err) { - c.notifyTasksFromWorkflowSnapshot(newWorkflow, events.PersistedBlobs{persistedHistory}, true) + if isOperationPossiblySuccessfulError(err) { + c.notifyTasksFromWorkflowSnapshotFn(newWorkflow, events.PersistedBlobs{persistedHistory}, true) } return err } - c.notifyTasksFromWorkflowSnapshot(newWorkflow, events.PersistedBlobs{persistedHistory}, false) + c.notifyTasksFromWorkflowSnapshotFn(newWorkflow, events.PersistedBlobs{persistedHistory}, false) // finally emit session stats - emitSessionUpdateStats( - c.metricsClient, - domainName, - resp.MutableStateUpdateSessionStats, - ) + c.emitSessionUpdateStatsFn(domainName, resp.MutableStateUpdateSessionStats) return nil } @@ -425,17 +468,13 @@ func (c *contextImpl) ConflictResolveWorkflowExecution( currentMutableState MutableState, currentTransactionPolicy *TransactionPolicy, ) (retError error) { - defer func() { if retError != nil { c.Clear() } }() - resetWorkflow, resetWorkflowEventsSeq, err := resetMutableState.CloseTransactionAsSnapshot( - now, - TransactionPolicyPassive, - ) + resetWorkflow, resetWorkflowEventsSeq, err := resetMutableState.CloseTransactionAsSnapshot(now, TransactionPolicyPassive) if err != nil { return err } @@ -443,7 +482,7 @@ func (c *contextImpl) ConflictResolveWorkflowExecution( var persistedBlobs events.PersistedBlobs resetHistorySize := c.GetHistorySize() for _, workflowEvents := range resetWorkflowEventsSeq { - blob, err := c.PersistNonStartWorkflowBatchEvents(ctx, workflowEvents) + blob, err := c.persistNonStartWorkflowBatchEventsFn(ctx, workflowEvents) if err != nil { return err } @@ -458,23 +497,19 @@ func (c *contextImpl) ConflictResolveWorkflowExecution( var newWorkflow *persistence.WorkflowSnapshot var newWorkflowEventsSeq []*persistence.WorkflowEvents if newContext != nil && newMutableState != nil { - defer func() { if retError != nil { newContext.Clear() } }() - newWorkflow, newWorkflowEventsSeq, err = newMutableState.CloseTransactionAsSnapshot( - now, - TransactionPolicyPassive, - ) + newWorkflow, newWorkflowEventsSeq, err = newMutableState.CloseTransactionAsSnapshot(now, TransactionPolicyPassive) if err != nil { return err } newWorkflowSizeSize := newContext.GetHistorySize() startEvents := newWorkflowEventsSeq[0] - blob, err := c.PersistStartWorkflowBatchEvents(ctx, startEvents) + blob, err := c.persistStartWorkflowBatchEventsFn(ctx, startEvents) if err != nil { return err } @@ -489,23 +524,19 @@ func (c *contextImpl) ConflictResolveWorkflowExecution( var currentWorkflow *persistence.WorkflowMutation var currentWorkflowEventsSeq []*persistence.WorkflowEvents if currentContext != nil && currentMutableState != nil && currentTransactionPolicy != nil { - defer func() { if retError != nil { currentContext.Clear() } }() - currentWorkflow, currentWorkflowEventsSeq, err = currentMutableState.CloseTransactionAsMutation( - now, - *currentTransactionPolicy, - ) + currentWorkflow, currentWorkflowEventsSeq, err = currentMutableState.CloseTransactionAsMutation(now, *currentTransactionPolicy) if err != nil { return err } currentWorkflowSize := currentContext.GetHistorySize() for _, workflowEvents := range currentWorkflowEventsSeq { - blob, err := c.PersistNonStartWorkflowBatchEvents(ctx, workflowEvents) + blob, err := c.persistNonStartWorkflowBatchEventsFn(ctx, workflowEvents) if err != nil { return err } @@ -518,7 +549,7 @@ func (c *contextImpl) ConflictResolveWorkflowExecution( } } - if err := c.conflictResolveEventReapply( + if err := c.conflictResolveEventReapplyFn( conflictResolveMode, resetWorkflowEventsSeq, newWorkflowEventsSeq, @@ -540,10 +571,10 @@ func (c *contextImpl) ConflictResolveWorkflowExecution( DomainName: domain, }) if err != nil { - if c.isPersistenceTimeoutError(err) { - c.notifyTasksFromWorkflowSnapshot(resetWorkflow, persistedBlobs, true) - c.notifyTasksFromWorkflowSnapshot(newWorkflow, persistedBlobs, true) - c.notifyTasksFromWorkflowMutation(currentWorkflow, persistedBlobs, true) + if isOperationPossiblySuccessfulError(err) { + c.notifyTasksFromWorkflowSnapshotFn(resetWorkflow, persistedBlobs, true) + c.notifyTasksFromWorkflowSnapshotFn(newWorkflow, persistedBlobs, true) + c.notifyTasksFromWorkflowMutationFn(currentWorkflow, persistedBlobs, true) } return err } @@ -566,34 +597,21 @@ func (c *contextImpl) ConflictResolveWorkflowExecution( workflowCloseState, )) - c.notifyTasksFromWorkflowSnapshot(resetWorkflow, persistedBlobs, false) - c.notifyTasksFromWorkflowSnapshot(newWorkflow, persistedBlobs, false) - c.notifyTasksFromWorkflowMutation(currentWorkflow, persistedBlobs, false) + c.notifyTasksFromWorkflowSnapshotFn(resetWorkflow, persistedBlobs, false) + c.notifyTasksFromWorkflowSnapshotFn(newWorkflow, persistedBlobs, false) + c.notifyTasksFromWorkflowMutationFn(currentWorkflow, persistedBlobs, false) // finally emit session stats - domainName := c.GetDomainName() - emitWorkflowHistoryStats( - c.metricsClient, - domainName, - int(c.stats.HistorySize), - int(resetMutableState.GetNextEventID()-1), - ) - emitSessionUpdateStats( - c.metricsClient, - domainName, - resp.MutableStateUpdateSessionStats, - ) + c.emitWorkflowHistoryStatsFn(domain, int(c.stats.HistorySize), int(resetMutableState.GetNextEventID()-1)) + c.emitSessionUpdateStatsFn(domain, resp.MutableStateUpdateSessionStats) // emit workflow completion stats if any if resetWorkflow.ExecutionInfo.State == persistence.WorkflowStateCompleted { if event, err := resetMutableState.GetCompletionEvent(ctx); err == nil { workflowType := resetWorkflow.ExecutionInfo.WorkflowTypeName taskList := resetWorkflow.ExecutionInfo.TaskList - emitWorkflowCompletionStats(c.metricsClient, c.logger, - domainName, workflowType, c.workflowExecution.GetWorkflowID(), c.workflowExecution.GetRunID(), - taskList, event) + c.emitWorkflowCompletionStatsFn(domain, workflowType, c.workflowExecution.GetWorkflowID(), c.workflowExecution.GetRunID(), taskList, event) } } - return nil } @@ -676,16 +694,13 @@ func (c *contextImpl) UpdateWorkflowExecutionTasks( } }() - currentWorkflow, currentWorkflowEventsSeq, err := c.mutableState.CloseTransactionAsMutation( - now, - TransactionPolicyPassive, - ) + currentWorkflow, currentWorkflowEventsSeq, err := c.mutableState.CloseTransactionAsMutation(now, TransactionPolicyPassive) if err != nil { return err } if len(currentWorkflowEventsSeq) != 0 { - return types.InternalServiceError{ + return &types.InternalServiceError{ Message: "UpdateWorkflowExecutionTask can only be used for persisting new workflow tasks, but found new history events", } } @@ -696,7 +711,7 @@ func (c *contextImpl) UpdateWorkflowExecutionTasks( if errorDomainName != nil { return errorDomainName } - resp, err := c.updateWorkflowExecutionWithRetry(ctx, &persistence.UpdateWorkflowExecutionRequest{ + resp, err := c.updateWorkflowExecutionFn(ctx, &persistence.UpdateWorkflowExecutionRequest{ // RangeID , this is set by shard context Mode: persistence.UpdateWorkflowModeIgnoreCurrent, UpdateWorkflowMutation: *currentWorkflow, @@ -704,24 +719,14 @@ func (c *contextImpl) UpdateWorkflowExecutionTasks( DomainName: domainName, }) if err != nil { - if c.isPersistenceTimeoutError(err) { - c.notifyTasksFromWorkflowMutation(currentWorkflow, nil, true) + if isOperationPossiblySuccessfulError(err) { + c.notifyTasksFromWorkflowMutationFn(currentWorkflow, nil, true) } return err } - - // TODO remove updateCondition in favor of condition in mutable state - c.updateCondition = currentWorkflow.ExecutionInfo.NextEventID - // notify current workflow tasks - c.notifyTasksFromWorkflowMutation(currentWorkflow, nil, false) - - emitSessionUpdateStats( - c.metricsClient, - c.GetDomainName(), - resp.MutableStateUpdateSessionStats, - ) - + c.notifyTasksFromWorkflowMutationFn(currentWorkflow, nil, false) + c.emitSessionUpdateStatsFn(domainName, resp.MutableStateUpdateSessionStats) return nil } @@ -740,10 +745,7 @@ func (c *contextImpl) UpdateWorkflowExecutionWithNew( } }() - currentWorkflow, currentWorkflowEventsSeq, err := c.mutableState.CloseTransactionAsMutation( - now, - currentWorkflowTransactionPolicy, - ) + currentWorkflow, currentWorkflowEventsSeq, err := c.mutableState.CloseTransactionAsMutation(now, currentWorkflowTransactionPolicy) if err != nil { return err } @@ -753,11 +755,11 @@ func (c *contextImpl) UpdateWorkflowExecutionWithNew( currentWorkflowHistoryCount := c.mutableState.GetNextEventID() - 1 oldWorkflowHistoryCount := currentWorkflowHistoryCount for _, workflowEvents := range currentWorkflowEventsSeq { - blob, err := c.PersistNonStartWorkflowBatchEvents(ctx, workflowEvents) - currentWorkflowHistoryCount += int64(len(workflowEvents.Events)) + blob, err := c.persistNonStartWorkflowBatchEventsFn(ctx, workflowEvents) if err != nil { return err } + currentWorkflowHistoryCount += int64(len(workflowEvents.Events)) currentWorkflowSize += int64(len(blob.Data)) persistedBlobs = append(persistedBlobs, blob) } @@ -769,7 +771,6 @@ func (c *contextImpl) UpdateWorkflowExecutionWithNew( var newWorkflow *persistence.WorkflowSnapshot var newWorkflowEventsSeq []*persistence.WorkflowEvents if newContext != nil && newMutableState != nil && newWorkflowTransactionPolicy != nil { - defer func() { if retError != nil { newContext.Clear() @@ -788,13 +789,13 @@ func (c *contextImpl) UpdateWorkflowExecutionWithNew( firstEventID := startEvents.Events[0].ID var blob events.PersistedBlob if firstEventID == common.FirstEventID { - blob, err = c.PersistStartWorkflowBatchEvents(ctx, startEvents) + blob, err = c.persistStartWorkflowBatchEventsFn(ctx, startEvents) if err != nil { return err } } else { // NOTE: This is the case for reset workflow, reset workflow already inserted a branch record - blob, err = c.PersistNonStartWorkflowBatchEvents(ctx, startEvents) + blob, err = c.persistNonStartWorkflowBatchEventsFn(ctx, startEvents) if err != nil { return err } @@ -808,26 +809,18 @@ func (c *contextImpl) UpdateWorkflowExecutionWithNew( } } - if err := c.mergeContinueAsNewReplicationTasks( - updateMode, - currentWorkflow, - newWorkflow, - ); err != nil { + if err := c.mergeContinueAsNewReplicationTasksFn(updateMode, currentWorkflow, newWorkflow); err != nil { return err } - if err := c.updateWorkflowExecutionEventReapply( - updateMode, - currentWorkflowEventsSeq, - newWorkflowEventsSeq, - ); err != nil { + if err := c.updateWorkflowExecutionEventReapplyFn(updateMode, currentWorkflowEventsSeq, newWorkflowEventsSeq); err != nil { return err } domain, errorDomainName := c.shard.GetDomainCache().GetDomainName(c.domainID) if errorDomainName != nil { return errorDomainName } - resp, err := c.updateWorkflowExecutionWithRetry(ctx, &persistence.UpdateWorkflowExecutionRequest{ + resp, err := c.updateWorkflowExecutionFn(ctx, &persistence.UpdateWorkflowExecutionRequest{ // RangeID , this is set by shard context Mode: updateMode, UpdateWorkflowMutation: *currentWorkflow, @@ -836,16 +829,13 @@ func (c *contextImpl) UpdateWorkflowExecutionWithNew( DomainName: domain, }) if err != nil { - if c.isPersistenceTimeoutError(err) { - c.notifyTasksFromWorkflowMutation(currentWorkflow, persistedBlobs, true) - c.notifyTasksFromWorkflowSnapshot(newWorkflow, persistedBlobs, true) + if isOperationPossiblySuccessfulError(err) { + c.notifyTasksFromWorkflowMutationFn(currentWorkflow, persistedBlobs, true) + c.notifyTasksFromWorkflowSnapshotFn(newWorkflow, persistedBlobs, true) } return err } - // TODO remove updateCondition in favor of condition in mutable state - c.updateCondition = currentWorkflow.ExecutionInfo.NextEventID - // for any change in the workflow, send a event currentBranchToken, err := c.mutableState.GetCurrentBranchToken() if err != nil { @@ -864,40 +854,28 @@ func (c *contextImpl) UpdateWorkflowExecutionWithNew( )) // notify current workflow tasks - c.notifyTasksFromWorkflowMutation(currentWorkflow, persistedBlobs, false) - + c.notifyTasksFromWorkflowMutationFn(currentWorkflow, persistedBlobs, false) // notify new workflow tasks - c.notifyTasksFromWorkflowSnapshot(newWorkflow, persistedBlobs, false) + c.notifyTasksFromWorkflowSnapshotFn(newWorkflow, persistedBlobs, false) // finally emit session stats - domainName := c.GetDomainName() - emitWorkflowHistoryStats( - c.metricsClient, - domainName, - int(c.stats.HistorySize), - int(c.mutableState.GetNextEventID()-1), - ) - emitSessionUpdateStats( - c.metricsClient, - domainName, - resp.MutableStateUpdateSessionStats, - ) - c.emitLargeWorkflowShardIDStats(currentWorkflowSize-oldWorkflowSize, oldWorkflowHistoryCount, oldWorkflowSize, currentWorkflowHistoryCount) + c.emitWorkflowHistoryStatsFn(domain, int(c.stats.HistorySize), int(c.mutableState.GetNextEventID()-1)) + c.emitSessionUpdateStatsFn(domain, resp.MutableStateUpdateSessionStats) + c.emitLargeWorkflowShardIDStatsFn(currentWorkflowSize-oldWorkflowSize, oldWorkflowHistoryCount, oldWorkflowSize, currentWorkflowHistoryCount) // emit workflow completion stats if any if currentWorkflow.ExecutionInfo.State == persistence.WorkflowStateCompleted { if event, err := c.mutableState.GetCompletionEvent(ctx); err == nil { workflowType := currentWorkflow.ExecutionInfo.WorkflowTypeName taskList := currentWorkflow.ExecutionInfo.TaskList - emitWorkflowCompletionStats(c.metricsClient, c.logger, - domainName, workflowType, c.workflowExecution.GetWorkflowID(), c.workflowExecution.GetRunID(), - taskList, event) + c.emitWorkflowCompletionStatsFn(domain, workflowType, c.workflowExecution.GetWorkflowID(), c.workflowExecution.GetRunID(), taskList, event) } } return nil } -func (c *contextImpl) notifyTasksFromWorkflowSnapshot( +func notifyTasksFromWorkflowSnapshot( + engine engine.Engine, workflowSnapShot *persistence.WorkflowSnapshot, history events.PersistedBlobs, persistenceError bool, @@ -906,10 +884,11 @@ func (c *contextImpl) notifyTasksFromWorkflowSnapshot( return } - c.notifyTasks( + notifyTasks( + engine, workflowSnapShot.ExecutionInfo, workflowSnapShot.VersionHistories, - activityInfosToMap(workflowSnapShot.ActivityInfos), + workflowSnapShot.ActivityInfos, workflowSnapShot.TransferTasks, workflowSnapShot.TimerTasks, workflowSnapShot.CrossClusterTasks, @@ -919,7 +898,8 @@ func (c *contextImpl) notifyTasksFromWorkflowSnapshot( ) } -func (c *contextImpl) notifyTasksFromWorkflowMutation( +func notifyTasksFromWorkflowMutation( + engine engine.Engine, workflowMutation *persistence.WorkflowMutation, history events.PersistedBlobs, persistenceError bool, @@ -928,10 +908,11 @@ func (c *contextImpl) notifyTasksFromWorkflowMutation( return } - c.notifyTasks( + notifyTasks( + engine, workflowMutation.ExecutionInfo, workflowMutation.VersionHistories, - activityInfosToMap(workflowMutation.UpsertActivityInfos), + workflowMutation.UpsertActivityInfos, workflowMutation.TransferTasks, workflowMutation.TimerTasks, workflowMutation.CrossClusterTasks, @@ -949,10 +930,11 @@ func activityInfosToMap(ais []*persistence.ActivityInfo) map[int64]*persistence. return m } -func (c *contextImpl) notifyTasks( +func notifyTasks( + engine engine.Engine, executionInfo *persistence.WorkflowExecutionInfo, versionHistories *persistence.VersionHistories, - activities map[int64]*persistence.ActivityInfo, + activities []*persistence.ActivityInfo, transferTasks []persistence.Task, timerTasks []persistence.Task, crossClusterTasks []persistence.Task, @@ -979,18 +961,18 @@ func (c *contextImpl) notifyTasks( ExecutionInfo: executionInfo, Tasks: replicationTasks, VersionHistories: versionHistories, - Activities: activities, + Activities: activityInfosToMap(activities), History: history, PersistenceError: persistenceError, } - c.shard.GetEngine().NotifyNewTransferTasks(transferTaskInfo) - c.shard.GetEngine().NotifyNewTimerTasks(timerTaskInfo) - c.shard.GetEngine().NotifyNewCrossClusterTasks(crossClusterTaskInfo) - c.shard.GetEngine().NotifyNewReplicationTasks(replicationTaskInfo) + engine.NotifyNewTransferTasks(transferTaskInfo) + engine.NotifyNewTimerTasks(timerTaskInfo) + engine.NotifyNewCrossClusterTasks(crossClusterTaskInfo) + engine.NotifyNewReplicationTasks(replicationTaskInfo) } -func (c *contextImpl) mergeContinueAsNewReplicationTasks( +func mergeContinueAsNewReplicationTasks( updateMode persistence.UpdateWorkflowMode, currentWorkflowMutation *persistence.WorkflowMutation, newWorkflowSnapshot *persistence.WorkflowSnapshot, @@ -1061,7 +1043,7 @@ func (c *contextImpl) PersistStartWorkflowBatchEvents( RunID: workflowEvents.RunID, } - resp, err := c.appendHistoryV2EventsWithRetry( + resp, err := c.appendHistoryNodesFn( ctx, domainID, execution, @@ -1103,7 +1085,7 @@ func (c *contextImpl) PersistNonStartWorkflowBatchEvents( RunID: workflowEvents.RunID, } - resp, err := c.appendHistoryV2EventsWithRetry( + resp, err := c.appendHistoryNodesFn( ctx, domainID, execution, @@ -1125,8 +1107,10 @@ func (c *contextImpl) PersistNonStartWorkflowBatchEvents( }, nil } -func (c *contextImpl) appendHistoryV2EventsWithRetry( +func appendHistoryV2EventsWithRetry( ctx context.Context, + shardContext shard.Context, + retryPolicy backoff.RetryPolicy, domainID string, execution types.WorkflowExecution, request *persistence.AppendHistoryNodesRequest, @@ -1135,30 +1119,31 @@ func (c *contextImpl) appendHistoryV2EventsWithRetry( var resp *persistence.AppendHistoryNodesResponse op := func() error { var err error - resp, err = c.shard.AppendHistoryV2Events(ctx, request, domainID, execution) + resp, err = shardContext.AppendHistoryV2Events(ctx, request, domainID, execution) return err } - throttleRetry := backoff.NewThrottleRetry( - backoff.WithRetryPolicy(common.CreatePersistenceRetryPolicy()), + backoff.WithRetryPolicy(retryPolicy), backoff.WithRetryableError(persistence.IsTransientError), ) err := throttleRetry.Do(ctx, op) return resp, err } -func (c *contextImpl) createWorkflowExecutionWithRetry( +func createWorkflowExecutionWithRetry( ctx context.Context, + shardContext shard.Context, + logger log.Logger, + retryPolicy backoff.RetryPolicy, request *persistence.CreateWorkflowExecutionRequest, ) (*persistence.CreateWorkflowExecutionResponse, error) { var resp *persistence.CreateWorkflowExecutionResponse op := func() error { var err error - resp, err = c.shard.CreateWorkflowExecution(ctx, request) + resp, err = shardContext.CreateWorkflowExecution(ctx, request) return err } - isRetryable := func(err error) bool { if _, ok := err.(*persistence.TimeoutError); ok { // TODO: is timeout error retryable for create workflow? @@ -1168,11 +1153,11 @@ func (c *contextImpl) createWorkflowExecutionWithRetry( } return persistence.IsTransientError(err) } - throttleRetry := backoff.NewThrottleRetry( - backoff.WithRetryPolicy(common.CreatePersistenceRetryPolicy()), + backoff.WithRetryPolicy(retryPolicy), backoff.WithRetryableError(isRetryable), ) + err := throttleRetry.Do(ctx, op) switch err.(type) { case nil: @@ -1182,11 +1167,8 @@ func (c *contextImpl) createWorkflowExecutionWithRetry( // workflow ID reuse policy return nil, err default: - c.logger.Error( + logger.Error( "Persistent store operation failure", - tag.WorkflowID(c.workflowExecution.GetWorkflowID()), - tag.WorkflowRunID(c.workflowExecution.GetRunID()), - tag.WorkflowDomainID(c.domainID), tag.StoreOperationCreateWorkflowExecution, tag.Error(err), ) @@ -1203,7 +1185,6 @@ func (c *contextImpl) getWorkflowExecutionWithRetry( op := func() error { var err error resp, err = c.shard.GetWorkflowExecution(ctx, request) - return err } @@ -1221,9 +1202,6 @@ func (c *contextImpl) getWorkflowExecutionWithRetry( default: c.logger.Error( "Persistent fetch operation failure", - tag.WorkflowID(c.workflowExecution.GetWorkflowID()), - tag.WorkflowRunID(c.workflowExecution.GetRunID()), - tag.WorkflowDomainID(c.domainID), tag.StoreOperationGetWorkflowExecution, tag.Error(err), ) @@ -1231,15 +1209,18 @@ func (c *contextImpl) getWorkflowExecutionWithRetry( } } -func (c *contextImpl) updateWorkflowExecutionWithRetry( +func updateWorkflowExecutionWithRetry( ctx context.Context, + shardContext shard.Context, + logger log.Logger, + retryPolicy backoff.RetryPolicy, request *persistence.UpdateWorkflowExecutionRequest, ) (*persistence.UpdateWorkflowExecutionResponse, error) { var resp *persistence.UpdateWorkflowExecutionResponse op := func() error { var err error - resp, err = c.shard.UpdateWorkflowExecution(ctx, request) + resp, err = shardContext.UpdateWorkflowExecution(ctx, request) return err } // Preparation for the task Validation. @@ -1259,7 +1240,7 @@ func (c *contextImpl) updateWorkflowExecutionWithRetry( } throttleRetry := backoff.NewThrottleRetry( - backoff.WithRetryPolicy(common.CreatePersistenceRetryPolicy()), + backoff.WithRetryPolicy(retryPolicy), backoff.WithRetryableError(isRetryable), ) err := throttleRetry.Do(ctx, op) @@ -1269,14 +1250,11 @@ func (c *contextImpl) updateWorkflowExecutionWithRetry( case *persistence.ConditionFailedError: return nil, &conflictError{err} default: - c.logger.Error( + logger.Error( "Persistent store operation failure", - tag.WorkflowID(c.workflowExecution.GetWorkflowID()), - tag.WorkflowRunID(c.workflowExecution.GetRunID()), - tag.WorkflowDomainID(c.domainID), tag.StoreOperationUpdateWorkflowExecution, tag.Error(err), - tag.Number(c.updateCondition), + tag.Number(request.UpdateWorkflowMutation.Condition), ) // TODO: Call the Task Validation here so that it happens whenever an error happen during Update. // err1 := checker.WorkflowCheckforValidation( @@ -1340,8 +1318,6 @@ func (c *contextImpl) ReapplyEvents( workflowID := eventBatches[0].WorkflowID runID := eventBatches[0].RunID domainCache := c.shard.GetDomainCache() - clientBean := c.shard.GetService().GetClientBean() - serializer := c.shard.GetService().GetPayloadSerializer() domainEntry, err := domainCache.GetDomainByID(domainID) if err != nil { return err @@ -1369,12 +1345,6 @@ func (c *contextImpl) ReapplyEvents( return nil } - // Reapply events only reapply to the current run. - // The run id is only used for reapply event de-duplication - execution := &types.WorkflowExecution{ - WorkflowID: workflowID, - RunID: runID, - } ctx, cancel := context.WithTimeout(context.Background(), defaultRemoteCallTimeout) defer cancel() @@ -1389,6 +1359,14 @@ func (c *contextImpl) ReapplyEvents( ) } + // Reapply events only reapply to the current run. + // The run id is only used for reapply event de-duplication + execution := &types.WorkflowExecution{ + WorkflowID: workflowID, + RunID: runID, + } + clientBean := c.shard.GetService().GetClientBean() + serializer := c.shard.GetService().GetPayloadSerializer() // The active cluster of the domain is the same as current cluster. // Use the history from the same cluster to reapply events reapplyEventsDataBlob, err := serializer.SerializeBatchEvents( @@ -1417,12 +1395,7 @@ func (c *contextImpl) ReapplyEvents( ) } -func (c *contextImpl) isPersistenceTimeoutError( - err error, -) bool { - // TODO: ideally we only need to check if err has type *persistence.Timeout, - // but currently only cassandra will return timeout error of that type. - // so currently this method will return false positives +func isOperationPossiblySuccessfulError(err error) bool { switch err.(type) { case nil: return false diff --git a/service/history/execution/context_test.go b/service/history/execution/context_test.go new file mode 100644 index 00000000000..16556807e0d --- /dev/null +++ b/service/history/execution/context_test.go @@ -0,0 +1,2674 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package execution + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/uber/cadence/common" + "github.com/uber/cadence/common/backoff" + "github.com/uber/cadence/common/cache" + "github.com/uber/cadence/common/cluster" + "github.com/uber/cadence/common/log/testlogger" + "github.com/uber/cadence/common/metrics" + "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/types" + hcommon "github.com/uber/cadence/service/history/common" + "github.com/uber/cadence/service/history/engine" + "github.com/uber/cadence/service/history/events" + "github.com/uber/cadence/service/history/resource" + "github.com/uber/cadence/service/history/shard" +) + +func TestIsOperationPossiblySuccessfulError(t *testing.T) { + assert.False(t, isOperationPossiblySuccessfulError(nil)) + assert.False(t, isOperationPossiblySuccessfulError(&types.WorkflowExecutionAlreadyStartedError{})) + assert.False(t, isOperationPossiblySuccessfulError(&persistence.WorkflowExecutionAlreadyStartedError{})) + assert.False(t, isOperationPossiblySuccessfulError(&persistence.CurrentWorkflowConditionFailedError{})) + assert.False(t, isOperationPossiblySuccessfulError(&persistence.ConditionFailedError{})) + assert.False(t, isOperationPossiblySuccessfulError(&types.ServiceBusyError{})) + assert.False(t, isOperationPossiblySuccessfulError(&types.LimitExceededError{})) + assert.False(t, isOperationPossiblySuccessfulError(&persistence.ShardOwnershipLostError{})) + assert.True(t, isOperationPossiblySuccessfulError(&persistence.TimeoutError{})) + assert.False(t, isOperationPossiblySuccessfulError(NewConflictError(t, &persistence.ConditionFailedError{}))) + assert.True(t, isOperationPossiblySuccessfulError(context.DeadlineExceeded)) +} + +func TestMergeContinueAsNewReplicationTasks(t *testing.T) { + testCases := []struct { + name string + updateMode persistence.UpdateWorkflowMode + currentWorkflowMutation *persistence.WorkflowMutation + newWorkflowSnapshot *persistence.WorkflowSnapshot + wantErr bool + assertErr func(*testing.T, error) + }{ + { + name: "current workflow does not continue as new", + currentWorkflowMutation: &persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + CloseStatus: persistence.WorkflowCloseStatusCompleted, + }, + }, + wantErr: false, + }, + { + name: "update workflow as zombie and continue as new without new zombie workflow", + currentWorkflowMutation: &persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + CloseStatus: persistence.WorkflowCloseStatusContinuedAsNew, + }, + }, + updateMode: persistence.UpdateWorkflowModeBypassCurrent, + wantErr: false, + }, + { + name: "continue as new on the passive side", + currentWorkflowMutation: &persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + CloseStatus: persistence.WorkflowCloseStatusContinuedAsNew, + }, + }, + updateMode: persistence.UpdateWorkflowModeUpdateCurrent, + wantErr: false, + }, + { + name: "continue as new on the active side, but new workflow is not provided", + currentWorkflowMutation: &persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + CloseStatus: persistence.WorkflowCloseStatusContinuedAsNew, + }, + ReplicationTasks: []persistence.Task{ + &persistence.HistoryReplicationTask{}, + }, + }, + updateMode: persistence.UpdateWorkflowModeUpdateCurrent, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.IsType(t, &types.InternalServiceError{}, err) + assert.Contains(t, err.Error(), "unable to find replication task from new workflow for continue as new replication") + }, + }, + { + name: "continue as new on the active side, but new workflow has no replication task", + currentWorkflowMutation: &persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + CloseStatus: persistence.WorkflowCloseStatusContinuedAsNew, + }, + ReplicationTasks: []persistence.Task{ + &persistence.HistoryReplicationTask{}, + }, + }, + newWorkflowSnapshot: &persistence.WorkflowSnapshot{}, + updateMode: persistence.UpdateWorkflowModeUpdateCurrent, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.IsType(t, &types.InternalServiceError{}, err) + assert.Contains(t, err.Error(), "unable to find replication task from new workflow for continue as new replication") + }, + }, + { + name: "continue as new on the active side, but current workflow has no history replication task", + currentWorkflowMutation: &persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + CloseStatus: persistence.WorkflowCloseStatusContinuedAsNew, + }, + ReplicationTasks: []persistence.Task{ + &persistence.SyncActivityTask{}, + }, + }, + newWorkflowSnapshot: &persistence.WorkflowSnapshot{ + ReplicationTasks: []persistence.Task{ + &persistence.HistoryReplicationTask{}, + }, + }, + updateMode: persistence.UpdateWorkflowModeUpdateCurrent, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.IsType(t, &types.InternalServiceError{}, err) + assert.Contains(t, err.Error(), "unable to find replication task from current workflow for continue as new replication") + }, + }, + { + name: "continue as new on the active side", + currentWorkflowMutation: &persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + CloseStatus: persistence.WorkflowCloseStatusContinuedAsNew, + }, + ReplicationTasks: []persistence.Task{ + &persistence.HistoryReplicationTask{}, + }, + }, + newWorkflowSnapshot: &persistence.WorkflowSnapshot{ + ReplicationTasks: []persistence.Task{ + &persistence.HistoryReplicationTask{}, + }, + }, + updateMode: persistence.UpdateWorkflowModeUpdateCurrent, + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := mergeContinueAsNewReplicationTasks(tc.updateMode, tc.currentWorkflowMutation, tc.newWorkflowSnapshot) + if tc.wantErr { + assert.Error(t, err) + if tc.assertErr != nil { + tc.assertErr(t, err) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestNotifyTasksFromWorkflowSnapshot(t *testing.T) { + testCases := []struct { + name string + workflowSnapShot *persistence.WorkflowSnapshot + history events.PersistedBlobs + persistenceError bool + mockSetup func(*engine.MockEngine) + }{ + { + name: "Success case", + workflowSnapShot: &persistence.WorkflowSnapshot{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + VersionHistories: &persistence.VersionHistories{ + CurrentVersionHistoryIndex: 0, + Histories: []*persistence.VersionHistory{ + { + BranchToken: []byte{1, 2, 3}, + }, + }, + }, + ActivityInfos: []*persistence.ActivityInfo{ + { + Version: 1, + ScheduleID: 11, + }, + }, + TransferTasks: []persistence.Task{ + &persistence.ActivityTask{ + TaskList: "test-tl", + }, + }, + TimerTasks: []persistence.Task{ + &persistence.ActivityTimeoutTask{ + Attempt: 10, + }, + }, + CrossClusterTasks: []persistence.Task{ + &persistence.CrossClusterStartChildExecutionTask{ + StartChildExecutionTask: persistence.StartChildExecutionTask{ + TargetDomainID: "target-domain", + }, + TargetCluster: "target", + }, + }, + ReplicationTasks: []persistence.Task{ + &persistence.HistoryReplicationTask{ + FirstEventID: 1, + NextEventID: 10, + }, + }, + }, + history: events.PersistedBlobs{ + events.PersistedBlob{}, + }, + persistenceError: true, + mockSetup: func(mockEngine *engine.MockEngine) { + mockEngine.EXPECT().NotifyNewTransferTasks(&hcommon.NotifyTaskInfo{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + Tasks: []persistence.Task{ + &persistence.ActivityTask{ + TaskList: "test-tl", + }, + }, + PersistenceError: true, + }) + mockEngine.EXPECT().NotifyNewTimerTasks(&hcommon.NotifyTaskInfo{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + Tasks: []persistence.Task{ + &persistence.ActivityTimeoutTask{ + Attempt: 10, + }, + }, + PersistenceError: true, + }) + mockEngine.EXPECT().NotifyNewCrossClusterTasks(&hcommon.NotifyTaskInfo{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + Tasks: []persistence.Task{ + &persistence.CrossClusterStartChildExecutionTask{ + StartChildExecutionTask: persistence.StartChildExecutionTask{ + TargetDomainID: "target-domain", + }, + TargetCluster: "target", + }, + }, + PersistenceError: true, + }) + mockEngine.EXPECT().NotifyNewReplicationTasks(&hcommon.NotifyTaskInfo{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + Tasks: []persistence.Task{ + &persistence.HistoryReplicationTask{ + FirstEventID: 1, + NextEventID: 10, + }, + }, + VersionHistories: &persistence.VersionHistories{ + CurrentVersionHistoryIndex: 0, + Histories: []*persistence.VersionHistory{ + { + BranchToken: []byte{1, 2, 3}, + }, + }, + }, + Activities: map[int64]*persistence.ActivityInfo{ + 11: { + Version: 1, + ScheduleID: 11, + }, + }, + History: events.PersistedBlobs{ + events.PersistedBlob{}, + }, + PersistenceError: true, + }) + }, + }, + { + name: "nil snapshot", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockEngine := engine.NewMockEngine(mockCtrl) + if tc.mockSetup != nil { + tc.mockSetup(mockEngine) + } + notifyTasksFromWorkflowSnapshot(mockEngine, tc.workflowSnapShot, tc.history, tc.persistenceError) + }) + } +} + +func TestNotifyTasksFromWorkflowMutation(t *testing.T) { + testCases := []struct { + name string + workflowMutation *persistence.WorkflowMutation + history events.PersistedBlobs + persistenceError bool + mockSetup func(*engine.MockEngine) + }{ + { + name: "Success case", + workflowMutation: &persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + VersionHistories: &persistence.VersionHistories{ + CurrentVersionHistoryIndex: 0, + Histories: []*persistence.VersionHistory{ + { + BranchToken: []byte{1, 2, 3}, + }, + }, + }, + UpsertActivityInfos: []*persistence.ActivityInfo{ + { + Version: 1, + ScheduleID: 11, + }, + }, + TransferTasks: []persistence.Task{ + &persistence.ActivityTask{ + TaskList: "test-tl", + }, + }, + TimerTasks: []persistence.Task{ + &persistence.ActivityTimeoutTask{ + Attempt: 10, + }, + }, + CrossClusterTasks: []persistence.Task{ + &persistence.CrossClusterStartChildExecutionTask{ + StartChildExecutionTask: persistence.StartChildExecutionTask{ + TargetDomainID: "target-domain", + }, + TargetCluster: "target", + }, + }, + ReplicationTasks: []persistence.Task{ + &persistence.HistoryReplicationTask{ + FirstEventID: 1, + NextEventID: 10, + }, + }, + }, + history: events.PersistedBlobs{ + events.PersistedBlob{}, + }, + persistenceError: true, + mockSetup: func(mockEngine *engine.MockEngine) { + mockEngine.EXPECT().NotifyNewTransferTasks(&hcommon.NotifyTaskInfo{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + Tasks: []persistence.Task{ + &persistence.ActivityTask{ + TaskList: "test-tl", + }, + }, + PersistenceError: true, + }) + mockEngine.EXPECT().NotifyNewTimerTasks(&hcommon.NotifyTaskInfo{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + Tasks: []persistence.Task{ + &persistence.ActivityTimeoutTask{ + Attempt: 10, + }, + }, + PersistenceError: true, + }) + mockEngine.EXPECT().NotifyNewCrossClusterTasks(&hcommon.NotifyTaskInfo{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + Tasks: []persistence.Task{ + &persistence.CrossClusterStartChildExecutionTask{ + StartChildExecutionTask: persistence.StartChildExecutionTask{ + TargetDomainID: "target-domain", + }, + TargetCluster: "target", + }, + }, + PersistenceError: true, + }) + mockEngine.EXPECT().NotifyNewReplicationTasks(&hcommon.NotifyTaskInfo{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + Tasks: []persistence.Task{ + &persistence.HistoryReplicationTask{ + FirstEventID: 1, + NextEventID: 10, + }, + }, + VersionHistories: &persistence.VersionHistories{ + CurrentVersionHistoryIndex: 0, + Histories: []*persistence.VersionHistory{ + { + BranchToken: []byte{1, 2, 3}, + }, + }, + }, + Activities: map[int64]*persistence.ActivityInfo{ + 11: { + Version: 1, + ScheduleID: 11, + }, + }, + History: events.PersistedBlobs{ + events.PersistedBlob{}, + }, + PersistenceError: true, + }) + }, + }, + { + name: "nil mutation", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockEngine := engine.NewMockEngine(mockCtrl) + if tc.mockSetup != nil { + tc.mockSetup(mockEngine) + } + notifyTasksFromWorkflowMutation(mockEngine, tc.workflowMutation, tc.history, tc.persistenceError) + }) + } +} + +func TestActivityInfosToMap(t *testing.T) { + testCases := []struct { + name string + activities []*persistence.ActivityInfo + want map[int64]*persistence.ActivityInfo + }{ + { + name: "non-empty", + activities: []*persistence.ActivityInfo{ + { + Version: 1, + ScheduleID: 11, + }, + { + Version: 2, + ScheduleID: 12, + }, + }, + want: map[int64]*persistence.ActivityInfo{ + 11: { + Version: 1, + ScheduleID: 11, + }, + 12: { + Version: 2, + ScheduleID: 12, + }, + }, + }, + { + name: "empty slice", + activities: []*persistence.ActivityInfo{}, + want: map[int64]*persistence.ActivityInfo{}, + }, + { + name: "nil slice", + want: map[int64]*persistence.ActivityInfo{}, + }, + } + + for _, tc := range testCases { + assert.Equal(t, tc.want, activityInfosToMap(tc.activities)) + } +} + +func TestCreateWorkflowExecutionWithRetry(t *testing.T) { + testCases := []struct { + name string + request *persistence.CreateWorkflowExecutionRequest + mockSetup func(*shard.MockContext) + want *persistence.CreateWorkflowExecutionResponse + wantErr bool + assertErr func(*testing.T, error) + }{ + { + name: "Success case", + request: &persistence.CreateWorkflowExecutionRequest{ + RangeID: 100, + }, + mockSetup: func(mockShard *shard.MockContext) { + mockShard.EXPECT().CreateWorkflowExecution(gomock.Any(), &persistence.CreateWorkflowExecutionRequest{ + RangeID: 100, + }).Return(&persistence.CreateWorkflowExecutionResponse{ + MutableStateUpdateSessionStats: &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, + }, nil) + }, + want: &persistence.CreateWorkflowExecutionResponse{ + MutableStateUpdateSessionStats: &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, + }, + wantErr: false, + }, + { + name: "workflow already started error", + request: &persistence.CreateWorkflowExecutionRequest{ + RangeID: 100, + }, + mockSetup: func(mockShard *shard.MockContext) { + mockShard.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, &persistence.WorkflowExecutionAlreadyStartedError{}) + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.IsType(t, err, &persistence.WorkflowExecutionAlreadyStartedError{}) + }, + }, + { + name: "timeout error", + request: &persistence.CreateWorkflowExecutionRequest{ + RangeID: 100, + }, + mockSetup: func(mockShard *shard.MockContext) { + mockShard.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, &persistence.TimeoutError{}) + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.IsType(t, err, &persistence.TimeoutError{}) + }, + }, + { + name: "retry succeeds", + request: &persistence.CreateWorkflowExecutionRequest{ + RangeID: 100, + }, + mockSetup: func(mockShard *shard.MockContext) { + mockShard.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, &types.ServiceBusyError{}) + mockShard.EXPECT().CreateWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.CreateWorkflowExecutionResponse{ + MutableStateUpdateSessionStats: &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, + }, nil) + }, + want: &persistence.CreateWorkflowExecutionResponse{ + MutableStateUpdateSessionStats: &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, + }, + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockShard := shard.NewMockContext(mockCtrl) + policy := backoff.NewExponentialRetryPolicy(time.Millisecond) + policy.SetMaximumAttempts(1) + if tc.mockSetup != nil { + tc.mockSetup(mockShard) + } + resp, err := createWorkflowExecutionWithRetry(context.Background(), mockShard, testlogger.New(t), policy, tc.request) + if tc.wantErr { + assert.Error(t, err) + if tc.assertErr != nil { + tc.assertErr(t, err) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tc.want, resp) + } + }) + } +} + +func TestUpdateWorkflowExecutionWithRetry(t *testing.T) { + testCases := []struct { + name string + request *persistence.UpdateWorkflowExecutionRequest + mockSetup func(*shard.MockContext) + want *persistence.UpdateWorkflowExecutionResponse + wantErr bool + assertErr func(*testing.T, error) + }{ + { + name: "Success case", + request: &persistence.UpdateWorkflowExecutionRequest{ + RangeID: 100, + }, + mockSetup: func(mockShard *shard.MockContext) { + mockShard.EXPECT().UpdateWorkflowExecution(gomock.Any(), &persistence.UpdateWorkflowExecutionRequest{ + RangeID: 100, + }).Return(&persistence.UpdateWorkflowExecutionResponse{ + MutableStateUpdateSessionStats: &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, + }, nil) + }, + want: &persistence.UpdateWorkflowExecutionResponse{ + MutableStateUpdateSessionStats: &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, + }, + wantErr: false, + }, + { + name: "condition failed error", + request: &persistence.UpdateWorkflowExecutionRequest{ + RangeID: 100, + }, + mockSetup: func(mockShard *shard.MockContext) { + mockShard.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, &persistence.ConditionFailedError{}) + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.IsType(t, err, &conflictError{}) + }, + }, + { + name: "timeout error", + request: &persistence.UpdateWorkflowExecutionRequest{ + RangeID: 100, + }, + mockSetup: func(mockShard *shard.MockContext) { + mockShard.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, &persistence.TimeoutError{}) + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.IsType(t, err, &persistence.TimeoutError{}) + }, + }, + { + name: "retry succeeds", + request: &persistence.UpdateWorkflowExecutionRequest{ + RangeID: 100, + }, + mockSetup: func(mockShard *shard.MockContext) { + mockShard.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, &types.ServiceBusyError{}) + mockShard.EXPECT().UpdateWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.UpdateWorkflowExecutionResponse{ + MutableStateUpdateSessionStats: &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, + }, nil) + }, + want: &persistence.UpdateWorkflowExecutionResponse{ + MutableStateUpdateSessionStats: &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, + }, + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockShard := shard.NewMockContext(mockCtrl) + policy := backoff.NewExponentialRetryPolicy(time.Millisecond) + policy.SetMaximumAttempts(1) + if tc.mockSetup != nil { + tc.mockSetup(mockShard) + } + resp, err := updateWorkflowExecutionWithRetry(context.Background(), mockShard, testlogger.New(t), policy, tc.request) + if tc.wantErr { + assert.Error(t, err) + if tc.assertErr != nil { + tc.assertErr(t, err) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tc.want, resp) + } + }) + } +} + +func TestAppendHistoryV2EventsWithRetry(t *testing.T) { + testCases := []struct { + name string + domainID string + execution types.WorkflowExecution + request *persistence.AppendHistoryNodesRequest + mockSetup func(*shard.MockContext) + want *persistence.AppendHistoryNodesResponse + wantErr bool + }{ + { + name: "Success case", + domainID: "test-domain-id", + execution: types.WorkflowExecution{ + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + request: &persistence.AppendHistoryNodesRequest{ + IsNewBranch: true, + }, + mockSetup: func(mockShard *shard.MockContext) { + mockShard.EXPECT().AppendHistoryV2Events(gomock.Any(), &persistence.AppendHistoryNodesRequest{ + IsNewBranch: true, + }, "test-domain-id", types.WorkflowExecution{WorkflowID: "test-workflow-id", RunID: "test-run-id"}).Return(&persistence.AppendHistoryNodesResponse{ + DataBlob: persistence.DataBlob{}, + }, nil) + }, + want: &persistence.AppendHistoryNodesResponse{ + DataBlob: persistence.DataBlob{}, + }, + wantErr: false, + }, + { + name: "retry success", + domainID: "test-domain-id", + execution: types.WorkflowExecution{ + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + request: &persistence.AppendHistoryNodesRequest{ + IsNewBranch: true, + }, + mockSetup: func(mockShard *shard.MockContext) { + mockShard.EXPECT().AppendHistoryV2Events(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &types.ServiceBusyError{}) + mockShard.EXPECT().AppendHistoryV2Events(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(&persistence.AppendHistoryNodesResponse{ + DataBlob: persistence.DataBlob{}, + }, nil) + }, + want: &persistence.AppendHistoryNodesResponse{ + DataBlob: persistence.DataBlob{}, + }, + wantErr: false, + }, + { + name: "non retryable error", + domainID: "test-domain-id", + execution: types.WorkflowExecution{ + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + request: &persistence.AppendHistoryNodesRequest{ + IsNewBranch: true, + }, + mockSetup: func(mockShard *shard.MockContext) { + mockShard.EXPECT().AppendHistoryV2Events(gomock.Any(), &persistence.AppendHistoryNodesRequest{ + IsNewBranch: true, + }, "test-domain-id", types.WorkflowExecution{WorkflowID: "test-workflow-id", RunID: "test-run-id"}).Return(nil, errors.New("some error")) + }, + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockShard := shard.NewMockContext(mockCtrl) + policy := backoff.NewExponentialRetryPolicy(time.Millisecond) + policy.SetMaximumAttempts(1) + if tc.mockSetup != nil { + tc.mockSetup(mockShard) + } + resp, err := appendHistoryV2EventsWithRetry(context.Background(), mockShard, policy, tc.domainID, tc.execution, tc.request) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.want, resp) + } + }) + } +} + +func TestPersistStartWorkflowBatchEvents(t *testing.T) { + testCases := []struct { + name string + workflowEvents *persistence.WorkflowEvents + mockSetup func(*shard.MockContext, *cache.MockDomainCache) + mockAppendHistoryNodesFn func(context.Context, string, types.WorkflowExecution, *persistence.AppendHistoryNodesRequest) (*persistence.AppendHistoryNodesResponse, error) + wantErr bool + want events.PersistedBlob + assertErr func(*testing.T, error) + }{ + { + name: "empty events", + workflowEvents: &persistence.WorkflowEvents{}, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.IsType(t, err, &types.InternalServiceError{}) + assert.Contains(t, err.Error(), "cannot persist first workflow events with empty events") + }, + }, + { + name: "failed to get domain name", + workflowEvents: &persistence.WorkflowEvents{ + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache) { + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("", errors.New("some error")) + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, err, errors.New("some error")) + }, + }, + { + name: "failed to append history nodes", + workflowEvents: &persistence.WorkflowEvents{ + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache) { + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("test-domain", nil) + }, + mockAppendHistoryNodesFn: func(context.Context, string, types.WorkflowExecution, *persistence.AppendHistoryNodesRequest) (*persistence.AppendHistoryNodesResponse, error) { + return nil, errors.New("some error") + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, err, errors.New("some error")) + }, + }, + { + name: "success", + workflowEvents: &persistence.WorkflowEvents{ + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache) { + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("test-domain", nil) + }, + mockAppendHistoryNodesFn: func(ctx context.Context, domainID string, execution types.WorkflowExecution, req *persistence.AppendHistoryNodesRequest) (*persistence.AppendHistoryNodesResponse, error) { + assert.Equal(t, &persistence.AppendHistoryNodesRequest{ + IsNewBranch: true, + Info: "::", + BranchToken: []byte{1, 2, 3}, + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + DomainName: "test-domain", + }, req) + return &persistence.AppendHistoryNodesResponse{ + DataBlob: persistence.DataBlob{ + Data: []byte("123"), + }, + }, nil + }, + want: events.PersistedBlob{ + DataBlob: persistence.DataBlob{ + Data: []byte("123"), + }, + BranchToken: []byte{1, 2, 3}, + FirstEventID: 1, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockShard := shard.NewMockContext(mockCtrl) + mockDomainCache := cache.NewMockDomainCache(mockCtrl) + if tc.mockSetup != nil { + tc.mockSetup(mockShard, mockDomainCache) + } + ctx := &contextImpl{ + shard: mockShard, + } + if tc.mockAppendHistoryNodesFn != nil { + ctx.appendHistoryNodesFn = tc.mockAppendHistoryNodesFn + } + got, err := ctx.PersistStartWorkflowBatchEvents(context.Background(), tc.workflowEvents) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.want, got) + } + }) + } +} + +func TestPersistNonStartWorkflowBatchEvents(t *testing.T) { + testCases := []struct { + name string + workflowEvents *persistence.WorkflowEvents + mockSetup func(*shard.MockContext, *cache.MockDomainCache) + mockAppendHistoryNodesFn func(context.Context, string, types.WorkflowExecution, *persistence.AppendHistoryNodesRequest) (*persistence.AppendHistoryNodesResponse, error) + wantErr bool + want events.PersistedBlob + assertErr func(*testing.T, error) + }{ + { + name: "empty events", + workflowEvents: &persistence.WorkflowEvents{}, + wantErr: false, + want: events.PersistedBlob{}, + }, + { + name: "failed to get domain name", + workflowEvents: &persistence.WorkflowEvents{ + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache) { + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("", errors.New("some error")) + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, err, errors.New("some error")) + }, + }, + { + name: "failed to append history nodes", + workflowEvents: &persistence.WorkflowEvents{ + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache) { + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("test-domain", nil) + }, + mockAppendHistoryNodesFn: func(context.Context, string, types.WorkflowExecution, *persistence.AppendHistoryNodesRequest) (*persistence.AppendHistoryNodesResponse, error) { + return nil, errors.New("some error") + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, err, errors.New("some error")) + }, + }, + { + name: "success", + workflowEvents: &persistence.WorkflowEvents{ + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache) { + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("test-domain", nil) + }, + mockAppendHistoryNodesFn: func(ctx context.Context, domainID string, execution types.WorkflowExecution, req *persistence.AppendHistoryNodesRequest) (*persistence.AppendHistoryNodesResponse, error) { + assert.Equal(t, &persistence.AppendHistoryNodesRequest{ + IsNewBranch: false, + BranchToken: []byte{1, 2, 3}, + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + DomainName: "test-domain", + }, req) + return &persistence.AppendHistoryNodesResponse{ + DataBlob: persistence.DataBlob{ + Data: []byte("123"), + }, + }, nil + }, + want: events.PersistedBlob{ + DataBlob: persistence.DataBlob{ + Data: []byte("123"), + }, + BranchToken: []byte{1, 2, 3}, + FirstEventID: 1, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockShard := shard.NewMockContext(mockCtrl) + mockDomainCache := cache.NewMockDomainCache(mockCtrl) + if tc.mockSetup != nil { + tc.mockSetup(mockShard, mockDomainCache) + } + ctx := &contextImpl{ + shard: mockShard, + } + if tc.mockAppendHistoryNodesFn != nil { + ctx.appendHistoryNodesFn = tc.mockAppendHistoryNodesFn + } + got, err := ctx.PersistNonStartWorkflowBatchEvents(context.Background(), tc.workflowEvents) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.want, got) + } + }) + } +} + +func TestCreateWorkflowExecution(t *testing.T) { + testCases := []struct { + name string + newWorkflow *persistence.WorkflowSnapshot + history events.PersistedBlob + createMode persistence.CreateWorkflowMode + prevRunID string + prevLastWriteVersion int64 + mockCreateWorkflowExecutionFn func(context.Context, *persistence.CreateWorkflowExecutionRequest) (*persistence.CreateWorkflowExecutionResponse, error) + mockNotifyTasksFromWorkflowSnapshotFn func(*persistence.WorkflowSnapshot, events.PersistedBlobs, bool) + mockEmitSessionUpdateStatsFn func(string, *persistence.MutableStateUpdateSessionStats) + wantErr bool + }{ + { + name: "failed to create workflow execution with possibly success error", + newWorkflow: &persistence.WorkflowSnapshot{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + }, + history: events.PersistedBlob{ + DataBlob: persistence.DataBlob{ + Data: []byte("123"), + }, + BranchToken: []byte{1, 2, 3}, + FirstEventID: 1, + }, + createMode: persistence.CreateWorkflowModeContinueAsNew, + prevRunID: "test-prev-run-id", + prevLastWriteVersion: 123, + mockCreateWorkflowExecutionFn: func(context.Context, *persistence.CreateWorkflowExecutionRequest) (*persistence.CreateWorkflowExecutionResponse, error) { + return nil, &types.InternalServiceError{} + }, + mockNotifyTasksFromWorkflowSnapshotFn: func(_ *persistence.WorkflowSnapshot, _ events.PersistedBlobs, persistenceError bool) { + assert.Equal(t, true, persistenceError) + }, + wantErr: true, + }, + { + name: "success", + newWorkflow: &persistence.WorkflowSnapshot{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + }, + history: events.PersistedBlob{ + DataBlob: persistence.DataBlob{ + Data: []byte("123"), + }, + BranchToken: []byte{1, 2, 3}, + FirstEventID: 1, + }, + createMode: persistence.CreateWorkflowModeContinueAsNew, + prevRunID: "test-prev-run-id", + prevLastWriteVersion: 123, + mockCreateWorkflowExecutionFn: func(ctx context.Context, req *persistence.CreateWorkflowExecutionRequest) (*persistence.CreateWorkflowExecutionResponse, error) { + assert.Equal(t, &persistence.CreateWorkflowExecutionRequest{ + Mode: persistence.CreateWorkflowModeContinueAsNew, + PreviousRunID: "test-prev-run-id", + PreviousLastWriteVersion: 123, + NewWorkflowSnapshot: persistence.WorkflowSnapshot{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + ExecutionStats: &persistence.ExecutionStats{ + HistorySize: 3, + }, + }, + DomainName: "test-domain", + }, req) + return &persistence.CreateWorkflowExecutionResponse{ + MutableStateUpdateSessionStats: &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, + }, nil + }, + mockNotifyTasksFromWorkflowSnapshotFn: func(newWorkflow *persistence.WorkflowSnapshot, history events.PersistedBlobs, persistenceError bool) { + assert.Equal(t, &persistence.WorkflowSnapshot{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + }, newWorkflow) + assert.Equal(t, events.PersistedBlobs{ + { + DataBlob: persistence.DataBlob{ + Data: []byte("123"), + }, + BranchToken: []byte{1, 2, 3}, + FirstEventID: 1, + }, + }, history) + assert.Equal(t, false, persistenceError) + }, + mockEmitSessionUpdateStatsFn: func(domainName string, stats *persistence.MutableStateUpdateSessionStats) { + assert.Equal(t, "test-domain", domainName) + assert.Equal(t, &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, stats) + }, + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockShard := shard.NewMockContext(mockCtrl) + mockDomainCache := cache.NewMockDomainCache(mockCtrl) + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("test-domain", nil) + ctx := &contextImpl{ + shard: mockShard, + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + } + if tc.mockCreateWorkflowExecutionFn != nil { + ctx.createWorkflowExecutionFn = tc.mockCreateWorkflowExecutionFn + } + if tc.mockNotifyTasksFromWorkflowSnapshotFn != nil { + ctx.notifyTasksFromWorkflowSnapshotFn = tc.mockNotifyTasksFromWorkflowSnapshotFn + } + if tc.mockEmitSessionUpdateStatsFn != nil { + ctx.emitSessionUpdateStatsFn = tc.mockEmitSessionUpdateStatsFn + } + err := ctx.CreateWorkflowExecution(context.Background(), tc.newWorkflow, tc.history, tc.createMode, tc.prevRunID, tc.prevLastWriteVersion) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestUpdateWorkflowExecutionTasks(t *testing.T) { + testCases := []struct { + name string + mockSetup func(*shard.MockContext, *cache.MockDomainCache, *MockMutableState) + mockUpdateWorkflowExecutionFn func(context.Context, *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) + mockNotifyTasksFromWorkflowMutationFn func(*persistence.WorkflowMutation, events.PersistedBlobs, bool) + mockEmitSessionUpdateStatsFn func(string, *persistence.MutableStateUpdateSessionStats) + wantErr bool + assertErr func(*testing.T, error) + }{ + { + name: "CloseTransactionAsMutation failed", + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockMutableState *MockMutableState) { + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(nil, nil, errors.New("some error")) + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "found unexpected new events", + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockMutableState *MockMutableState) { + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowMutation{}, []*persistence.WorkflowEvents{{}}, nil) + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.IsType(t, &types.InternalServiceError{}, err) + }, + }, + { + name: "domain cache error", + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockMutableState *MockMutableState) { + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowMutation{}, []*persistence.WorkflowEvents{}, nil) + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("", errors.New("some error")) + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "update workflow failed with possibly success error", + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockMutableState *MockMutableState) { + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowMutation{}, []*persistence.WorkflowEvents{}, nil) + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("test-domain", nil) + }, + mockUpdateWorkflowExecutionFn: func(_ context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + return nil, &types.InternalServiceError{} + }, + mockNotifyTasksFromWorkflowMutationFn: func(_ *persistence.WorkflowMutation, _ events.PersistedBlobs, persistenceError bool) { + assert.Equal(t, true, persistenceError, "case: update workflow failed with possibly success error") + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.IsType(t, &types.InternalServiceError{}, err) + }, + }, + { + name: "success", + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockMutableState *MockMutableState) { + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + }, []*persistence.WorkflowEvents{}, nil) + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("test-domain", nil) + }, + mockUpdateWorkflowExecutionFn: func(_ context.Context, request *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + assert.Equal(t, &persistence.UpdateWorkflowExecutionRequest{ + UpdateWorkflowMutation: persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + ExecutionStats: &persistence.ExecutionStats{}, + }, + Mode: persistence.UpdateWorkflowModeIgnoreCurrent, + DomainName: "test-domain", + }, request, "case: success") + return &persistence.UpdateWorkflowExecutionResponse{ + MutableStateUpdateSessionStats: &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, + }, nil + }, + mockNotifyTasksFromWorkflowMutationFn: func(mutation *persistence.WorkflowMutation, history events.PersistedBlobs, persistenceError bool) { + assert.Equal(t, &persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + }, + ExecutionStats: &persistence.ExecutionStats{}, + }, mutation, "case: success") + assert.Nil(t, history, "case: success") + assert.Equal(t, false, persistenceError, "case: success") + }, + mockEmitSessionUpdateStatsFn: func(domainName string, stats *persistence.MutableStateUpdateSessionStats) { + assert.Equal(t, "test-domain", domainName, "case: success") + assert.Equal(t, &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, stats, "case: success") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockShard := shard.NewMockContext(mockCtrl) + mockDomainCache := cache.NewMockDomainCache(mockCtrl) + mockMutableState := NewMockMutableState(mockCtrl) + if tc.mockSetup != nil { + tc.mockSetup(mockShard, mockDomainCache, mockMutableState) + } + ctx := &contextImpl{ + shard: mockShard, + mutableState: mockMutableState, + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + } + if tc.mockUpdateWorkflowExecutionFn != nil { + ctx.updateWorkflowExecutionFn = tc.mockUpdateWorkflowExecutionFn + } + if tc.mockNotifyTasksFromWorkflowMutationFn != nil { + ctx.notifyTasksFromWorkflowMutationFn = tc.mockNotifyTasksFromWorkflowMutationFn + } + if tc.mockEmitSessionUpdateStatsFn != nil { + ctx.emitSessionUpdateStatsFn = tc.mockEmitSessionUpdateStatsFn + } + err := ctx.UpdateWorkflowExecutionTasks(context.Background(), time.Unix(0, 0)) + if tc.wantErr { + assert.Error(t, err) + if tc.assertErr != nil { + tc.assertErr(t, err) + } + } else { + assert.NoError(t, err) + } + }) + + } +} + +func TestUpdateWorkflowExecutionWithNew(t *testing.T) { + testCases := []struct { + name string + updateMode persistence.UpdateWorkflowMode + newContext Context + currentWorkflowTransactionPolicy TransactionPolicy + newWorkflowTransactionPolicy *TransactionPolicy + mockSetup func(*shard.MockContext, *cache.MockDomainCache, *MockMutableState, *MockMutableState, *engine.MockEngine) + mockPersistNonStartWorkflowBatchEventsFn func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) + mockPersistStartWorkflowBatchEventsFn func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) + mockUpdateWorkflowExecutionFn func(context.Context, *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) + mockNotifyTasksFromWorkflowMutationFn func(*persistence.WorkflowMutation, events.PersistedBlobs, bool) + mockNotifyTasksFromWorkflowSnapshotFn func(*persistence.WorkflowSnapshot, events.PersistedBlobs, bool) + mockEmitSessionUpdateStatsFn func(string, *persistence.MutableStateUpdateSessionStats) + mockEmitWorkflowHistoryStatsFn func(string, int, int) + mockEmitLargeWorkflowShardIDStatsFn func(int64, int64, int64, int64) + mockEmitWorkflowCompletionStatsFn func(string, string, string, string, string, *types.HistoryEvent) + mockMergeContinueAsNewReplicationTasksFn func(persistence.UpdateWorkflowMode, *persistence.WorkflowMutation, *persistence.WorkflowSnapshot) error + mockUpdateWorkflowExecutionEventReapplyFn func(persistence.UpdateWorkflowMode, []*persistence.WorkflowEvents, []*persistence.WorkflowEvents) error + wantErr bool + assertErr func(*testing.T, error) + }{ + { + name: "CloseTransactionAsMutation failed", + currentWorkflowTransactionPolicy: TransactionPolicyPassive, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(nil, nil, errors.New("some error")) + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "PersistNonStartWorkflowBatchEvents failed", + currentWorkflowTransactionPolicy: TransactionPolicyPassive, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowMutation{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + mockMutableState.EXPECT().GetNextEventID().Return(int64(11)) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, errors.New("some error") + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "CloseTransactionAsSnapshot failed", + newContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentWorkflowTransactionPolicy: TransactionPolicyActive, + newWorkflowTransactionPolicy: TransactionPolicyActive.Ptr(), + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowMutation{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + mockMutableState.EXPECT().GetNextEventID().Return(int64(11)) + mockMutableState.EXPECT().SetHistorySize(gomock.Any()) + mockNewMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(nil, nil, errors.New("some error")) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "mergeContinueAsNewReplicationTasks failed", + newContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentWorkflowTransactionPolicy: TransactionPolicyActive, + newWorkflowTransactionPolicy: TransactionPolicyActive.Ptr(), + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowMutation{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + mockMutableState.EXPECT().GetNextEventID().Return(int64(11)) + mockMutableState.EXPECT().SetHistorySize(gomock.Any()) + mockNewMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, + }, nil) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + mockPersistStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + mockMergeContinueAsNewReplicationTasksFn: func(persistence.UpdateWorkflowMode, *persistence.WorkflowMutation, *persistence.WorkflowSnapshot) error { + return errors.New("some error") + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "updateWorkflowExecutionEventReapply failed", + newContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentWorkflowTransactionPolicy: TransactionPolicyActive, + newWorkflowTransactionPolicy: TransactionPolicyActive.Ptr(), + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowMutation{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + mockMutableState.EXPECT().GetNextEventID().Return(int64(11)) + mockMutableState.EXPECT().SetHistorySize(gomock.Any()) + mockNewMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, + }, nil) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + mockPersistStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + mockMergeContinueAsNewReplicationTasksFn: func(persistence.UpdateWorkflowMode, *persistence.WorkflowMutation, *persistence.WorkflowSnapshot) error { + return nil + }, + mockUpdateWorkflowExecutionEventReapplyFn: func(persistence.UpdateWorkflowMode, []*persistence.WorkflowEvents, []*persistence.WorkflowEvents) error { + return errors.New("some error") + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "updateWorkflowExecution failed", + newContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentWorkflowTransactionPolicy: TransactionPolicyActive, + newWorkflowTransactionPolicy: TransactionPolicyActive.Ptr(), + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowMutation{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + mockMutableState.EXPECT().GetNextEventID().Return(int64(11)) + mockMutableState.EXPECT().SetHistorySize(gomock.Any()) + mockNewMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, + }, nil) + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("test-domain", nil) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + mockPersistStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + mockMergeContinueAsNewReplicationTasksFn: func(persistence.UpdateWorkflowMode, *persistence.WorkflowMutation, *persistence.WorkflowSnapshot) error { + return nil + }, + mockUpdateWorkflowExecutionEventReapplyFn: func(persistence.UpdateWorkflowMode, []*persistence.WorkflowEvents, []*persistence.WorkflowEvents) error { + return nil + }, + mockUpdateWorkflowExecutionFn: func(context.Context, *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + return nil, errors.New("some error") + }, + mockNotifyTasksFromWorkflowMutationFn: func(_ *persistence.WorkflowMutation, _ events.PersistedBlobs, persistenceError bool) { + assert.Equal(t, true, persistenceError, "case: updateWorkflowExecution failed") + }, + mockNotifyTasksFromWorkflowSnapshotFn: func(_ *persistence.WorkflowSnapshot, _ events.PersistedBlobs, persistenceError bool) { + assert.Equal(t, true, persistenceError, "case: updateWorkflowExecution failed") + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "success", + newContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + updateMode: persistence.UpdateWorkflowModeUpdateCurrent, + currentWorkflowTransactionPolicy: TransactionPolicyActive, + newWorkflowTransactionPolicy: TransactionPolicyActive.Ptr(), + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), TransactionPolicyActive).Return(&persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + State: persistence.WorkflowStateCompleted, + }, + }, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 2, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + mockMutableState.EXPECT().GetNextEventID().Return(int64(11)) + mockMutableState.EXPECT().SetHistorySize(int64(5)) + mockNewMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), TransactionPolicyActive).Return(&persistence.WorkflowSnapshot{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id2", + }, + }, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, + }, nil) + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("test-domain", nil) + mockMutableState.EXPECT().GetCurrentBranchToken().Return([]byte{5, 6}, nil) + mockMutableState.EXPECT().GetWorkflowStateCloseStatus().Return(persistence.WorkflowStateCompleted, persistence.WorkflowCloseStatusCompleted) + mockShard.EXPECT().GetEngine().Return(mockEngine) + mockEngine.EXPECT().NotifyNewHistoryEvent(gomock.Any()) + mockMutableState.EXPECT().GetLastFirstEventID().Return(int64(1)) + mockMutableState.EXPECT().GetNextEventID().Return(int64(10)) + mockMutableState.EXPECT().GetPreviousStartedEventID().Return(int64(12)) + mockMutableState.EXPECT().GetNextEventID().Return(int64(20)) + mockMutableState.EXPECT().GetCompletionEvent(gomock.Any()).Return(&types.HistoryEvent{ + ID: 123, + }, nil) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(_ context.Context, history *persistence.WorkflowEvents) (events.PersistedBlob, error) { + assert.Equal(t, &persistence.WorkflowEvents{ + Events: []*types.HistoryEvent{ + { + ID: 2, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, history, "case: success") + return events.PersistedBlob{ + DataBlob: persistence.DataBlob{ + Data: []byte{1, 2, 3, 4, 5}, + }, + }, nil + }, + mockPersistStartWorkflowBatchEventsFn: func(_ context.Context, history *persistence.WorkflowEvents) (events.PersistedBlob, error) { + assert.Equal(t, &persistence.WorkflowEvents{ + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, history, "case: success") + return events.PersistedBlob{ + DataBlob: persistence.DataBlob{ + Data: []byte{4, 5}, + }, + }, nil + }, + mockMergeContinueAsNewReplicationTasksFn: func(updateMode persistence.UpdateWorkflowMode, currentWorkflow *persistence.WorkflowMutation, newWorkflow *persistence.WorkflowSnapshot) error { + assert.Equal(t, persistence.UpdateWorkflowModeUpdateCurrent, updateMode) + assert.Equal(t, &persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + State: persistence.WorkflowStateCompleted, + }, + ExecutionStats: &persistence.ExecutionStats{ + HistorySize: 5, + }, + }, currentWorkflow, "case: success") + assert.Equal(t, &persistence.WorkflowSnapshot{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id2", + }, + ExecutionStats: &persistence.ExecutionStats{ + HistorySize: 2, + }, + }, newWorkflow, "case: success") + return nil + }, + mockUpdateWorkflowExecutionEventReapplyFn: func(updateMode persistence.UpdateWorkflowMode, currentEvents []*persistence.WorkflowEvents, newEvents []*persistence.WorkflowEvents) error { + assert.Equal(t, persistence.UpdateWorkflowModeUpdateCurrent, updateMode) + assert.Equal(t, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 2, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, currentEvents, "case: success") + assert.Equal(t, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, + }, newEvents, "case: success") + return nil + }, + mockUpdateWorkflowExecutionFn: func(_ context.Context, req *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) { + assert.Equal(t, &persistence.UpdateWorkflowExecutionRequest{ + Mode: persistence.UpdateWorkflowModeUpdateCurrent, + UpdateWorkflowMutation: persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + State: persistence.WorkflowStateCompleted, + }, + ExecutionStats: &persistence.ExecutionStats{ + HistorySize: 5, + }, + }, + NewWorkflowSnapshot: &persistence.WorkflowSnapshot{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id2", + }, + ExecutionStats: &persistence.ExecutionStats{ + HistorySize: 2, + }, + }, + DomainName: "test-domain", + }, req, "case: success") + return &persistence.UpdateWorkflowExecutionResponse{ + MutableStateUpdateSessionStats: &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, + }, nil + }, + mockNotifyTasksFromWorkflowMutationFn: func(currentWorkflow *persistence.WorkflowMutation, currentEvents events.PersistedBlobs, persistenceError bool) { + assert.Equal(t, &persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + State: persistence.WorkflowStateCompleted, + }, + ExecutionStats: &persistence.ExecutionStats{ + HistorySize: 5, + }, + }, currentWorkflow, "case: success") + assert.Equal(t, events.PersistedBlobs{ + { + DataBlob: persistence.DataBlob{ + Data: []byte{1, 2, 3, 4, 5}, + }, + }, + { + DataBlob: persistence.DataBlob{ + Data: []byte{4, 5}, + }, + }, + }, currentEvents, "case: success") + assert.Equal(t, false, persistenceError) + }, + mockNotifyTasksFromWorkflowSnapshotFn: func(newWorkflow *persistence.WorkflowSnapshot, newEvents events.PersistedBlobs, persistenceError bool) { + assert.Equal(t, &persistence.WorkflowSnapshot{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id2", + }, + ExecutionStats: &persistence.ExecutionStats{ + HistorySize: 2, + }, + }, newWorkflow, "case: success") + assert.Equal(t, events.PersistedBlobs{ + { + DataBlob: persistence.DataBlob{ + Data: []byte{1, 2, 3, 4, 5}, + }, + }, + { + DataBlob: persistence.DataBlob{ + Data: []byte{4, 5}, + }, + }, + }, newEvents, "case: success") + assert.Equal(t, false, persistenceError, "case: success") + }, + mockEmitWorkflowHistoryStatsFn: func(domainName string, size int, count int) { + assert.Equal(t, 5, size, "case: success") + assert.Equal(t, 19, count, "case: success") + }, + mockEmitSessionUpdateStatsFn: func(domainName string, stats *persistence.MutableStateUpdateSessionStats) { + assert.Equal(t, &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, stats, "case: success") + }, + mockEmitLargeWorkflowShardIDStatsFn: func(blobSize int64, oldHistoryCount int64, oldHistorySize int64, newHistoryCount int64) { + assert.Equal(t, int64(5), blobSize, "case: success") + assert.Equal(t, int64(10), oldHistoryCount, "case: success") + assert.Equal(t, int64(0), oldHistorySize, "case: success") + assert.Equal(t, int64(11), newHistoryCount, "case: success") + }, + mockEmitWorkflowCompletionStatsFn: func(domainName string, workflowType string, workflowID string, runID string, taskList string, lastEvent *types.HistoryEvent) { + assert.Equal(t, &types.HistoryEvent{ + ID: 123, + }, lastEvent, "case: success") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockShard := shard.NewMockContext(mockCtrl) + mockDomainCache := cache.NewMockDomainCache(mockCtrl) + mockMutableState := NewMockMutableState(mockCtrl) + mockNewMutableState := NewMockMutableState(mockCtrl) + mockEngine := engine.NewMockEngine(mockCtrl) + if tc.mockSetup != nil { + tc.mockSetup(mockShard, mockDomainCache, mockMutableState, mockNewMutableState, mockEngine) + } + ctx := &contextImpl{ + shard: mockShard, + mutableState: mockMutableState, + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + persistNonStartWorkflowBatchEventsFn: tc.mockPersistNonStartWorkflowBatchEventsFn, + persistStartWorkflowBatchEventsFn: tc.mockPersistStartWorkflowBatchEventsFn, + updateWorkflowExecutionFn: tc.mockUpdateWorkflowExecutionFn, + notifyTasksFromWorkflowMutationFn: tc.mockNotifyTasksFromWorkflowMutationFn, + notifyTasksFromWorkflowSnapshotFn: tc.mockNotifyTasksFromWorkflowSnapshotFn, + emitSessionUpdateStatsFn: tc.mockEmitSessionUpdateStatsFn, + emitWorkflowHistoryStatsFn: tc.mockEmitWorkflowHistoryStatsFn, + mergeContinueAsNewReplicationTasksFn: tc.mockMergeContinueAsNewReplicationTasksFn, + updateWorkflowExecutionEventReapplyFn: tc.mockUpdateWorkflowExecutionEventReapplyFn, + emitLargeWorkflowShardIDStatsFn: tc.mockEmitLargeWorkflowShardIDStatsFn, + emitWorkflowCompletionStatsFn: tc.mockEmitWorkflowCompletionStatsFn, + } + err := ctx.UpdateWorkflowExecutionWithNew(context.Background(), time.Unix(0, 0), tc.updateMode, tc.newContext, mockNewMutableState, tc.currentWorkflowTransactionPolicy, tc.newWorkflowTransactionPolicy) + if tc.wantErr { + assert.Error(t, err) + if tc.assertErr != nil { + tc.assertErr(t, err) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestConflictResolveWorkflowExecution(t *testing.T) { + testCases := []struct { + name string + conflictResolveMode persistence.ConflictResolveWorkflowMode + newContext Context + currentContext Context + currentWorkflowTransactionPolicy *TransactionPolicy + mockSetup func(*shard.MockContext, *cache.MockDomainCache, *MockMutableState, *MockMutableState, *MockMutableState, *engine.MockEngine) + mockPersistNonStartWorkflowBatchEventsFn func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) + mockPersistStartWorkflowBatchEventsFn func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) + mockUpdateWorkflowExecutionFn func(context.Context, *persistence.UpdateWorkflowExecutionRequest) (*persistence.UpdateWorkflowExecutionResponse, error) + mockNotifyTasksFromWorkflowMutationFn func(*persistence.WorkflowMutation, events.PersistedBlobs, bool) + mockNotifyTasksFromWorkflowSnapshotFn func(*persistence.WorkflowSnapshot, events.PersistedBlobs, bool) + mockEmitSessionUpdateStatsFn func(string, *persistence.MutableStateUpdateSessionStats) + mockEmitWorkflowHistoryStatsFn func(string, int, int) + mockEmitLargeWorkflowShardIDStatsFn func(int64, int64, int64, int64) + mockEmitWorkflowCompletionStatsFn func(string, string, string, string, string, *types.HistoryEvent) + mockMergeContinueAsNewReplicationTasksFn func(persistence.UpdateWorkflowMode, *persistence.WorkflowMutation, *persistence.WorkflowSnapshot) error + mockConflictResolveWorkflowExecutionEventReapplyFn func(persistence.ConflictResolveWorkflowMode, []*persistence.WorkflowEvents, []*persistence.WorkflowEvents) error + wantErr bool + assertErr func(*testing.T, error) + }{ + { + name: "resetMutableState CloseTransactionAsSnapshot failed", + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockResetMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockResetMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(nil, nil, errors.New("some error")) + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "persistNonStartWorkflowEvents failed", + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockResetMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockResetMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, errors.New("some error") + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "newMutableState CloseTransactionAsSnapshot failed", + newContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockResetMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockResetMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + mockNewMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(nil, nil, errors.New("some error")) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "persistStartWorkflowEvents failed", + newContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockResetMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockResetMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + mockNewMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, + }, nil) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + mockPersistStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, errors.New("some error") + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "currentMutableState CloseTransactionAsMutation failed", + newContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentWorkflowTransactionPolicy: TransactionPolicyActive.Ptr(), + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockResetMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockResetMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + mockNewMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, + }, nil) + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(nil, nil, errors.New("some error")) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + mockPersistStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "currentMutableState persistNonStartWorkflowEvents failed", + newContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentWorkflowTransactionPolicy: TransactionPolicyActive.Ptr(), + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockResetMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockResetMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + mockNewMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, + }, nil) + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowMutation{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 2, + }, + }, + BranchToken: []byte{5, 6}, + }, + }, nil) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(_ context.Context, history *persistence.WorkflowEvents) (events.PersistedBlob, error) { + if history.BranchToken[0] == 1 { + return events.PersistedBlob{}, nil + } + return events.PersistedBlob{}, errors.New("some error") + }, + mockPersistStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "conflictResolveEventReapply failed", + newContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentWorkflowTransactionPolicy: TransactionPolicyActive.Ptr(), + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockResetMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockResetMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + mockNewMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, + }, nil) + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowMutation{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 2, + }, + }, + BranchToken: []byte{5, 6}, + }, + }, nil) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(_ context.Context, history *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + mockPersistStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + mockConflictResolveWorkflowExecutionEventReapplyFn: func(persistence.ConflictResolveWorkflowMode, []*persistence.WorkflowEvents, []*persistence.WorkflowEvents) error { + return errors.New("some error") + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "ConflictResolveWorkflowExecution failed", + newContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentWorkflowTransactionPolicy: TransactionPolicyActive.Ptr(), + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockResetMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockResetMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + mockNewMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, + }, nil) + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowMutation{}, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 2, + }, + }, + BranchToken: []byte{5, 6}, + }, + }, nil) + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("test-domain", nil) + mockShard.EXPECT().ConflictResolveWorkflowExecution(gomock.Any(), gomock.Any()).Return(nil, errors.New("some error")) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(_ context.Context, history *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + mockPersistStartWorkflowBatchEventsFn: func(context.Context, *persistence.WorkflowEvents) (events.PersistedBlob, error) { + return events.PersistedBlob{}, nil + }, + mockConflictResolveWorkflowExecutionEventReapplyFn: func(persistence.ConflictResolveWorkflowMode, []*persistence.WorkflowEvents, []*persistence.WorkflowEvents) error { + return nil + }, + mockNotifyTasksFromWorkflowMutationFn: func(currentWorkflow *persistence.WorkflowMutation, currentEvents events.PersistedBlobs, persistenceError bool) { + assert.Equal(t, true, persistenceError, "case: ConflictResolveWorkflowExecution failed") + }, + mockNotifyTasksFromWorkflowSnapshotFn: func(newWorkflow *persistence.WorkflowSnapshot, newEvents events.PersistedBlobs, persistenceError bool) { + assert.Equal(t, true, persistenceError, "case: ConflictResolveWorkflowExecution failed") + }, + wantErr: true, + assertErr: func(t *testing.T, err error) { + assert.Equal(t, errors.New("some error"), err) + }, + }, + { + name: "ConflictResolveWorkflowExecution success", + conflictResolveMode: persistence.ConflictResolveWorkflowModeUpdateCurrent, + newContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentContext: &contextImpl{ + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + }, + currentWorkflowTransactionPolicy: TransactionPolicyActive.Ptr(), + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockResetMutableState *MockMutableState, mockNewMutableState *MockMutableState, mockMutableState *MockMutableState, mockEngine *engine.MockEngine) { + mockResetMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + State: persistence.WorkflowStateCompleted, + }, + }, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, nil) + mockNewMutableState.EXPECT().CloseTransactionAsSnapshot(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowSnapshot{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id2", + }, + }, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, + }, nil) + mockMutableState.EXPECT().CloseTransactionAsMutation(gomock.Any(), gomock.Any()).Return(&persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id0", + }, + }, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 2, + }, + }, + BranchToken: []byte{5, 6}, + }, + }, nil) + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("test-domain", nil) + mockShard.EXPECT().ConflictResolveWorkflowExecution(gomock.Any(), gomock.Any()).Return(&persistence.ConflictResolveWorkflowExecutionResponse{ + MutableStateUpdateSessionStats: &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, + }, nil) + mockResetMutableState.EXPECT().GetCurrentBranchToken().Return([]byte{1}, nil) + mockResetMutableState.EXPECT().GetWorkflowStateCloseStatus().Return(persistence.WorkflowStateCompleted, persistence.WorkflowCloseStatusCompleted) + mockShard.EXPECT().GetEngine().Return(mockEngine) + mockEngine.EXPECT().NotifyNewHistoryEvent(gomock.Any()) + mockResetMutableState.EXPECT().GetLastFirstEventID().Return(int64(123)) + mockResetMutableState.EXPECT().GetNextEventID().Return(int64(456)) + mockResetMutableState.EXPECT().GetPreviousStartedEventID().Return(int64(789)) + mockResetMutableState.EXPECT().GetNextEventID().Return(int64(1111)) + mockResetMutableState.EXPECT().GetCompletionEvent(gomock.Any()).Return(&types.HistoryEvent{ + ID: 123, + }, nil) + }, + mockPersistNonStartWorkflowBatchEventsFn: func(_ context.Context, history *persistence.WorkflowEvents) (events.PersistedBlob, error) { + if history.BranchToken[0] == 1 { + assert.Equal(t, &persistence.WorkflowEvents{ + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, history, "case: success") + return events.PersistedBlob{ + DataBlob: persistence.DataBlob{ + Data: []byte{1, 2, 3, 4, 5}, + }, + }, nil + } + + assert.Equal(t, &persistence.WorkflowEvents{ + Events: []*types.HistoryEvent{ + { + ID: 2, + }, + }, + BranchToken: []byte{5, 6}, + }, history, "case: success") + return events.PersistedBlob{ + DataBlob: persistence.DataBlob{ + Data: []byte{1, 2}, + }, + }, nil + }, + mockPersistStartWorkflowBatchEventsFn: func(_ context.Context, history *persistence.WorkflowEvents) (events.PersistedBlob, error) { + assert.Equal(t, &persistence.WorkflowEvents{ + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, history, "case: success") + return events.PersistedBlob{ + DataBlob: persistence.DataBlob{ + Data: []byte{3, 2}, + }, + }, nil + }, + mockConflictResolveWorkflowExecutionEventReapplyFn: func(mode persistence.ConflictResolveWorkflowMode, resetEvents []*persistence.WorkflowEvents, newEvents []*persistence.WorkflowEvents) error { + assert.Equal(t, persistence.ConflictResolveWorkflowModeUpdateCurrent, mode, "case: success") + assert.Equal(t, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: 1, + }, + }, + BranchToken: []byte{1, 2, 3}, + }, + }, resetEvents, "case: success") + assert.Equal(t, []*persistence.WorkflowEvents{ + { + Events: []*types.HistoryEvent{ + { + ID: common.FirstEventID, + }, + }, + BranchToken: []byte{4}, + }, + }, newEvents, "case: success") + return nil + }, + mockNotifyTasksFromWorkflowMutationFn: func(currentWorkflow *persistence.WorkflowMutation, currentEvents events.PersistedBlobs, persistenceError bool) { + assert.Equal(t, &persistence.WorkflowMutation{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id0", + }, + ExecutionStats: &persistence.ExecutionStats{ + HistorySize: 2, + }, + }, currentWorkflow, "case: success") + assert.Equal(t, events.PersistedBlobs{ + { + DataBlob: persistence.DataBlob{ + Data: []byte{1, 2, 3, 4, 5}, + }, + }, + { + DataBlob: persistence.DataBlob{ + Data: []byte{3, 2}, + }, + }, + { + DataBlob: persistence.DataBlob{ + Data: []byte{1, 2}, + }, + }, + }, currentEvents, "case: success") + assert.Equal(t, false, persistenceError, "case: success") + }, + mockNotifyTasksFromWorkflowSnapshotFn: func(newWorkflow *persistence.WorkflowSnapshot, newEvents events.PersistedBlobs, persistenceError bool) { + if newWorkflow.ExecutionInfo.RunID == "test-run-id" { + assert.Equal(t, &persistence.WorkflowSnapshot{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + State: persistence.WorkflowStateCompleted, + }, + ExecutionStats: &persistence.ExecutionStats{ + HistorySize: 5, + }, + }, newWorkflow, "case: success") + } else { + assert.Equal(t, &persistence.WorkflowSnapshot{ + ExecutionInfo: &persistence.WorkflowExecutionInfo{ + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id2", + }, + ExecutionStats: &persistence.ExecutionStats{ + HistorySize: 2, + }, + }, newWorkflow, "case: success") + } + assert.Equal(t, events.PersistedBlobs{ + { + DataBlob: persistence.DataBlob{ + Data: []byte{1, 2, 3, 4, 5}, + }, + }, + { + DataBlob: persistence.DataBlob{ + Data: []byte{3, 2}, + }, + }, + { + DataBlob: persistence.DataBlob{ + Data: []byte{1, 2}, + }, + }, + }, newEvents, "case: success") + assert.Equal(t, false, persistenceError, "case: success") + }, + mockEmitWorkflowHistoryStatsFn: func(domainName string, size int, count int) { + assert.Equal(t, 5, size, "case: success") + assert.Equal(t, 1110, count, "case: success") + }, + mockEmitSessionUpdateStatsFn: func(domainName string, stats *persistence.MutableStateUpdateSessionStats) { + assert.Equal(t, &persistence.MutableStateUpdateSessionStats{ + MutableStateSize: 123, + }, stats, "case: success") + }, + mockEmitWorkflowCompletionStatsFn: func(domainName string, workflowType string, workflowID string, runID string, taskList string, lastEvent *types.HistoryEvent) { + assert.Equal(t, &types.HistoryEvent{ + ID: 123, + }, lastEvent, "case: success") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockShard := shard.NewMockContext(mockCtrl) + mockDomainCache := cache.NewMockDomainCache(mockCtrl) + mockResetMutableState := NewMockMutableState(mockCtrl) + mockMutableState := NewMockMutableState(mockCtrl) + mockNewMutableState := NewMockMutableState(mockCtrl) + mockEngine := engine.NewMockEngine(mockCtrl) + if tc.mockSetup != nil { + tc.mockSetup(mockShard, mockDomainCache, mockResetMutableState, mockNewMutableState, mockMutableState, mockEngine) + } + ctx := &contextImpl{ + shard: mockShard, + stats: &persistence.ExecutionStats{}, + metricsClient: metrics.NewNoopMetricsClient(), + persistNonStartWorkflowBatchEventsFn: tc.mockPersistNonStartWorkflowBatchEventsFn, + persistStartWorkflowBatchEventsFn: tc.mockPersistStartWorkflowBatchEventsFn, + updateWorkflowExecutionFn: tc.mockUpdateWorkflowExecutionFn, + notifyTasksFromWorkflowMutationFn: tc.mockNotifyTasksFromWorkflowMutationFn, + notifyTasksFromWorkflowSnapshotFn: tc.mockNotifyTasksFromWorkflowSnapshotFn, + emitSessionUpdateStatsFn: tc.mockEmitSessionUpdateStatsFn, + emitWorkflowHistoryStatsFn: tc.mockEmitWorkflowHistoryStatsFn, + mergeContinueAsNewReplicationTasksFn: tc.mockMergeContinueAsNewReplicationTasksFn, + conflictResolveEventReapplyFn: tc.mockConflictResolveWorkflowExecutionEventReapplyFn, + emitLargeWorkflowShardIDStatsFn: tc.mockEmitLargeWorkflowShardIDStatsFn, + emitWorkflowCompletionStatsFn: tc.mockEmitWorkflowCompletionStatsFn, + } + err := ctx.ConflictResolveWorkflowExecution(context.Background(), time.Unix(0, 0), tc.conflictResolveMode, mockResetMutableState, tc.newContext, mockNewMutableState, tc.currentContext, mockMutableState, tc.currentWorkflowTransactionPolicy) + if tc.wantErr { + assert.Error(t, err) + if tc.assertErr != nil { + tc.assertErr(t, err) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestReapplyEvents(t *testing.T) { + testCases := []struct { + name string + eventBatches []*persistence.WorkflowEvents + mockSetup func(*shard.MockContext, *cache.MockDomainCache, *resource.Test, *engine.MockEngine) + wantErr bool + }{ + { + name: "empty input", + eventBatches: []*persistence.WorkflowEvents{}, + wantErr: false, + }, + { + name: "domain cache error", + eventBatches: []*persistence.WorkflowEvents{ + { + DomainID: "test-domain-id", + }, + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, _ *resource.Test, _ *engine.MockEngine) { + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainByID("test-domain-id").Return(nil, errors.New("some error")) + }, + wantErr: true, + }, + { + name: "domain is pending active", + eventBatches: []*persistence.WorkflowEvents{ + { + DomainID: "test-domain-id", + }, + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, _ *resource.Test, _ *engine.MockEngine) { + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainByID("test-domain-id").Return(cache.NewDomainCacheEntryForTest(nil, nil, true, nil, 0, common.Ptr(int64(1))), nil) + }, + wantErr: false, + }, + { + name: "domainID/workflowID mismatch", + eventBatches: []*persistence.WorkflowEvents{ + { + DomainID: "test-domain-id", + }, + { + DomainID: "test-domain-id2", + }, + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, _ *resource.Test, _ *engine.MockEngine) { + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainByID("test-domain-id").Return(cache.NewDomainCacheEntryForTest(nil, nil, true, nil, 0, nil), nil) + }, + wantErr: true, + }, + { + name: "no signal events", + eventBatches: []*persistence.WorkflowEvents{ + { + DomainID: "test-domain-id", + }, + { + DomainID: "test-domain-id", + }, + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, _ *resource.Test, _ *engine.MockEngine) { + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainByID("test-domain-id").Return(cache.NewDomainCacheEntryForTest(nil, nil, true, nil, 0, nil), nil) + }, + wantErr: false, + }, + { + name: "success - apply to current cluster", + eventBatches: []*persistence.WorkflowEvents{ + { + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + Events: []*types.HistoryEvent{ + { + EventType: types.EventTypeWorkflowExecutionSignaled.Ptr(), + }, + }, + }, + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, _ *resource.Test, mockEngine *engine.MockEngine) { + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainByID("test-domain-id").Return(cache.NewGlobalDomainCacheEntryForTest(nil, nil, &persistence.DomainReplicationConfig{ActiveClusterName: cluster.TestCurrentClusterName}, 0), nil) + mockShard.EXPECT().GetClusterMetadata().Return(cluster.TestActiveClusterMetadata) + mockShard.EXPECT().GetEngine().Return(mockEngine) + mockEngine.EXPECT().ReapplyEvents(gomock.Any(), "test-domain-id", "test-workflow-id", "test-run-id", []*types.HistoryEvent{ + { + EventType: types.EventTypeWorkflowExecutionSignaled.Ptr(), + }, + }).Return(nil) + }, + wantErr: false, + }, + { + name: "success - apply to remote cluster", + eventBatches: []*persistence.WorkflowEvents{ + { + DomainID: "test-domain-id", + WorkflowID: "test-workflow-id", + RunID: "test-run-id", + Events: []*types.HistoryEvent{ + { + EventType: types.EventTypeWorkflowExecutionSignaled.Ptr(), + }, + }, + }, + }, + mockSetup: func(mockShard *shard.MockContext, mockDomainCache *cache.MockDomainCache, mockResource *resource.Test, mockEngine *engine.MockEngine) { + mockShard.EXPECT().GetDomainCache().Return(mockDomainCache) + mockDomainCache.EXPECT().GetDomainByID("test-domain-id").Return(cache.NewGlobalDomainCacheEntryForTest(&persistence.DomainInfo{Name: "test-domain"}, nil, &persistence.DomainReplicationConfig{ActiveClusterName: cluster.TestAlternativeClusterName}, 0), nil) + mockShard.EXPECT().GetClusterMetadata().Return(cluster.TestActiveClusterMetadata) + mockShard.EXPECT().GetService().Return(mockResource).Times(2) + mockResource.RemoteAdminClient.EXPECT().ReapplyEvents(gomock.Any(), gomock.Any()).Return(nil) + }, + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockShard := shard.NewMockContext(mockCtrl) + mockDomainCache := cache.NewMockDomainCache(mockCtrl) + mockEngine := engine.NewMockEngine(mockCtrl) + resource := resource.NewTest(t, mockCtrl, metrics.Common) + if tc.mockSetup != nil { + tc.mockSetup(mockShard, mockDomainCache, resource, mockEngine) + } + ctx := &contextImpl{ + shard: mockShard, + } + err := ctx.ReapplyEvents(tc.eventBatches) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/service/history/queue/timer_queue_processor_base.go b/service/history/queue/timer_queue_processor_base.go index 6cf68cede0e..52c67081a2c 100644 --- a/service/history/queue/timer_queue_processor_base.go +++ b/service/history/queue/timer_queue_processor_base.go @@ -174,6 +174,9 @@ func (t *timerQueueProcessorBase) Start() { go t.processorPump() } +// Edge Case: Stop doesn't stop TimerGate if timerQueueProcessorBase is only initiliazed without starting +// As a result, TimerGate needs to be stopped separately +// One way to fix this is to make sure TimerGate doesn't start daemon loop on initilization and requires explicit Start func (t *timerQueueProcessorBase) Stop() { if !atomic.CompareAndSwapInt32(&t.status, common.DaemonStatusStarted, common.DaemonStatusStopped) { return diff --git a/service/history/queue/timer_queue_processor_base_test.go b/service/history/queue/timer_queue_processor_base_test.go index 2d2ed0b9c1a..0ec893dc5f2 100644 --- a/service/history/queue/timer_queue_processor_base_test.go +++ b/service/history/queue/timer_queue_processor_base_test.go @@ -31,6 +31,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/uber-go/tally" + "go.uber.org/goleak" "github.com/uber/cadence/common/cluster" "github.com/uber/cadence/common/dynamicconfig" @@ -93,10 +94,12 @@ func (s *timerQueueProcessorBaseSuite) SetupTest() { func (s *timerQueueProcessorBaseSuite) TearDownTest() { s.controller.Finish() s.mockShard.Finish(s.T()) + goleak.VerifyNone(s.T()) } func (s *timerQueueProcessorBaseSuite) TestIsProcessNow() { - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + defer done() s.True(timerQueueProcessBase.isProcessNow(time.Time{})) now := s.mockShard.GetCurrentTime(s.clusterName) @@ -107,6 +110,7 @@ func (s *timerQueueProcessorBaseSuite) TestIsProcessNow() { timeAfter := now.Add(10 * time.Second) s.False(timerQueueProcessBase.isProcessNow(timeAfter)) + } func (s *timerQueueProcessorBaseSuite) TestGetTimerTasks_More() { @@ -141,7 +145,8 @@ func (s *timerQueueProcessorBaseSuite) TestGetTimerTasks_More() { mockExecutionMgr := s.mockShard.Resource.ExecutionMgr mockExecutionMgr.On("GetTimerIndexTasks", mock.Anything, request).Return(response, nil).Once() - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + defer done() got, err := timerQueueProcessBase.getTimerTasks(readLevel, maxReadLevel, request.NextPageToken, batchSize) s.Nil(err) s.Equal(response.Timers, got.Timers) @@ -180,7 +185,8 @@ func (s *timerQueueProcessorBaseSuite) TestGetTimerTasks_NoMore() { mockExecutionMgr := s.mockShard.Resource.ExecutionMgr mockExecutionMgr.On("GetTimerIndexTasks", mock.Anything, request).Return(response, nil).Once() - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + defer done() got, err := timerQueueProcessBase.getTimerTasks(readLevel, maxReadLevel, request.NextPageToken, batchSize) s.Nil(err) s.Equal(response.Timers, got.Timers) @@ -220,7 +226,8 @@ func (s *timerQueueProcessorBaseSuite) TestReadLookAheadTask() { mockExecutionMgr := s.mockShard.Resource.ExecutionMgr mockExecutionMgr.On("GetTimerIndexTasks", mock.Anything, request).Return(response, nil).Once() - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + defer done() lookAheadTask, err := timerQueueProcessBase.readLookAheadTask(readLevel, maxReadLevel) s.Nil(err) s.Equal(response.Timers[0], lookAheadTask) @@ -265,7 +272,8 @@ func (s *timerQueueProcessorBaseSuite) TestReadAndFilterTasks_NoLookAhead_NoNext mockExecutionMgr.On("GetTimerIndexTasks", mock.Anything, request).Return(response, nil).Once() mockExecutionMgr.On("GetTimerIndexTasks", mock.Anything, lookAheadRequest).Return(&persistence.GetTimerIndexTasksResponse{}, nil).Once() - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + defer done() got, err := timerQueueProcessBase.readAndFilterTasks(readLevel, maxReadLevel, request.NextPageToken) s.Nil(err) s.Equal(response.Timers, got.timerTasks) @@ -304,7 +312,8 @@ func (s *timerQueueProcessorBaseSuite) TestReadAndFilterTasks_NoLookAhead_HasNex mockExecutionMgr := s.mockShard.Resource.ExecutionMgr mockExecutionMgr.On("GetTimerIndexTasks", mock.Anything, request).Return(response, nil).Once() - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + defer done() got, err := timerQueueProcessBase.readAndFilterTasks(readLevel, maxReadLevel, request.NextPageToken) s.Nil(err) s.Equal(response.Timers, got.timerTasks) @@ -354,7 +363,8 @@ func (s *timerQueueProcessorBaseSuite) TestReadAndFilterTasks_HasLookAhead_NoNex mockExecutionMgr := s.mockShard.Resource.ExecutionMgr mockExecutionMgr.On("GetTimerIndexTasks", mock.Anything, request).Return(response, nil).Once() - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + defer done() got, err := timerQueueProcessBase.readAndFilterTasks(readLevel, maxReadLevel, request.NextPageToken) s.Nil(err) s.Equal([]*persistence.TimerTaskInfo{response.Timers[0]}, got.timerTasks) @@ -404,7 +414,8 @@ func (s *timerQueueProcessorBaseSuite) TestReadAndFilterTasks_HasLookAhead_HasNe mockExecutionMgr := s.mockShard.Resource.ExecutionMgr mockExecutionMgr.On("GetTimerIndexTasks", mock.Anything, request).Return(response, nil).Once() - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + defer done() got, err := timerQueueProcessBase.readAndFilterTasks(readLevel, maxReadLevel, request.NextPageToken) s.Nil(err) s.Equal([]*persistence.TimerTaskInfo{response.Timers[0]}, got.timerTasks) @@ -462,7 +473,8 @@ func (s *timerQueueProcessorBaseSuite) TestReadAndFilterTasks_LookAheadFailed_No mockExecutionMgr.On("GetTimerIndexTasks", mock.Anything, request).Return(response, nil).Once() mockExecutionMgr.On("GetTimerIndexTasks", mock.Anything, lookAheadRequest).Return(nil, errors.New("some random error")).Times(s.mockShard.GetConfig().TimerProcessorGetFailureRetryCount()) - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + defer done() got, err := timerQueueProcessBase.readAndFilterTasks(readLevel, maxReadLevel, request.NextPageToken) s.Nil(err) s.Equal(response.Timers, got.timerTasks) @@ -471,7 +483,8 @@ func (s *timerQueueProcessorBaseSuite) TestReadAndFilterTasks_LookAheadFailed_No } func (s *timerQueueProcessorBaseSuite) TestNotifyNewTimes() { - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(nil, nil, nil, nil, nil) + defer done() // assert the initial state s.True(timerQueueProcessBase.newTime.IsZero()) @@ -539,7 +552,8 @@ func (s *timerQueueProcessorBaseSuite) TestProcessQueueCollections_SkipRead() { return shardMaxReadLevel } - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(processingQueueStates, updateMaxReadLevel, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(processingQueueStates, updateMaxReadLevel, nil, nil, nil) + defer done() timerQueueProcessBase.processQueueCollections(map[int]struct{}{queueLevel: {}}) s.Len(timerQueueProcessBase.processingQueueCollections, 1) @@ -620,7 +634,8 @@ func (s *timerQueueProcessorBaseSuite) TestProcessBatch_HasNextPage() { s.mockTaskProcessor.EXPECT().TrySubmit(gomock.Any()).Return(true, nil).AnyTimes() - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(processingQueueStates, updateMaxReadLevel, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(processingQueueStates, updateMaxReadLevel, nil, nil, nil) + defer done() timerQueueProcessBase.processQueueCollections(map[int]struct{}{queueLevel: {}}) s.Len(timerQueueProcessBase.processingQueueCollections, 1) @@ -710,7 +725,8 @@ func (s *timerQueueProcessorBaseSuite) TestProcessBatch_NoNextPage_HasLookAhead( s.mockTaskProcessor.EXPECT().TrySubmit(gomock.Any()).Return(true, nil).AnyTimes() - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(processingQueueStates, updateMaxReadLevel, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(processingQueueStates, updateMaxReadLevel, nil, nil, nil) + defer done() timerQueueProcessBase.processingQueueReadProgress[0] = timeTaskReadProgress{ currentQueue: timerQueueProcessBase.processingQueueCollections[0].ActiveQueue(), readLevel: ackLevel, @@ -807,7 +823,8 @@ func (s *timerQueueProcessorBaseSuite) TestProcessBatch_NoNextPage_NoLookAhead() s.mockTaskProcessor.EXPECT().TrySubmit(gomock.Any()).Return(true, nil).AnyTimes() - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(processingQueueStates, updateMaxReadLevel, nil, nil, nil) + timerQueueProcessBase, done := s.newTestTimerQueueProcessorBase(processingQueueStates, updateMaxReadLevel, nil, nil, nil) + defer done() timerQueueProcessBase.processingQueueReadProgress[0] = timeTaskReadProgress{ currentQueue: timerQueueProcessBase.processingQueueCollections[0].ActiveQueue(), readLevel: ackLevel, @@ -854,7 +871,7 @@ func (s *timerQueueProcessorBaseSuite) TestTimerProcessorPump_HandleAckLevelUpda return newTimerTaskKey(now, 0) } - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(processingQueueStates, updateMaxReadLevel, nil, nil, nil) + timerQueueProcessBase, _ := s.newTestTimerQueueProcessorBase(processingQueueStates, updateMaxReadLevel, nil, nil, nil) timerQueueProcessBase.options.UpdateAckInterval = dynamicconfig.GetDurationPropertyFn(1 * time.Millisecond) updatedCh := make(chan struct{}, 1) timerQueueProcessBase.updateAckLevelFn = func() (bool, task.Key, error) { @@ -889,7 +906,7 @@ func (s *timerQueueProcessorBaseSuite) TestTimerProcessorPump_SplitQueue() { return newTimerTaskKey(now, 0) } - timerQueueProcessBase := s.newTestTimerQueueProcessorBase(processingQueueStates, updateMaxReadLevel, nil, nil, nil) + timerQueueProcessBase, _ := s.newTestTimerQueueProcessorBase(processingQueueStates, updateMaxReadLevel, nil, nil, nil) timerQueueProcessBase.options.SplitQueueInterval = dynamicconfig.GetDurationPropertyFn(1 * time.Millisecond) splittedCh := make(chan struct{}, 1) timerQueueProcessBase.splitProcessingQueueCollectionFn = func(splitPolicy ProcessingQueueSplitPolicy, upsertPollTimeFn func(int, time.Time)) { @@ -912,21 +929,25 @@ func (s *timerQueueProcessorBaseSuite) newTestTimerQueueProcessorBase( updateClusterAckLevel updateClusterAckLevelFn, updateProcessingQueueStates updateProcessingQueueStatesFn, queueShutdown queueShutdownFn, -) *timerQueueProcessorBase { +) (*timerQueueProcessorBase, func()) { + timerGate := NewLocalTimerGate(s.mockShard.GetTimeSource()) + return newTimerQueueProcessorBase( - s.clusterName, - s.mockShard, - processingQueueStates, - s.mockTaskProcessor, - NewLocalTimerGate(s.mockShard.GetTimeSource()), - newTimerQueueProcessorOptions(s.mockShard.GetConfig(), true, false), - updateMaxReadLevel, - updateClusterAckLevel, - updateProcessingQueueStates, - queueShutdown, - nil, - nil, - s.logger, - s.metricsClient, - ) + s.clusterName, + s.mockShard, + processingQueueStates, + s.mockTaskProcessor, + timerGate, + newTimerQueueProcessorOptions(s.mockShard.GetConfig(), true, false), + updateMaxReadLevel, + updateClusterAckLevel, + updateProcessingQueueStates, + queueShutdown, + nil, + nil, + s.logger, + s.metricsClient, + ), func() { + timerGate.Close() + } } diff --git a/service/history/queue/transfer_queue_processor_test.go b/service/history/queue/transfer_queue_processor_test.go new file mode 100644 index 00000000000..7dada053e44 --- /dev/null +++ b/service/history/queue/transfer_queue_processor_test.go @@ -0,0 +1,75 @@ +// The MIT License (MIT) + +// Copyright (c) 2017-2020 Uber Technologies Inc. + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package queue + +import ( + "testing" + + "github.com/golang/mock/gomock" + "go.uber.org/goleak" + + "github.com/uber/cadence/common/persistence" + "github.com/uber/cadence/common/reconciliation/invariant" + "github.com/uber/cadence/service/history/config" + "github.com/uber/cadence/service/history/execution" + "github.com/uber/cadence/service/history/reset" + "github.com/uber/cadence/service/history/shard" + "github.com/uber/cadence/service/history/task" + "github.com/uber/cadence/service/history/workflowcache" + "github.com/uber/cadence/service/worker/archiver" +) + +func TestTransferQueueProcessor_RequireStartStop(t *testing.T) { + // some goroutine leak not from this test + defer goleak.VerifyNone(t) + ctrl := gomock.NewController(t) + mockShard := shard.NewTestContext( + t, ctrl, &persistence.ShardInfo{ + ShardID: 10, + RangeID: 1, + TransferAckLevel: 0, + }, + config.NewForTest()) + defer mockShard.Finish(t) + + mockProcessor := task.NewMockProcessor(ctrl) + mockResetter := reset.NewMockWorkflowResetter(ctrl) + mockArchiver := &archiver.ClientMock{} + mockInvariant := invariant.NewMockInvariant(ctrl) + mockWorkflowCache := workflowcache.NewMockWFCache(ctrl) + ratelimit := func(domain string) bool { return false } + + // Create a new transferQueueProcessor + processor := NewTransferQueueProcessor( + mockShard, + mockShard.GetEngine(), + mockProcessor, + execution.NewCache(mockShard), + mockResetter, + mockArchiver, + mockInvariant, + mockWorkflowCache, + ratelimit) + processor.Start() + processor.Stop() +}