diff --git a/host/test_cluster.go b/host/test_cluster.go index 1b1544609b0..41a6d7cd7de 100644 --- a/host/test_cluster.go +++ b/host/test_cluster.go @@ -168,7 +168,10 @@ func NewCluster(options *TestClusterConfig, logger log.Logger) (*TestCluster, er } } + clusterInfoMap := make(map[string]cluster.ClusterInformation) for clusterName, clusterInfo := range clusterMetadataConfig.ClusterInformation { + clusterInfo.ShardCount = options.HistoryConfig.NumHistoryShards + clusterInfoMap[clusterName] = clusterInfo _, err := testBase.ClusterMetadataManager.SaveClusterMetadata(context.Background(), &persistence.SaveClusterMetadataRequest{ ClusterMetadata: persistencespb.ClusterMetadata{ HistoryShardCount: options.HistoryConfig.NumHistoryShards, @@ -184,6 +187,7 @@ func NewCluster(options *TestClusterConfig, logger log.Logger) (*TestCluster, er return nil, err } } + clusterMetadataConfig.ClusterInformation = clusterInfoMap // This will save custom test search attributes to cluster metadata. // Actual Elasticsearch fields are created from index template (testdata/es_v7_index_template.json). diff --git a/service/history/replication/poller_manager.go b/service/history/replication/poller_manager.go index 0b01e6c6621..7e392077d2f 100644 --- a/service/history/replication/poller_manager.go +++ b/service/history/replication/poller_manager.go @@ -31,12 +31,18 @@ import ( ) type ( + pollerManager interface { + getSourceClusterShardIDs(sourceClusterName string) []int32 + } + pollerManagerImpl struct { currentShardId int32 clusterMetadata cluster.Metadata } ) +var _ pollerManager = (*pollerManagerImpl)(nil) + func newPollerManager( currentShardId int32, clusterMetadata cluster.Metadata, @@ -47,21 +53,21 @@ func newPollerManager( } } -func (p pollerManagerImpl) getPollingShardIDs(remoteClusterName string) []int32 { +func (p pollerManagerImpl) getSourceClusterShardIDs(sourceClusterName string) []int32 { currentCluster := p.clusterMetadata.GetCurrentClusterName() allClusters := p.clusterMetadata.GetAllClusterInfo() currentClusterInfo, ok := allClusters[currentCluster] if !ok { panic("Cannot get current cluster info from cluster metadata cache") } - remoteClusterInfo, ok := allClusters[remoteClusterName] + remoteClusterInfo, ok := allClusters[sourceClusterName] if !ok { - panic(fmt.Sprintf("Cannot get remote cluster %s info from cluster metadata cache", remoteClusterName)) + panic(fmt.Sprintf("Cannot get source cluster %s info from cluster metadata cache", sourceClusterName)) } - return generatePollingShardIDs(p.currentShardId, currentClusterInfo.ShardCount, remoteClusterInfo.ShardCount) + return generateShardIDs(p.currentShardId, currentClusterInfo.ShardCount, remoteClusterInfo.ShardCount) } -func generatePollingShardIDs(localShardId int32, localShardCount int32, remoteShardCount int32) []int32 { +func generateShardIDs(localShardId int32, localShardCount int32, remoteShardCount int32) []int32 { var pollingShards []int32 if remoteShardCount <= localShardCount { if localShardId <= remoteShardCount { diff --git a/service/history/replication/poller_manager_test.go b/service/history/replication/poller_manager_test.go index 3b2be030b9d..abe69d9f79f 100644 --- a/service/history/replication/poller_manager_test.go +++ b/service/history/replication/poller_manager_test.go @@ -90,7 +90,7 @@ func TestGetPollingShardIds(t *testing.T) { t.Errorf("The code did not panic") } }() - shardIDs := generatePollingShardIDs(tt.shardID, tt.localShardCount, tt.remoteShardCount) + shardIDs := generateShardIDs(tt.shardID, tt.localShardCount, tt.remoteShardCount) assert.Equal(t, tt.expectedShardIDs, shardIDs) }) } diff --git a/service/history/replication/task_processor.go b/service/history/replication/task_processor.go index 5558e5c1bc9..19a5ba2e0b2 100644 --- a/service/history/replication/task_processor.go +++ b/service/history/replication/task_processor.go @@ -74,9 +74,10 @@ type ( // taskProcessorImpl is responsible for processing replication tasks for a shard. taskProcessorImpl struct { - currentCluster string + status int32 + sourceCluster string - status int32 + sourceShardID int32 shard shard.Context historyEngine shard.Engine historySerializer serialization.Serializer @@ -109,6 +110,7 @@ type ( // NewTaskProcessor creates a new replication task processor. func NewTaskProcessor( + sourceShardID int32, shard shard.Context, historyEngine shard.Engine, config *configs.Config, @@ -132,9 +134,9 @@ func NewTaskProcessor( WithExpirationInterval(config.ReplicationTaskProcessorErrorRetryExpiration(shardID)) return &taskProcessorImpl{ - currentCluster: shard.GetClusterMetadata().GetCurrentClusterName(), - sourceCluster: replicationTaskFetcher.getSourceCluster(), status: common.DaemonStatusInitialized, + sourceShardID: sourceShardID, + sourceCluster: replicationTaskFetcher.getSourceCluster(), shard: shard, historyEngine: historyEngine, historySerializer: eventSerializer, @@ -383,6 +385,7 @@ func (p *taskProcessorImpl) convertTaskToDLQTask( switch replicationTask.TaskType { case enumsspb.REPLICATION_TASK_TYPE_SYNC_ACTIVITY_TASK: taskAttributes := replicationTask.GetSyncActivityTaskAttributes() + // TODO: GetShardID will break GetDLQReplicationMessages we need to handle DLQ for cross shard replication. return &persistence.PutReplicationTaskToDLQRequest{ ShardID: p.shard.GetShardID(), SourceClusterName: p.sourceCluster, @@ -414,6 +417,7 @@ func (p *taskProcessorImpl) convertTaskToDLQTask( // NOTE: last event vs next event, next event ID is exclusive nextEventID := lastEvent.GetEventId() + 1 + // TODO: GetShardID will break GetDLQReplicationMessages we need to handle DLQ for cross shard replication. return &persistence.PutReplicationTaskToDLQRequest{ ShardID: p.shard.GetShardID(), SourceClusterName: p.sourceCluster, @@ -442,6 +446,7 @@ func (p *taskProcessorImpl) convertTaskToDLQTask( return nil, err } + // TODO: GetShardID will break GetDLQReplicationMessages we need to handle DLQ for cross shard replication. return &persistence.PutReplicationTaskToDLQRequest{ ShardID: p.shard.GetShardID(), SourceClusterName: p.sourceCluster, @@ -464,7 +469,7 @@ func (p *taskProcessorImpl) paginationFn(_ []byte) ([]interface{}, []byte, error respChan := make(chan *replicationspb.ReplicationMessages, 1) p.requestChan <- &replicationTaskRequest{ token: &replicationspb.ReplicationToken{ - ShardId: p.shard.GetShardID(), + ShardId: p.sourceShardID, LastProcessedMessageId: p.maxRxProcessedTaskID, LastProcessedVisibilityTime: &p.maxRxProcessedTimestamp, LastRetrievedMessageId: p.maxRxReceivedTaskID, @@ -499,7 +504,7 @@ func (p *taskProcessorImpl) paginationFn(_ []byte) ([]interface{}, []byte, error if resp.GetHasMore() { p.rxTaskBackoff = time.Duration(0) } else { - p.rxTaskBackoff = p.config.ReplicationTaskProcessorNoTaskRetryWait(p.shard.GetShardID()) + p.rxTaskBackoff = p.config.ReplicationTaskProcessorNoTaskRetryWait(p.sourceShardID) } return tasks, nil, nil diff --git a/service/history/replication/task_processor_manager.go b/service/history/replication/task_processor_manager.go index 843924b7c8d..f20ca98b783 100644 --- a/service/history/replication/task_processor_manager.go +++ b/service/history/replication/task_processor_manager.go @@ -26,6 +26,7 @@ package replication import ( "context" + "fmt" "sync" "sync/atomic" "time" @@ -48,6 +49,10 @@ import ( "go.temporal.io/server/service/history/workflow" ) +const ( + clusterCallbackKey = "%s-%d" // - +) + type ( // taskProcessorManagerImpl is to manage replication task processors taskProcessorManagerImpl struct { @@ -61,6 +66,7 @@ type ( workflowCache workflow.Cache resender xdc.NDCHistoryResender taskExecutorProvider TaskExecutorProvider + taskPollerManager pollerManager metricsHandler metrics.MetricsHandler logger log.Logger @@ -109,6 +115,7 @@ func NewTaskProcessorManager( metricsHandler: shard.GetMetricsHandler(), taskProcessors: make(map[string]TaskProcessor), taskExecutorProvider: taskExecutorProvider, + taskPollerManager: newPollerManager(shard.GetShardID(), shard.GetClusterMetadata()), minTxAckedTaskID: persistence.EmptyQueueMessageID, shutdownChan: make(chan struct{}), } @@ -166,36 +173,39 @@ func (r *taskProcessorManagerImpl) handleClusterMetadataUpdate( if clusterName == currentClusterName { continue } - // The metadata triggers a update when the following fields update: 1. Enabled 2. Initial Failover Version 3. Cluster address - // The callback covers three cases: - // Case 1: Remove a cluster Case 2: Add a new cluster Case 3: Refresh cluster metadata. - - if processor, ok := r.taskProcessors[clusterName]; ok { - // Case 1 and Case 3 - processor.Stop() - delete(r.taskProcessors, clusterName) - } - - if clusterInfo := newClusterMetadata[clusterName]; clusterInfo != nil && clusterInfo.Enabled { - // Case 2 and Case 3 - fetcher := r.replicationTaskFetcherFactory.GetOrCreateFetcher(clusterName) - replicationTaskProcessor := NewTaskProcessor( - r.shard, - r.engine, - r.config, - r.shard.GetMetricsHandler(), - fetcher, - r.taskExecutorProvider(TaskExecutorParams{ - RemoteCluster: clusterName, - Shard: r.shard, - HistoryResender: r.resender, - DeleteManager: r.deleteMgr, - WorkflowCache: r.workflowCache, - }), - r.eventSerializer, - ) - replicationTaskProcessor.Start() - r.taskProcessors[clusterName] = replicationTaskProcessor + sourceShardIds := r.taskPollerManager.getSourceClusterShardIDs(clusterName) + for _, sourceShardId := range sourceShardIds { + perShardTaskProcessorKey := fmt.Sprintf(clusterCallbackKey, clusterName, sourceShardId) + // The metadata triggers an update when the following fields update: 1. Enabled 2. Initial Failover Version 3. Cluster address + // The callback covers three cases: + // Case 1: Remove a cluster Case 2: Add a new cluster Case 3: Refresh cluster metadata. + if processor, ok := r.taskProcessors[perShardTaskProcessorKey]; ok { + // Case 1 and Case 3 + processor.Stop() + delete(r.taskProcessors, perShardTaskProcessorKey) + } + if clusterInfo := newClusterMetadata[clusterName]; clusterInfo != nil && clusterInfo.Enabled { + // Case 2 and Case 3 + fetcher := r.replicationTaskFetcherFactory.GetOrCreateFetcher(clusterName) + replicationTaskProcessor := NewTaskProcessor( + sourceShardId, + r.shard, + r.engine, + r.config, + r.shard.GetMetricsHandler(), + fetcher, + r.taskExecutorProvider(TaskExecutorParams{ + RemoteCluster: clusterName, + Shard: r.shard, + HistoryResender: r.resender, + DeleteManager: r.deleteMgr, + WorkflowCache: r.workflowCache, + }), + r.eventSerializer, + ) + replicationTaskProcessor.Start() + r.taskProcessors[perShardTaskProcessorKey] = replicationTaskProcessor + } } } } diff --git a/service/history/replication/task_processor_test.go b/service/history/replication/task_processor_test.go index 0298e428939..3c922899d66 100644 --- a/service/history/replication/task_processor_test.go +++ b/service/history/replication/task_processor_test.go @@ -148,6 +148,7 @@ func (s *taskProcessorSuite) SetupTest() { metricsClient := metrics.NoopMetricsHandler s.replicationTaskProcessor = NewTaskProcessor( + s.shardID, s.mockShard, s.mockEngine, s.config,