diff --git a/service/frontend/adminHandler.go b/service/frontend/adminHandler.go index 9f4951b2c10..f8606e8ce23 100644 --- a/service/frontend/adminHandler.go +++ b/service/frontend/adminHandler.go @@ -33,8 +33,10 @@ import ( "sync/atomic" "time" - "golang.org/x/sync/errgroup" + "google.golang.org/grpc/metadata" + "go.temporal.io/server/client/history" + "go.temporal.io/server/common/channel" "go.temporal.io/server/common/clock" "go.temporal.io/server/common/primitives" "go.temporal.io/server/common/util" @@ -1962,62 +1964,84 @@ func (adh *AdminHandler) getWorkflowCompletionEvent( } func (adh *AdminHandler) StreamWorkflowReplicationMessages( - targetCluster adminservice.AdminService_StreamWorkflowReplicationMessagesServer, + clientCluster adminservice.AdminService_StreamWorkflowReplicationMessagesServer, ) (retError error) { defer log.CapturePanic(adh.logger, &retError) - ctx := targetCluster.Context() - sourceCluster, err := adh.historyClient.StreamWorkflowReplicationMessages(ctx) + ctxMetadata, ok := metadata.FromIncomingContext(clientCluster.Context()) + if !ok { + return serviceerror.NewInvalidArgument("missing cluster & shard ID metadata") + } + _, serverClusterShardID, err := history.DecodeClusterShardMD(ctxMetadata) + if err != nil { + return err + } + + logger := log.With(adh.logger, tag.ShardID(serverClusterShardID.ShardID)) + logger.Info("AdminStreamReplicationMessages started.") + defer logger.Info("AdminStreamReplicationMessages stopped.") + + ctx := clientCluster.Context() + serverCluster, err := adh.historyClient.StreamWorkflowReplicationMessages(ctx) if err != nil { return err } - errGroup, ctx := errgroup.WithContext(ctx) - errGroup.Go(func() error { - for ctx.Err() == nil { - req, err := targetCluster.Recv() + shutdownChan := channel.NewShutdownOnce() + go func() { + defer shutdownChan.Shutdown() + + for !shutdownChan.IsShutdown() { + req, err := clientCluster.Recv() if err != nil { - return err + logger.Info("AdminStreamReplicationMessages client -> server encountered error", tag.Error(err)) + return } switch attr := req.GetAttributes().(type) { case *adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState: - if err = sourceCluster.Send(&historyservice.StreamWorkflowReplicationMessagesRequest{ + if err = serverCluster.Send(&historyservice.StreamWorkflowReplicationMessagesRequest{ Attributes: &historyservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{ SyncReplicationState: attr.SyncReplicationState, }, }); err != nil { - return err + logger.Info("AdminStreamReplicationMessages client -> server encountered error", tag.Error(err)) + return } default: - return serviceerror.NewInternal(fmt.Sprintf( + logger.Info("AdminStreamReplicationMessages client -> server encountered error", tag.Error(serviceerror.NewInternal(fmt.Sprintf( "StreamWorkflowReplicationMessages encountered unknown type: %T %v", attr, attr, - )) + )))) + return } } - return ctx.Err() - }) - errGroup.Go(func() error { - for ctx.Err() == nil { - resp, err := sourceCluster.Recv() + }() + go func() { + defer shutdownChan.Shutdown() + + for !shutdownChan.IsShutdown() { + resp, err := serverCluster.Recv() if err != nil { - return err + logger.Info("AdminStreamReplicationMessages server -> client encountered error", tag.Error(err)) + return } switch attr := resp.GetAttributes().(type) { case *historyservice.StreamWorkflowReplicationMessagesResponse_Messages: - if err = targetCluster.Send(&adminservice.StreamWorkflowReplicationMessagesResponse{ + if err = clientCluster.Send(&adminservice.StreamWorkflowReplicationMessagesResponse{ Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{ Messages: attr.Messages, }, }); err != nil { - return err + logger.Info("AdminStreamReplicationMessages server -> client encountered error", tag.Error(err)) + return } default: - return serviceerror.NewInternal(fmt.Sprintf( + logger.Info("AdminStreamReplicationMessages server -> client encountered error", tag.Error(serviceerror.NewInternal(fmt.Sprintf( "StreamWorkflowReplicationMessages encountered unknown type: %T %v", attr, attr, - )) + )))) + return } } - return ctx.Err() - }) - return errGroup.Wait() + }() + <-shutdownChan.Channel() + return nil } diff --git a/service/frontend/adminHandler_test.go b/service/frontend/adminHandler_test.go index 347cc95dce4..68666055736 100644 --- a/service/frontend/adminHandler_test.go +++ b/service/frontend/adminHandler_test.go @@ -28,9 +28,14 @@ import ( "context" "errors" "fmt" + "math/rand" + "sync" "testing" "time" + "google.golang.org/grpc/metadata" + + historyclient "go.temporal.io/server/client/history" "go.temporal.io/server/common/clock" "go.temporal.io/server/common/persistence/visibility/store/standard/cassandra" "go.temporal.io/server/common/primitives" @@ -1445,3 +1450,93 @@ func (s *adminHandlerSuite) TestDeleteWorkflowExecution_CassandraVisibilityBacke _, err = s.handler.DeleteWorkflowExecution(context.Background(), request) s.NoError(err) } + +func (s *adminHandlerSuite) TestStreamWorkflowReplicationMessages_ClientToServerBroken() { + clientClusterShardID := historyclient.ClusterShardID{ + ClusterID: rand.Int31(), + ShardID: rand.Int31(), + } + serverClusterShardID := historyclient.ClusterShardID{ + ClusterID: rand.Int31(), + ShardID: rand.Int31(), + } + clusterShardMD := historyclient.EncodeClusterShardMD( + clientClusterShardID, + serverClusterShardID, + ) + ctx := metadata.NewIncomingContext(context.Background(), clusterShardMD) + clientCluster := adminservicemock.NewMockAdminService_StreamWorkflowReplicationMessagesServer(s.controller) + clientCluster.EXPECT().Context().Return(ctx).AnyTimes() + serverCluster := historyservicemock.NewMockHistoryService_StreamWorkflowReplicationMessagesClient(s.controller) + s.mockHistoryClient.EXPECT().StreamWorkflowReplicationMessages(ctx).Return(serverCluster, nil) + + waitGroupStart := sync.WaitGroup{} + waitGroupStart.Add(2) + waitGroupEnd := sync.WaitGroup{} + waitGroupEnd.Add(2) + channel := make(chan struct{}) + + clientCluster.EXPECT().Recv().DoAndReturn(func() (*adminservice.StreamWorkflowReplicationMessagesRequest, error) { + waitGroupStart.Done() + waitGroupStart.Wait() + + defer waitGroupEnd.Done() + return nil, serviceerror.NewUnavailable("random error") + }) + serverCluster.EXPECT().Recv().DoAndReturn(func() (*historyservice.StreamWorkflowReplicationMessagesResponse, error) { + waitGroupStart.Done() + waitGroupStart.Wait() + + defer waitGroupEnd.Done() + <-channel + return nil, serviceerror.NewUnavailable("random error") + }) + _ = s.handler.StreamWorkflowReplicationMessages(clientCluster) + close(channel) + waitGroupEnd.Wait() +} + +func (s *adminHandlerSuite) TestStreamWorkflowReplicationMessages_ServerToClientBroken() { + clientClusterShardID := historyclient.ClusterShardID{ + ClusterID: rand.Int31(), + ShardID: rand.Int31(), + } + serverClusterShardID := historyclient.ClusterShardID{ + ClusterID: rand.Int31(), + ShardID: rand.Int31(), + } + clusterShardMD := historyclient.EncodeClusterShardMD( + clientClusterShardID, + serverClusterShardID, + ) + ctx := metadata.NewIncomingContext(context.Background(), clusterShardMD) + clientCluster := adminservicemock.NewMockAdminService_StreamWorkflowReplicationMessagesServer(s.controller) + clientCluster.EXPECT().Context().Return(ctx).AnyTimes() + serverCluster := historyservicemock.NewMockHistoryService_StreamWorkflowReplicationMessagesClient(s.controller) + s.mockHistoryClient.EXPECT().StreamWorkflowReplicationMessages(ctx).Return(serverCluster, nil) + + waitGroupStart := sync.WaitGroup{} + waitGroupStart.Add(2) + waitGroupEnd := sync.WaitGroup{} + waitGroupEnd.Add(2) + channel := make(chan struct{}) + + clientCluster.EXPECT().Recv().DoAndReturn(func() (*adminservice.StreamWorkflowReplicationMessagesRequest, error) { + waitGroupStart.Done() + waitGroupStart.Wait() + + defer waitGroupEnd.Done() + <-channel + return nil, serviceerror.NewUnavailable("random error") + }) + serverCluster.EXPECT().Recv().DoAndReturn(func() (*historyservice.StreamWorkflowReplicationMessagesResponse, error) { + waitGroupStart.Done() + waitGroupStart.Wait() + + defer waitGroupEnd.Done() + return nil, serviceerror.NewUnavailable("random error") + }) + _ = s.handler.StreamWorkflowReplicationMessages(clientCluster) + close(channel) + waitGroupEnd.Wait() +} diff --git a/service/history/replication/stream_receiver_monitor.go b/service/history/replication/stream_receiver_monitor.go index 017b9d6edbb..9883d601265 100644 --- a/service/history/replication/stream_receiver_monitor.go +++ b/service/history/replication/stream_receiver_monitor.go @@ -36,7 +36,7 @@ import ( ) const ( - streamReceiverMonitorInterval = 5 * time.Second + streamReceiverMonitorInterval = 2 * time.Second ) type (