diff --git a/service/history/api/replication/stream.go b/service/history/api/replication/stream.go index 4476692a57b..28659f1ca85 100644 --- a/service/history/api/replication/stream.go +++ b/service/history/api/replication/stream.go @@ -22,6 +22,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. +//go:generate mockgen -copyright_file ../../../../LICENSE -package $GOPACKAGE -source $GOFILE -destination stream_mock.go + package replication import ( @@ -34,25 +36,55 @@ import ( "go.temporal.io/server/api/historyservice/v1" replicationspb "go.temporal.io/server/api/replication/v1" historyclient "go.temporal.io/server/client/history" + "go.temporal.io/server/common" "go.temporal.io/server/common/log/tag" + "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/service/history/shard" "go.temporal.io/server/service/history/tasks" ) +type ( + TaskConvertorImpl struct { + Ctx context.Context + Engine shard.Engine + NamespaceCache namespace.Registry + SourceClusterShardCount int32 + SourceClusterShardID historyclient.ClusterShardID + } + TaskConvertor interface { + Convert(task tasks.Task) (*replicationspb.ReplicationTask, error) + } +) + func StreamReplicationTasks( server historyservice.HistoryService_StreamWorkflowReplicationMessagesServer, shardContext shard.Context, sourceClusterShardID historyclient.ClusterShardID, targetClusterShardID historyclient.ClusterShardID, ) error { + sourceClusterInfo, ok := shardContext.GetClusterMetadata().GetAllClusterInfo()[sourceClusterShardID.ClusterName] + if !ok { + return serviceerror.NewInternal(fmt.Sprintf("Unknown cluster: %v", sourceClusterInfo.ClusterID)) + } + engine, err := shardContext.GetEngine(server.Context()) + if err != nil { + return err + } + filter := &TaskConvertorImpl{ + Ctx: server.Context(), + Engine: engine, + NamespaceCache: shardContext.GetNamespaceRegistry(), + SourceClusterShardCount: sourceClusterInfo.ShardCount, + SourceClusterShardID: sourceClusterShardID, + } errGroup, ctx := errgroup.WithContext(server.Context()) errGroup.Go(func() error { return recvLoop(ctx, server, shardContext, sourceClusterShardID) }) errGroup.Go(func() error { - return sendLoop(ctx, server, shardContext, sourceClusterShardID, targetClusterShardID) + return sendLoop(ctx, server, shardContext, filter, sourceClusterShardID) }) return errGroup.Wait() } @@ -98,24 +130,27 @@ func recvSyncReplicationState( ) error { lastProcessedMessageID := attr.GetLastProcessedMessageId() lastProcessedMessageIDTime := attr.GetLastProcessedMessageTime() - if lastProcessedMessageID != persistence.EmptyQueueMessageID { - if err := shardContext.UpdateQueueClusterAckLevel( - tasks.CategoryReplication, - sourceClusterShardID.ClusterName, - tasks.NewImmediateKey(lastProcessedMessageID), - ); err != nil { - shardContext.GetLogger().Error( - "error updating replication level for shard", - tag.Error(err), - tag.OperationFailed, - ) - } - shardContext.UpdateRemoteClusterInfo( - sourceClusterShardID.ClusterName, - lastProcessedMessageID, - *lastProcessedMessageIDTime, + if lastProcessedMessageID == persistence.EmptyQueueMessageID { + return nil + } + + // TODO wait for #4176 to be merged and then use cluster & shard ID as reader ID + if err := shardContext.UpdateQueueClusterAckLevel( + tasks.CategoryReplication, + sourceClusterShardID.ClusterName, + tasks.NewImmediateKey(lastProcessedMessageID), + ); err != nil { + shardContext.GetLogger().Error( + "error updating replication level for shard", + tag.Error(err), + tag.OperationFailed, ) } + shardContext.UpdateRemoteClusterInfo( + sourceClusterShardID.ClusterName, + lastProcessedMessageID, + *lastProcessedMessageIDTime, + ) return nil } @@ -123,8 +158,8 @@ func sendLoop( ctx context.Context, server historyservice.HistoryService_StreamWorkflowReplicationMessagesServer, shardContext shard.Context, + taskConvertor TaskConvertor, sourceClusterShardID historyclient.ClusterShardID, - targetClusterShardID historyclient.ClusterShardID, ) error { engine, err := shardContext.GetEngine(ctx) if err != nil { @@ -137,8 +172,8 @@ func sendLoop( ctx, server, shardContext, + taskConvertor, sourceClusterShardID, - targetClusterShardID, ) if err != nil { shardContext.GetLogger().Error( @@ -152,8 +187,8 @@ func sendLoop( ctx, server, shardContext, + taskConvertor, sourceClusterShardID, - targetClusterShardID, newTaskNotificationChan, catchupEndExclusiveWatermark, ); err != nil { @@ -172,9 +207,10 @@ func sendCatchUp( ctx context.Context, server historyservice.HistoryService_StreamWorkflowReplicationMessagesServer, shardContext shard.Context, + taskConvertor TaskConvertor, sourceClusterShardID historyclient.ClusterShardID, - targetClusterShardID historyclient.ClusterShardID, ) (int64, error) { + // TODO wait for #4176 to be merged and then use cluster & shard ID as reader ID catchupBeginInclusiveWatermark := shardContext.GetQueueClusterAckLevel( tasks.CategoryReplication, sourceClusterShardID.ClusterName, @@ -184,8 +220,8 @@ func sendCatchUp( ctx, server, shardContext, + taskConvertor, sourceClusterShardID, - targetClusterShardID, catchupBeginInclusiveWatermark.TaskID, catchupEndExclusiveWatermark.TaskID, ); err != nil { @@ -198,8 +234,8 @@ func sendLive( ctx context.Context, server historyservice.HistoryService_StreamWorkflowReplicationMessagesServer, shardContext shard.Context, + taskConvertor TaskConvertor, sourceClusterShardID historyclient.ClusterShardID, - targetClusterShardID historyclient.ClusterShardID, newTaskNotificationChan <-chan struct{}, beginInclusiveWatermark int64, ) error { @@ -211,8 +247,8 @@ func sendLive( ctx, server, shardContext, + taskConvertor, sourceClusterShardID, - targetClusterShardID, beginInclusiveWatermark, endExclusiveWatermark, ); err != nil { @@ -229,8 +265,8 @@ func sendTasks( ctx context.Context, server historyservice.HistoryService_StreamWorkflowReplicationMessagesServer, shardContext shard.Context, + taskConvertor TaskConvertor, sourceClusterShardID historyclient.ClusterShardID, - targetClusterShardID historyclient.ClusterShardID, beginInclusiveWatermark int64, endExclusiveWatermark int64, ) error { @@ -261,7 +297,7 @@ Loop: if err != nil { return err } - task, err := engine.ConvertReplicationTask(ctx, item) + task, err := taskConvertor.Convert(item) if err != nil { return err } @@ -290,3 +326,35 @@ Loop: }, }) } + +func (f *TaskConvertorImpl) Convert( + task tasks.Task, +) (*replicationspb.ReplicationTask, error) { + if namespaceEntry, err := f.NamespaceCache.GetNamespaceByID( + namespace.ID(task.GetNamespaceID()), + ); err == nil { + shouldProcessTask := false + FilterLoop: + for _, targetCluster := range namespaceEntry.ClusterNames() { + if f.SourceClusterShardID.ClusterName == targetCluster { + shouldProcessTask = true + break FilterLoop + } + } + if !shouldProcessTask { + return nil, nil + } + } + // if there is error, then blindly send the task, better safe than sorry + + sourceShardID := common.WorkflowIDToHistoryShard(task.GetNamespaceID(), task.GetWorkflowID(), f.SourceClusterShardCount) + if sourceShardID != f.SourceClusterShardID.ShardID { + return nil, nil + } + + replicationTask, err := f.Engine.ConvertReplicationTask(f.Ctx, task) + if err != nil { + return nil, err + } + return replicationTask, nil +} diff --git a/service/history/api/replication/stream_mock.go b/service/history/api/replication/stream_mock.go new file mode 100644 index 00000000000..d6086deb7a7 --- /dev/null +++ b/service/history/api/replication/stream_mock.go @@ -0,0 +1,75 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Code generated by MockGen. DO NOT EDIT. +// Source: stream.go + +// Package replication is a generated GoMock package. +package replication + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + repication "go.temporal.io/server/api/replication/v1" + tasks "go.temporal.io/server/service/history/tasks" +) + +// MockTaskConvertor is a mock of TaskConvertor interface. +type MockTaskConvertor struct { + ctrl *gomock.Controller + recorder *MockTaskConvertorMockRecorder +} + +// MockTaskConvertorMockRecorder is the mock recorder for MockTaskConvertor. +type MockTaskConvertorMockRecorder struct { + mock *MockTaskConvertor +} + +// NewMockTaskConvertor creates a new mock instance. +func NewMockTaskConvertor(ctrl *gomock.Controller) *MockTaskConvertor { + mock := &MockTaskConvertor{ctrl: ctrl} + mock.recorder = &MockTaskConvertorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTaskConvertor) EXPECT() *MockTaskConvertorMockRecorder { + return m.recorder +} + +// Convert mocks base method. +func (m *MockTaskConvertor) Convert(task tasks.Task) (*repication.ReplicationTask, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Convert", task) + ret0, _ := ret[0].(*repication.ReplicationTask) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Convert indicates an expected call of Convert. +func (mr *MockTaskConvertorMockRecorder) Convert(task interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Convert", reflect.TypeOf((*MockTaskConvertor)(nil).Convert), task) +} diff --git a/service/history/api/replication/stream_test.go b/service/history/api/replication/stream_test.go index 4cfb7b61ace..847e3736677 100644 --- a/service/history/api/replication/stream_test.go +++ b/service/history/api/replication/stream_test.go @@ -54,6 +54,7 @@ type ( server *historyservicemock.MockHistoryService_StreamWorkflowReplicationMessagesServer shardContext *shard.MockContext historyEngine *shard.MockEngine + taskConvertor *MockTaskConvertor ctx context.Context cancel context.CancelFunc @@ -62,7 +63,7 @@ type ( } ) -func TestBiDirectionStreamSuite(t *testing.T) { +func TestStreamSuite(t *testing.T) { s := new(streamSuite) suite.Run(t, s) } @@ -82,6 +83,7 @@ func (s *streamSuite) SetupTest() { s.server = historyservicemock.NewMockHistoryService_StreamWorkflowReplicationMessagesServer(s.controller) s.shardContext = shard.NewMockContext(s.controller) s.historyEngine = shard.NewMockEngine(s.controller) + s.taskConvertor = NewMockTaskConvertor(s.controller) s.ctx, s.cancel = context.WithCancel(context.Background()) s.sourceClusterShardID = historyclient.ClusterShardID{ @@ -152,8 +154,8 @@ func (s *streamSuite) TestSendCatchUp() { s.ctx, s.server, s.shardContext, + s.taskConvertor, s.sourceClusterShardID, - s.targetClusterShardID, ) s.NoError(err) s.Equal(endExclusiveWatermark, taskID) @@ -213,8 +215,8 @@ func (s *streamSuite) TestSendLive() { s.ctx, s.server, s.shardContext, + s.taskConvertor, s.sourceClusterShardID, - s.targetClusterShardID, channel, watermark0, ) @@ -229,8 +231,8 @@ func (s *streamSuite) TestSendTasks_Noop() { s.ctx, s.server, s.shardContext, + s.taskConvertor, s.sourceClusterShardID, - s.targetClusterShardID, beginInclusiveWatermark, endExclusiveWatermark, ) @@ -262,8 +264,8 @@ func (s *streamSuite) TestSendTasks_WithoutTasks() { s.ctx, s.server, s.shardContext, + s.taskConvertor, s.sourceClusterShardID, - s.targetClusterShardID, beginInclusiveWatermark, endExclusiveWatermark, ) @@ -296,9 +298,9 @@ func (s *streamSuite) TestSendTasks_WithTasks() { beginInclusiveWatermark, endExclusiveWatermark, ).Return(iter, nil) - s.historyEngine.EXPECT().ConvertReplicationTask(s.ctx, item0).Return(task0, nil) - s.historyEngine.EXPECT().ConvertReplicationTask(s.ctx, item1).Return(nil, nil) - s.historyEngine.EXPECT().ConvertReplicationTask(s.ctx, item2).Return(task2, nil) + s.taskConvertor.EXPECT().Convert(item0).Return(task0, nil) + s.taskConvertor.EXPECT().Convert(item1).Return(nil, nil) + s.taskConvertor.EXPECT().Convert(item2).Return(task2, nil) gomock.InOrder( s.server.EXPECT().Send(&historyservice.StreamWorkflowReplicationMessagesResponse{ Attributes: &historyservice.StreamWorkflowReplicationMessagesResponse_Messages{ @@ -329,8 +331,8 @@ func (s *streamSuite) TestSendTasks_WithTasks() { s.ctx, s.server, s.shardContext, + s.taskConvertor, s.sourceClusterShardID, - s.targetClusterShardID, beginInclusiveWatermark, endExclusiveWatermark, ) diff --git a/service/history/replication/stream_receiver_monitor.go b/service/history/replication/stream_receiver_monitor.go index 2f9c0a6f60e..e813d954b52 100644 --- a/service/history/replication/stream_receiver_monitor.go +++ b/service/history/replication/stream_receiver_monitor.go @@ -146,7 +146,8 @@ func (m *StreamReceiverMonitorImpl) generateStreamKeys() map[ClusterShardKeyPair sourceClusterName := m.ClusterMetadata.GetCurrentClusterName() targetClusterNames := make(map[string]struct{}) - for clusterName, clusterInfo := range m.ClusterMetadata.GetAllClusterInfo() { + clusterInfo := m.ClusterMetadata.GetAllClusterInfo() + for clusterName, clusterInfo := range clusterInfo { if !clusterInfo.Enabled || clusterName == sourceClusterName { continue } @@ -155,13 +156,20 @@ func (m *StreamReceiverMonitorImpl) generateStreamKeys() map[ClusterShardKeyPair streamKeys := make(map[ClusterShardKeyPair]struct{}) for _, shardID := range m.ShardController.ShardIDs() { for targetClusterName := range targetClusterNames { + // NOTE: + // source: client side of the replication stream, this is actually the receiver of replication tasks + // target: server side of the replication stream, this is actually the sender of replication tasks sourceShardID := shardID - // TODO src shards !necessary= target shards, add conversion fn here - targetShardID := shardID - streamKeys[ClusterShardKeyPair{ - Source: NewClusterShardKey(sourceClusterName, sourceShardID), - Target: NewClusterShardKey(targetClusterName, targetShardID), - }] = struct{}{} + for _, targetShardID := range common.MapShardID( + clusterInfo[sourceClusterName].ShardCount, + clusterInfo[targetClusterName].ShardCount, + sourceShardID, + ) { + streamKeys[ClusterShardKeyPair{ + Source: NewClusterShardKey(sourceClusterName, sourceShardID), + Target: NewClusterShardKey(targetClusterName, targetShardID), + }] = struct{}{} + } } } return streamKeys