diff --git a/service/history/workflow/update/registry.go b/service/history/workflow/update/registry.go index 1f9e315191f..70f649db65e 100644 --- a/service/history/workflow/update/registry.go +++ b/service/history/workflow/update/registry.go @@ -36,6 +36,7 @@ import ( protocolpb "go.temporal.io/api/protocol/v1" "go.temporal.io/api/serviceerror" updatepb "go.temporal.io/api/update/v1" + updatespb "go.temporal.io/server/api/update/v1" "go.temporal.io/server/common/future" "go.temporal.io/server/common/log" @@ -208,7 +209,7 @@ func (r *RegistryImpl) ReadOutgoingMessages( } // TODO (alex-update): currently sequencing_id is simply pointing to the - // event before WorkflowTaskStartedEvent. SDKs are supposed to respect this + // event before WorkflowTaskStartedEvent. SDKs are supposed to respect this // and process messages (specifically, updates) after event with that ID. // In the future, sequencing_id could point to some specific event // (specifically, signal) after which the update should be processed. diff --git a/tests/integration_test.go b/tests/integration_test.go index f4e4043f9cf..314a13f356c 100644 --- a/tests/integration_test.go +++ b/tests/integration_test.go @@ -37,6 +37,8 @@ import ( commonpb "go.temporal.io/api/common/v1" "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/server/api/adminservice/v1" + "go.temporal.io/server/common" "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/payloads" ) @@ -85,6 +87,20 @@ func (s *integrationSuite) sendSignal(namespace string, execution *commonpb.Work return err } +func (s *integrationSuite) closeShard(wid string) { + s.T().Helper() + + resp, err := s.engine.DescribeNamespace(NewContext(), &workflowservice.DescribeNamespaceRequest{ + Namespace: s.namespace, + }) + s.NoError(err) + + _, err = s.adminClient.CloseShard(NewContext(), &adminservice.CloseShardRequest{ + ShardId: common.WorkflowIDToHistoryShard(resp.NamespaceInfo.Id, wid, s.testClusterConfig.HistoryConfig.NumHistoryShards), + }) + s.NoError(err) +} + func unmarshalAny[T proto.Message](s *integrationSuite, a *types.Any) T { s.T().Helper() pb := new(T) diff --git a/tests/testvars/test_vars.go b/tests/testvars/test_vars.go index c7b01310ee5..99affad35c9 100644 --- a/tests/testvars/test_vars.go +++ b/tests/testvars/test_vars.go @@ -191,6 +191,14 @@ func (tv *TestVars) WithHandlerName(handlerName string, key ...string) *TestVars return tv.cloneSet("handler_name", handlerName, key) } +func (tv *TestVars) WorkerIdentity(key ...string) string { + return tv.getOrCreate("worker_identity", key) +} + +func (tv *TestVars) WithWorkerIdentity(identity string, key ...string) *TestVars { + return tv.cloneSet("worker_identity", identity, key) +} + // ----------- Generic methods ------------ func (tv *TestVars) InfiniteTimeout() *time.Duration { diff --git a/tests/update_workflow_test.go b/tests/update_workflow_test.go index f61c1f87b3e..52030c3e15a 100644 --- a/tests/update_workflow_test.go +++ b/tests/update_workflow_test.go @@ -2985,3 +2985,287 @@ func (s *integrationSuite) TestUpdateWorkflow_SpeculativeWorkflowTask_Heartbeat( 14 WorkflowTaskCompleted 15 WorkflowExecutionCompleted`, events) } + +func (s *integrationSuite) TestUpdateWorkflow_NewScheduledSpeculativeWorkflowTaskLost_BecauseOfShardMove() { + tv := testvars.New(s.T().Name()) + + tv = s.startWorkflow(tv) + + wtHandlerCalls := 0 + wtHandler := func(execution *commonpb.WorkflowExecution, wt *commonpb.WorkflowType, previousStartedEventID, startedEventID int64, history *historypb.History) ([]*commandpb.Command, error) { + wtHandlerCalls++ + switch wtHandlerCalls { + case 1: + // Completes first WT with update unrelated command. + return []*commandpb.Command{{ + CommandType: enumspb.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK, + Attributes: &commandpb.Command_ScheduleActivityTaskCommandAttributes{ScheduleActivityTaskCommandAttributes: &commandpb.ScheduleActivityTaskCommandAttributes{ + ActivityId: tv.ActivityID(), + ActivityType: tv.ActivityType(), + TaskQueue: tv.TaskQueue(), + ScheduleToCloseTimeout: tv.InfiniteTimeout(), + }}, + }}, nil + case 2: + s.EqualHistory(` + 1 WorkflowExecutionStarted + 2 WorkflowTaskScheduled + 3 WorkflowTaskStarted + 4 WorkflowTaskCompleted + 5 ActivityTaskScheduled + 6 WorkflowExecutionSignaled + 7 WorkflowTaskScheduled + 8 WorkflowTaskStarted`, history) + return []*commandpb.Command{{ + CommandType: enumspb.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION, + Attributes: &commandpb.Command_CompleteWorkflowExecutionCommandAttributes{CompleteWorkflowExecutionCommandAttributes: &commandpb.CompleteWorkflowExecutionCommandAttributes{}}, + }}, nil + default: + s.Failf("wtHandler called too many times", "wtHandler shouldn't be called %d times", wtHandlerCalls) + return nil, nil + } + } + + msgHandlerCalls := 0 + msgHandler := func(task *workflowservice.PollWorkflowTaskQueueResponse) ([]*protocolpb.Message, error) { + msgHandlerCalls++ + switch msgHandlerCalls { + case 1: + return nil, nil + case 2: + s.Empty(task.Messages, "update must be lost due to shard reload") + return nil, nil + default: + s.Failf("msgHandler called too many times", "msgHandler shouldn't be called %d times", msgHandlerCalls) + return nil, nil + } + } + + poller := &TaskPoller{ + Engine: s.engine, + Namespace: s.namespace, + TaskQueue: tv.TaskQueue(), + WorkflowTaskHandler: wtHandler, + MessageHandler: msgHandler, + Logger: s.Logger, + T: s.T(), + } + + // Drain exiting first workflow task. + _, err := poller.PollAndProcessWorkflowTask(true, false) + s.NoError(err) + + updateResultCh := make(chan struct{}) + updateWorkflowFn := func() { + halfSecondTimeoutCtx, cancel := context.WithTimeout(NewContext(), 500*time.Millisecond) + defer cancel() + + updateResponse, err1 := s.engine.UpdateWorkflowExecution(halfSecondTimeoutCtx, &workflowservice.UpdateWorkflowExecutionRequest{ + Namespace: s.namespace, + WorkflowExecution: tv.WorkflowExecution(), + Request: &updatepb.Request{ + Meta: &updatepb.Meta{UpdateId: tv.UpdateID("1")}, + Input: &updatepb.Input{ + Name: tv.Any(), + Args: payloads.EncodeString(tv.Any()), + }, + }, + }) + assert.Error(s.T(), err1) + assert.True(s.T(), common.IsContextDeadlineExceededErr(err1), err1) + assert.Nil(s.T(), updateResponse) + + updateResultCh <- struct{}{} + } + go updateWorkflowFn() + + // Close shard, Speculative WT with update will be lost. + s.closeShard(tv.WorkflowID()) + + pollCtx, cancel := context.WithTimeout(NewContext(), common.MinLongPollTimeout+100*time.Millisecond) + defer cancel() + pollResponse, err := s.engine.PollWorkflowTaskQueue(pollCtx, &workflowservice.PollWorkflowTaskQueueRequest{ + Namespace: s.namespace, + TaskQueue: tv.TaskQueue(), + Identity: tv.WorkerIdentity(), + }) + s.NoError(err) + s.Nil(pollResponse.Messages) + + <-updateResultCh + + // Send signal to schedule new WT. + err = s.sendSignal(s.namespace, tv.WorkflowExecution(), tv.Any(), payloads.EncodeString(tv.Any()), tv.Any()) + s.NoError(err) + + // Complete workflow. + completeWorkflowResp, err := poller.PollAndProcessWorkflowTask(false, false) + s.NoError(err) + s.NotNil(completeWorkflowResp) + + s.Equal(2, wtHandlerCalls) + s.Equal(2, msgHandlerCalls) + + events := s.getHistory(s.namespace, tv.WorkflowExecution()) + + s.EqualHistoryEvents(` + 1 WorkflowExecutionStarted + 2 WorkflowTaskScheduled + 3 WorkflowTaskStarted + 4 WorkflowTaskCompleted + 5 ActivityTaskScheduled + 6 WorkflowExecutionSignaled + 7 WorkflowTaskScheduled + 8 WorkflowTaskStarted + 9 WorkflowTaskCompleted + 10 WorkflowExecutionCompleted`, events) +} + +func (s *integrationSuite) TestUpdateWorkflow_NewStartedSpeculativeWorkflowTaskLost_BecauseOfShardMove() { + tv := testvars.New(s.T().Name()) + + tv = s.startWorkflow(tv) + + wtHandlerCalls := 0 + wtHandler := func(execution *commonpb.WorkflowExecution, wt *commonpb.WorkflowType, previousStartedEventID, startedEventID int64, history *historypb.History) ([]*commandpb.Command, error) { + wtHandlerCalls++ + switch wtHandlerCalls { + case 1: + // Completes first WT with update unrelated command. + return []*commandpb.Command{{ + CommandType: enumspb.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK, + Attributes: &commandpb.Command_ScheduleActivityTaskCommandAttributes{ScheduleActivityTaskCommandAttributes: &commandpb.ScheduleActivityTaskCommandAttributes{ + ActivityId: tv.ActivityID(), + ActivityType: tv.ActivityType(), + TaskQueue: tv.TaskQueue(), + ScheduleToCloseTimeout: tv.InfiniteTimeout(), + }}, + }}, nil + case 2: + // Speculative WT, with update.Request message. + s.EqualHistory(` + 1 WorkflowExecutionStarted + 2 WorkflowTaskScheduled + 3 WorkflowTaskStarted + 4 WorkflowTaskCompleted + 5 ActivityTaskScheduled + 6 WorkflowTaskScheduled + 7 WorkflowTaskStarted`, history) + + // Close shard, Speculative WT (6 and 7) will be lost, and NotFound error will be returned to RespondWorkflowTaskCompleted. + s.closeShard(tv.WorkflowID()) + + return s.acceptCompleteUpdateCommands(tv, "1"), nil + case 3: + s.EqualHistory(` + 1 WorkflowExecutionStarted + 2 WorkflowTaskScheduled + 3 WorkflowTaskStarted + 4 WorkflowTaskCompleted + 5 ActivityTaskScheduled + 6 WorkflowExecutionSignaled + 7 WorkflowTaskScheduled + 8 WorkflowTaskStarted`, history) + return []*commandpb.Command{{ + CommandType: enumspb.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION, + Attributes: &commandpb.Command_CompleteWorkflowExecutionCommandAttributes{CompleteWorkflowExecutionCommandAttributes: &commandpb.CompleteWorkflowExecutionCommandAttributes{}}, + }}, nil + default: + s.Failf("wtHandler called too many times", "wtHandler shouldn't be called %d times", wtHandlerCalls) + return nil, nil + } + } + + msgHandlerCalls := 0 + msgHandler := func(task *workflowservice.PollWorkflowTaskQueueResponse) ([]*protocolpb.Message, error) { + msgHandlerCalls++ + switch msgHandlerCalls { + case 1: + return nil, nil + case 2: + updRequestMsg := task.Messages[0] + s.EqualValues(6, updRequestMsg.GetEventId()) + + return s.acceptCompleteUpdateMessages(tv, updRequestMsg, "1"), nil + case 3: + s.Empty(task.Messages, "update must be lost due to shard reload") + return nil, nil + default: + s.Failf("msgHandler called too many times", "msgHandler shouldn't be called %d times", msgHandlerCalls) + return nil, nil + } + } + + poller := &TaskPoller{ + Engine: s.engine, + Namespace: s.namespace, + TaskQueue: tv.TaskQueue(), + WorkflowTaskHandler: wtHandler, + MessageHandler: msgHandler, + Logger: s.Logger, + T: s.T(), + } + + // Drain exiting first workflow task. + _, err := poller.PollAndProcessWorkflowTask(true, false) + s.NoError(err) + + updateResultCh := make(chan struct{}) + updateWorkflowFn := func() { + halfSecondTimeoutCtx, cancel := context.WithTimeout(NewContext(), 500*time.Millisecond) + defer cancel() + + updateResponse, err1 := s.engine.UpdateWorkflowExecution(halfSecondTimeoutCtx, &workflowservice.UpdateWorkflowExecutionRequest{ + Namespace: s.namespace, + WorkflowExecution: tv.WorkflowExecution(), + Request: &updatepb.Request{ + Meta: &updatepb.Meta{UpdateId: tv.UpdateID("1")}, + Input: &updatepb.Input{ + Name: tv.Any(), + Args: payloads.EncodeString(tv.Any()), + }, + }, + }) + assert.Error(s.T(), err1) + assert.True(s.T(), common.IsContextDeadlineExceededErr(err1), err1) + assert.Nil(s.T(), updateResponse) + + updateResultCh <- struct{}{} + } + go updateWorkflowFn() + + // Process update in workflow. + _, updateResp, err := poller.PollAndProcessWorkflowTaskWithAttemptAndRetryAndForceNewWorkflowTask(false, false, false, false, 1, 1, true, nil) + s.Error(err) + s.IsType(&serviceerror.NotFound{}, err) + s.ErrorContains(err, "Workflow task not found") + s.Nil(updateResp) + + <-updateResultCh + + // Send signal to schedule new WT. + err = s.sendSignal(s.namespace, tv.WorkflowExecution(), tv.Any(), payloads.EncodeString(tv.Any()), tv.Any()) + s.NoError(err) + + // Complete workflow. + completeWorkflowResp, err := poller.PollAndProcessWorkflowTask(false, false) + s.NoError(err) + s.NotNil(completeWorkflowResp) + + s.Equal(3, wtHandlerCalls) + s.Equal(3, msgHandlerCalls) + + events := s.getHistory(s.namespace, tv.WorkflowExecution()) + + s.EqualHistoryEvents(` + 1 WorkflowExecutionStarted + 2 WorkflowTaskScheduled + 3 WorkflowTaskStarted + 4 WorkflowTaskCompleted + 5 ActivityTaskScheduled + 6 WorkflowExecutionSignaled + 7 WorkflowTaskScheduled + 8 WorkflowTaskStarted + 9 WorkflowTaskCompleted + 10 WorkflowExecutionCompleted`, events) +}