From 3f0e5721e4b0470a77387df6b417af3bd4451b6f Mon Sep 17 00:00:00 2001 From: Alex Shtin Date: Fri, 26 May 2023 12:31:00 -0700 Subject: [PATCH 1/2] Allow workflow to complete if there are updates in the registry --- service/history/workflow/update/registry.go | 23 ++-- service/history/workflowTaskHandler.go | 19 +-- tests/update_workflow_test.go | 139 +++++++++++++++++++- 3 files changed, 152 insertions(+), 29 deletions(-) diff --git a/service/history/workflow/update/registry.go b/service/history/workflow/update/registry.go index 85afea3beb9..44239761e54 100644 --- a/service/history/workflow/update/registry.go +++ b/service/history/workflow/update/registry.go @@ -35,6 +35,7 @@ import ( protocolpb "go.temporal.io/api/protocol/v1" "go.temporal.io/api/serviceerror" updatepb "go.temporal.io/api/update/v1" + persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common/future" "go.temporal.io/server/common/log" @@ -55,14 +56,13 @@ type ( // new update if no update is found. Find(ctx context.Context, protocolInstanceID string) (*Update, bool) - // ReadOutoundMessages polls each registered Update for outbound + // ReadOutgoingMessages polls each registered Update for outbound // messages and returns them. ReadOutgoingMessages(startedEventID int64) ([]*protocolpb.Message, error) - // HasUndeliveredUpdates returns true if the registry has Updates that - // are not known to have been seen by user workflow code. In practice - // this means updates that have not yet been accepted or rejected. - HasUndeliveredUpdates() bool + // TerminateUpdates terminates all existing updates in the registry + // and notifies update aPI callers with corresponding error. + TerminateUpdates(ctx context.Context, eventStore EventStore) // HasOutgoing returns true if the registry has any Updates that want to // sent messages to a worker. @@ -167,15 +167,10 @@ func (r *RegistryImpl) Find(ctx context.Context, id string) (*Update, bool) { return r.findLocked(ctx, id) } -func (r *RegistryImpl) HasUndeliveredUpdates() bool { - r.mu.RLock() - defer r.mu.RUnlock() - for _, upd := range r.updates { - if !upd.hasBeenSeenByWorkflowExecution() { - return true - } - } - return false +func (r *RegistryImpl) TerminateUpdates(_ context.Context, _ EventStore) { + // TODO (alex-update): implement + // This method is not implemented and update API callers will just timeout. + // In future, it should remove all existing updates and notify callers with better error. } func (r *RegistryImpl) HasOutgoing() bool { diff --git a/service/history/workflowTaskHandler.go b/service/history/workflowTaskHandler.go index 90684d2b0e9..1fa4fa677a6 100644 --- a/service/history/workflowTaskHandler.go +++ b/service/history/workflowTaskHandler.go @@ -37,6 +37,7 @@ import ( protocolpb "go.temporal.io/api/protocol/v1" "go.temporal.io/api/serviceerror" "go.temporal.io/api/workflowservice/v1" + "go.temporal.io/server/internal/effect" "go.temporal.io/server/internal/protocol" "go.temporal.io/server/service/history/workflow/update" @@ -568,9 +569,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandCompleteWorkflow( return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_COMMAND, nil) } - if handler.updateRegistry.HasUndeliveredUpdates() { - return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_UPDATE, nil) - } + handler.updateRegistry.TerminateUpdates(ctx, workflow.WithEffects(handler.effects, handler.mutableState)) if err := handler.validateCommandAttr( func() (enumspb.WorkflowTaskFailedCause, error) { @@ -630,9 +629,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandFailWorkflow( return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_COMMAND, nil) } - if handler.updateRegistry.HasUndeliveredUpdates() { - return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_UPDATE, nil) - } + handler.updateRegistry.TerminateUpdates(ctx, workflow.WithEffects(handler.effects, handler.mutableState)) if err := handler.validateCommandAttr( func() (enumspb.WorkflowTaskFailedCause, error) { @@ -726,7 +723,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandCancelTimer( } func (handler *workflowTaskHandlerImpl) handleCommandCancelWorkflow( - _ context.Context, + ctx context.Context, attr *commandpb.CancelWorkflowExecutionCommandAttributes, ) error { @@ -736,9 +733,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandCancelWorkflow( return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_COMMAND, nil) } - if handler.updateRegistry.HasUndeliveredUpdates() { - return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_UPDATE, nil) - } + handler.updateRegistry.TerminateUpdates(ctx, workflow.WithEffects(handler.effects, handler.mutableState)) if err := handler.validateCommandAttr( func() (enumspb.WorkflowTaskFailedCause, error) { @@ -843,9 +838,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandContinueAsNewWorkflow( return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_COMMAND, nil) } - if handler.updateRegistry.HasUndeliveredUpdates() { - return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_UPDATE, nil) - } + handler.updateRegistry.TerminateUpdates(ctx, workflow.WithEffects(handler.effects, handler.mutableState)) namespaceName := handler.mutableState.GetNamespaceEntry().Name() diff --git a/tests/update_workflow_test.go b/tests/update_workflow_test.go index 629beef0346..5c1f43af027 100644 --- a/tests/update_workflow_test.go +++ b/tests/update_workflow_test.go @@ -1781,7 +1781,7 @@ func (s *integrationSuite) TestUpdateWorkflow_FailWorkflowTask() { // Complete workflow. _, err = poller.PollAndProcessWorkflowTask(false, false) - s.Error(err, "update was never successfully accepted so it prevents completion") + s.NoError(err) s.Equal(5, wtHandlerCalls) s.Equal(5, msgHandlerCalls) @@ -1795,7 +1795,11 @@ func (s *integrationSuite) TestUpdateWorkflow_FailWorkflowTask() { 5 ActivityTaskScheduled 6 WorkflowTaskScheduled 7 WorkflowTaskStarted - 8 WorkflowTaskFailed`, events) + 8 WorkflowTaskFailed + 9 WorkflowTaskScheduled + 10 WorkflowTaskStarted + 11 WorkflowTaskCompleted + 12 WorkflowExecutionCompleted`, events) } func (s *integrationSuite) TestUpdateWorkflow_ConvertStartedSpeculativeWorkflowTaskToNormal_BecauseOfBufferedSignal() { @@ -2559,3 +2563,134 @@ func (s *integrationSuite) TestUpdateWorkflow_ScheduledSpeculativeWorkflowTask_T // completion_event_batch_id should point to WFTerminated event. s.EqualValues(6, msResp.GetDatabaseMutableState().GetExecutionInfo().GetCompletionEventBatchId()) } + +func (s *integrationSuite) TestUpdateWorkflow_CompleteWorkflow_CancelUpdate() { + testCases := []struct { + Name string + UpdateErrMsg string + Commands func(tv *testvars.TestVars) []*commandpb.Command + Messages func(tv *testvars.TestVars, updRequestMsg *protocolpb.Message) []*protocolpb.Message + }{ + { + Name: "requested", + UpdateErrMsg: "update canceled", + Commands: func(_ *testvars.TestVars) []*commandpb.Command { return nil }, + Messages: func(_ *testvars.TestVars, _ *protocolpb.Message) []*protocolpb.Message { return nil }, + }, + { + Name: "accepted", + UpdateErrMsg: "update canceled", + Commands: func(tv *testvars.TestVars) []*commandpb.Command { return s.acceptUpdateCommands(tv, "1") }, + Messages: func(tv *testvars.TestVars, updRequestMsg *protocolpb.Message) []*protocolpb.Message { + return s.acceptUpdateMessages(tv, updRequestMsg, "1") + }, + }, + { + Name: "completed", + UpdateErrMsg: "", + Commands: func(tv *testvars.TestVars) []*commandpb.Command { return s.acceptCompleteUpdateCommands(tv, "1") }, + Messages: func(tv *testvars.TestVars, updRequestMsg *protocolpb.Message) []*protocolpb.Message { + return s.acceptCompleteUpdateMessages(tv, updRequestMsg, "1") + }, + }, + } + + for _, tc := range testCases { + s.T().Run(tc.Name, func(t *testing.T) { + tv := testvars.New(t.Name()) + + tv = s.startWorkflow(tv) + + wtHandlerCalls := 0 + wtHandler := func(execution *commonpb.WorkflowExecution, wt *commonpb.WorkflowType, previousStartedEventID, startedEventID int64, history *historypb.History) ([]*commandpb.Command, error) { + wtHandlerCalls++ + switch wtHandlerCalls { + case 1: + // Completes first WT with update unrelated command. + return []*commandpb.Command{{ + CommandType: enumspb.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK, + Attributes: &commandpb.Command_ScheduleActivityTaskCommandAttributes{ScheduleActivityTaskCommandAttributes: &commandpb.ScheduleActivityTaskCommandAttributes{ + ActivityId: tv.ActivityID(), + ActivityType: tv.ActivityType(), + TaskQueue: tv.TaskQueue(), + ScheduleToCloseTimeout: tv.InfiniteTimeout(), + }}, + }}, nil + case 2: + return append(tc.Commands(tv), &commandpb.Command{ + CommandType: enumspb.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION, + Attributes: &commandpb.Command_CompleteWorkflowExecutionCommandAttributes{CompleteWorkflowExecutionCommandAttributes: &commandpb.CompleteWorkflowExecutionCommandAttributes{}}, + }), nil + default: + s.Failf("wtHandler called too many times", "wtHandler shouldn't be called %d times", wtHandlerCalls) + return nil, nil + } + } + + msgHandlerCalls := 0 + msgHandler := func(task *workflowservice.PollWorkflowTaskQueueResponse) ([]*protocolpb.Message, error) { + msgHandlerCalls++ + switch msgHandlerCalls { + case 1: + return nil, nil + case 2: + updRequestMsg := task.Messages[0] + return tc.Messages(tv, updRequestMsg), nil + default: + s.Failf("msgHandler called too many times", "msgHandler shouldn't be called %d times", msgHandlerCalls) + return nil, nil + } + } + + poller := &TaskPoller{ + Engine: s.engine, + Namespace: s.namespace, + TaskQueue: tv.TaskQueue(), + WorkflowTaskHandler: wtHandler, + MessageHandler: msgHandler, + Logger: s.Logger, + T: s.T(), + } + + // Drain exiting first workflow task. + _, err := poller.PollAndProcessWorkflowTask(true, false) + s.NoError(err) + + updateResultCh := make(chan struct{}) + go func(updateErrMsg string) { + halfSecondTimeoutCtx, cancel := context.WithTimeout(NewContext(), 500*time.Millisecond) + defer cancel() + + resp, err1 := s.engine.UpdateWorkflowExecution(halfSecondTimeoutCtx, &workflowservice.UpdateWorkflowExecutionRequest{ + Namespace: s.namespace, + WorkflowExecution: tv.WorkflowExecution(), + Request: &updatepb.Request{ + Meta: &updatepb.Meta{UpdateId: tv.UpdateID("1")}, + Input: &updatepb.Input{ + Name: tv.HandlerName(), + Args: payloads.EncodeString(tv.String("args", "1")), + }, + }, + }) + + if updateErrMsg == "" { + s.NoError(err1) + s.NotNil(resp) + } else { + s.Error(err1) + s.True(common.IsContextDeadlineExceededErr(err1)) + s.Nil(resp) + } + updateResultCh <- struct{}{} + }(tc.UpdateErrMsg) + + // Complete workflow. + _, err = poller.PollAndProcessWorkflowTask(false, false) + s.NoError(err) + <-updateResultCh + + s.Equal(2, wtHandlerCalls) + s.Equal(2, msgHandlerCalls) + }) + } +} From d0c187dd5ade3e0a3442fa8e8ba22fcdf6179d2a Mon Sep 17 00:00:00 2001 From: Alex Shtin Date: Fri, 26 May 2023 13:01:12 -0700 Subject: [PATCH 2/2] Fix unit tests --- service/history/workflow/update/registry_test.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/service/history/workflow/update/registry_test.go b/service/history/workflow/update/registry_test.go index c7eb3840c08..07c7bcc668f 100644 --- a/service/history/workflow/update/registry_test.go +++ b/service/history/workflow/update/registry_test.go @@ -35,6 +35,7 @@ import ( failurepb "go.temporal.io/api/failure/v1" "go.temporal.io/api/serviceerror" updatepb "go.temporal.io/api/update/v1" + historyspb "go.temporal.io/server/api/history/v1" persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/internal/effect" @@ -104,7 +105,6 @@ func TestFind(t *testing.T) { _, found, err := reg.FindOrCreate(ctx, updateID) require.NoError(t, err) require.False(t, found) - require.True(t, reg.HasUndeliveredUpdates()) _, ok = reg.Find(ctx, updateID) require.True(t, ok) @@ -195,8 +195,6 @@ func TestFindOrCreate(t *testing.T) { reg = update.NewRegistry(store) ) - require.False(t, reg.HasUndeliveredUpdates()) - t.Run("new update", func(t *testing.T) { updateID := "a completely new update ID" _, found, err := reg.FindOrCreate(ctx, updateID) @@ -253,7 +251,6 @@ func TestUpdateRemovalFromRegistry(t *testing.T) { evStore := mockEventStore{Controller: &effects} meta := updatepb.Meta{UpdateId: storedAcceptedUpdateID} outcome := successOutcome(t, "success!") - require.False(t, reg.HasUndeliveredUpdates(), "accepted is not undelivered") err = upd.OnMessage( ctx, @@ -262,7 +259,6 @@ func TestUpdateRemovalFromRegistry(t *testing.T) { ) require.NoError(t, err) - require.False(t, reg.HasUndeliveredUpdates(), "updates should be ProvisionallyCompleted") require.Equal(t, 1, reg.Len(), "update should still be present in map") effects.Apply(ctx) require.Equal(t, 0, reg.Len(), "update should have been removed")