Skip to content

Commit

Permalink
Better handle admin handler stream replication API lifecycle (#4647)
Browse files Browse the repository at this point in the history
Make sure admin handler stream replication API is able to return if client -> server or server -> client link breaks
  • Loading branch information
wxing1292 authored and dnr committed Jul 21, 2023
1 parent 7a125d9 commit 32b4219
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 27 deletions.
76 changes: 50 additions & 26 deletions service/frontend/adminHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ import (
"sync/atomic"
"time"

"golang.org/x/sync/errgroup"
"google.golang.org/grpc/metadata"

"go.temporal.io/server/client/history"
"go.temporal.io/server/common/channel"
"go.temporal.io/server/common/clock"
"go.temporal.io/server/common/primitives"
"go.temporal.io/server/common/util"
Expand Down Expand Up @@ -1962,62 +1964,84 @@ func (adh *AdminHandler) getWorkflowCompletionEvent(
}

func (adh *AdminHandler) StreamWorkflowReplicationMessages(
targetCluster adminservice.AdminService_StreamWorkflowReplicationMessagesServer,
clientCluster adminservice.AdminService_StreamWorkflowReplicationMessagesServer,
) (retError error) {
defer log.CapturePanic(adh.logger, &retError)

ctx := targetCluster.Context()
sourceCluster, err := adh.historyClient.StreamWorkflowReplicationMessages(ctx)
ctxMetadata, ok := metadata.FromIncomingContext(clientCluster.Context())
if !ok {
return serviceerror.NewInvalidArgument("missing cluster & shard ID metadata")
}
_, serverClusterShardID, err := history.DecodeClusterShardMD(ctxMetadata)
if err != nil {
return err
}

logger := log.With(adh.logger, tag.ShardID(serverClusterShardID.ShardID))
logger.Info("AdminStreamReplicationMessages started.")
defer logger.Info("AdminStreamReplicationMessages stopped.")

ctx := clientCluster.Context()
serverCluster, err := adh.historyClient.StreamWorkflowReplicationMessages(ctx)
if err != nil {
return err
}

errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
for ctx.Err() == nil {
req, err := targetCluster.Recv()
shutdownChan := channel.NewShutdownOnce()
go func() {
defer shutdownChan.Shutdown()

for !shutdownChan.IsShutdown() {
req, err := clientCluster.Recv()
if err != nil {
return err
logger.Info("AdminStreamReplicationMessages client -> server encountered error", tag.Error(err))
return
}
switch attr := req.GetAttributes().(type) {
case *adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState:
if err = sourceCluster.Send(&historyservice.StreamWorkflowReplicationMessagesRequest{
if err = serverCluster.Send(&historyservice.StreamWorkflowReplicationMessagesRequest{
Attributes: &historyservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{
SyncReplicationState: attr.SyncReplicationState,
},
}); err != nil {
return err
logger.Info("AdminStreamReplicationMessages client -> server encountered error", tag.Error(err))
return
}
default:
return serviceerror.NewInternal(fmt.Sprintf(
logger.Info("AdminStreamReplicationMessages client -> server encountered error", tag.Error(serviceerror.NewInternal(fmt.Sprintf(
"StreamWorkflowReplicationMessages encountered unknown type: %T %v", attr, attr,
))
))))
return
}
}
return ctx.Err()
})
errGroup.Go(func() error {
for ctx.Err() == nil {
resp, err := sourceCluster.Recv()
}()
go func() {
defer shutdownChan.Shutdown()

for !shutdownChan.IsShutdown() {
resp, err := serverCluster.Recv()
if err != nil {
return err
logger.Info("AdminStreamReplicationMessages server -> client encountered error", tag.Error(err))
return
}
switch attr := resp.GetAttributes().(type) {
case *historyservice.StreamWorkflowReplicationMessagesResponse_Messages:
if err = targetCluster.Send(&adminservice.StreamWorkflowReplicationMessagesResponse{
if err = clientCluster.Send(&adminservice.StreamWorkflowReplicationMessagesResponse{
Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{
Messages: attr.Messages,
},
}); err != nil {
return err
logger.Info("AdminStreamReplicationMessages server -> client encountered error", tag.Error(err))
return
}
default:
return serviceerror.NewInternal(fmt.Sprintf(
logger.Info("AdminStreamReplicationMessages server -> client encountered error", tag.Error(serviceerror.NewInternal(fmt.Sprintf(
"StreamWorkflowReplicationMessages encountered unknown type: %T %v", attr, attr,
))
))))
return
}
}
return ctx.Err()
})
return errGroup.Wait()
}()
<-shutdownChan.Channel()
return nil
}
95 changes: 95 additions & 0 deletions service/frontend/adminHandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,14 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"sync"
"testing"
"time"

"google.golang.org/grpc/metadata"

historyclient "go.temporal.io/server/client/history"
"go.temporal.io/server/common/clock"
"go.temporal.io/server/common/persistence/visibility/store/standard/cassandra"
"go.temporal.io/server/common/primitives"
Expand Down Expand Up @@ -1445,3 +1450,93 @@ func (s *adminHandlerSuite) TestDeleteWorkflowExecution_CassandraVisibilityBacke
_, err = s.handler.DeleteWorkflowExecution(context.Background(), request)
s.NoError(err)
}

func (s *adminHandlerSuite) TestStreamWorkflowReplicationMessages_ClientToServerBroken() {
clientClusterShardID := historyclient.ClusterShardID{
ClusterID: rand.Int31(),
ShardID: rand.Int31(),
}
serverClusterShardID := historyclient.ClusterShardID{
ClusterID: rand.Int31(),
ShardID: rand.Int31(),
}
clusterShardMD := historyclient.EncodeClusterShardMD(
clientClusterShardID,
serverClusterShardID,
)
ctx := metadata.NewIncomingContext(context.Background(), clusterShardMD)
clientCluster := adminservicemock.NewMockAdminService_StreamWorkflowReplicationMessagesServer(s.controller)
clientCluster.EXPECT().Context().Return(ctx).AnyTimes()
serverCluster := historyservicemock.NewMockHistoryService_StreamWorkflowReplicationMessagesClient(s.controller)
s.mockHistoryClient.EXPECT().StreamWorkflowReplicationMessages(ctx).Return(serverCluster, nil)

waitGroupStart := sync.WaitGroup{}
waitGroupStart.Add(2)
waitGroupEnd := sync.WaitGroup{}
waitGroupEnd.Add(2)
channel := make(chan struct{})

clientCluster.EXPECT().Recv().DoAndReturn(func() (*adminservice.StreamWorkflowReplicationMessagesRequest, error) {
waitGroupStart.Done()
waitGroupStart.Wait()

defer waitGroupEnd.Done()
return nil, serviceerror.NewUnavailable("random error")
})
serverCluster.EXPECT().Recv().DoAndReturn(func() (*historyservice.StreamWorkflowReplicationMessagesResponse, error) {
waitGroupStart.Done()
waitGroupStart.Wait()

defer waitGroupEnd.Done()
<-channel
return nil, serviceerror.NewUnavailable("random error")
})
_ = s.handler.StreamWorkflowReplicationMessages(clientCluster)
close(channel)
waitGroupEnd.Wait()
}

func (s *adminHandlerSuite) TestStreamWorkflowReplicationMessages_ServerToClientBroken() {
clientClusterShardID := historyclient.ClusterShardID{
ClusterID: rand.Int31(),
ShardID: rand.Int31(),
}
serverClusterShardID := historyclient.ClusterShardID{
ClusterID: rand.Int31(),
ShardID: rand.Int31(),
}
clusterShardMD := historyclient.EncodeClusterShardMD(
clientClusterShardID,
serverClusterShardID,
)
ctx := metadata.NewIncomingContext(context.Background(), clusterShardMD)
clientCluster := adminservicemock.NewMockAdminService_StreamWorkflowReplicationMessagesServer(s.controller)
clientCluster.EXPECT().Context().Return(ctx).AnyTimes()
serverCluster := historyservicemock.NewMockHistoryService_StreamWorkflowReplicationMessagesClient(s.controller)
s.mockHistoryClient.EXPECT().StreamWorkflowReplicationMessages(ctx).Return(serverCluster, nil)

waitGroupStart := sync.WaitGroup{}
waitGroupStart.Add(2)
waitGroupEnd := sync.WaitGroup{}
waitGroupEnd.Add(2)
channel := make(chan struct{})

clientCluster.EXPECT().Recv().DoAndReturn(func() (*adminservice.StreamWorkflowReplicationMessagesRequest, error) {
waitGroupStart.Done()
waitGroupStart.Wait()

defer waitGroupEnd.Done()
<-channel
return nil, serviceerror.NewUnavailable("random error")
})
serverCluster.EXPECT().Recv().DoAndReturn(func() (*historyservice.StreamWorkflowReplicationMessagesResponse, error) {
waitGroupStart.Done()
waitGroupStart.Wait()

defer waitGroupEnd.Done()
return nil, serviceerror.NewUnavailable("random error")
})
_ = s.handler.StreamWorkflowReplicationMessages(clientCluster)
close(channel)
waitGroupEnd.Wait()
}
2 changes: 1 addition & 1 deletion service/history/replication/stream_receiver_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import (
)

const (
streamReceiverMonitorInterval = 5 * time.Second
streamReceiverMonitorInterval = 2 * time.Second
)

type (
Expand Down

0 comments on commit 32b4219

Please sign in to comment.