Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow workflow to complete if there are updates in the registry #4412

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions service/history/workflow/update/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
protocolpb "go.temporal.io/api/protocol/v1"
"go.temporal.io/api/serviceerror"
updatepb "go.temporal.io/api/update/v1"

persistencespb "go.temporal.io/server/api/persistence/v1"
"go.temporal.io/server/common/future"
"go.temporal.io/server/common/log"
Expand All @@ -55,14 +56,13 @@ type (
// new update if no update is found.
Find(ctx context.Context, protocolInstanceID string) (*Update, bool)

// ReadOutoundMessages polls each registered Update for outbound
// ReadOutgoingMessages polls each registered Update for outbound
// messages and returns them.
ReadOutgoingMessages(startedEventID int64) ([]*protocolpb.Message, error)

// HasUndeliveredUpdates returns true if the registry has Updates that
// are not known to have been seen by user workflow code. In practice
// this means updates that have not yet been accepted or rejected.
HasUndeliveredUpdates() bool
// TerminateUpdates terminates all existing updates in the registry
// and notifies update aPI callers with corresponding error.
TerminateUpdates(ctx context.Context, eventStore EventStore)

// HasOutgoing returns true if the registry has any Updates that want to
// sent messages to a worker.
Expand Down Expand Up @@ -167,15 +167,10 @@ func (r *RegistryImpl) Find(ctx context.Context, id string) (*Update, bool) {
return r.findLocked(ctx, id)
}

func (r *RegistryImpl) HasUndeliveredUpdates() bool {
r.mu.RLock()
defer r.mu.RUnlock()
for _, upd := range r.updates {
if !upd.hasBeenSeenByWorkflowExecution() {
return true
}
}
return false
func (r *RegistryImpl) TerminateUpdates(_ context.Context, _ EventStore) {
// TODO (alex-update): implement
// This method is not implemented and update API callers will just timeout.
// In future, it should remove all existing updates and notify callers with better error.
}

func (r *RegistryImpl) HasOutgoing() bool {
Expand Down
6 changes: 1 addition & 5 deletions service/history/workflow/update/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
failurepb "go.temporal.io/api/failure/v1"
"go.temporal.io/api/serviceerror"
updatepb "go.temporal.io/api/update/v1"

historyspb "go.temporal.io/server/api/history/v1"
persistencespb "go.temporal.io/server/api/persistence/v1"
"go.temporal.io/server/internal/effect"
Expand Down Expand Up @@ -104,7 +105,6 @@ func TestFind(t *testing.T) {
_, found, err := reg.FindOrCreate(ctx, updateID)
require.NoError(t, err)
require.False(t, found)
require.True(t, reg.HasUndeliveredUpdates())

_, ok = reg.Find(ctx, updateID)
require.True(t, ok)
Expand Down Expand Up @@ -195,8 +195,6 @@ func TestFindOrCreate(t *testing.T) {
reg = update.NewRegistry(store)
)

require.False(t, reg.HasUndeliveredUpdates())

t.Run("new update", func(t *testing.T) {
updateID := "a completely new update ID"
_, found, err := reg.FindOrCreate(ctx, updateID)
Expand Down Expand Up @@ -253,7 +251,6 @@ func TestUpdateRemovalFromRegistry(t *testing.T) {
evStore := mockEventStore{Controller: &effects}
meta := updatepb.Meta{UpdateId: storedAcceptedUpdateID}
outcome := successOutcome(t, "success!")
require.False(t, reg.HasUndeliveredUpdates(), "accepted is not undelivered")

err = upd.OnMessage(
ctx,
Expand All @@ -262,7 +259,6 @@ func TestUpdateRemovalFromRegistry(t *testing.T) {
)

require.NoError(t, err)
require.False(t, reg.HasUndeliveredUpdates(), "updates should be ProvisionallyCompleted")
require.Equal(t, 1, reg.Len(), "update should still be present in map")
effects.Apply(ctx)
require.Equal(t, 0, reg.Len(), "update should have been removed")
Expand Down
19 changes: 6 additions & 13 deletions service/history/workflowTaskHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
protocolpb "go.temporal.io/api/protocol/v1"
"go.temporal.io/api/serviceerror"
"go.temporal.io/api/workflowservice/v1"

"go.temporal.io/server/internal/effect"
"go.temporal.io/server/internal/protocol"
"go.temporal.io/server/service/history/workflow/update"
Expand Down Expand Up @@ -568,9 +569,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandCompleteWorkflow(
return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_COMMAND, nil)
}

if handler.updateRegistry.HasUndeliveredUpdates() {
return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_UPDATE, nil)
}
handler.updateRegistry.TerminateUpdates(ctx, workflow.WithEffects(handler.effects, handler.mutableState))

if err := handler.validateCommandAttr(
func() (enumspb.WorkflowTaskFailedCause, error) {
Expand Down Expand Up @@ -630,9 +629,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandFailWorkflow(
return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_COMMAND, nil)
}

if handler.updateRegistry.HasUndeliveredUpdates() {
return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_UPDATE, nil)
}
handler.updateRegistry.TerminateUpdates(ctx, workflow.WithEffects(handler.effects, handler.mutableState))

if err := handler.validateCommandAttr(
func() (enumspb.WorkflowTaskFailedCause, error) {
Expand Down Expand Up @@ -726,7 +723,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandCancelTimer(
}

func (handler *workflowTaskHandlerImpl) handleCommandCancelWorkflow(
_ context.Context,
ctx context.Context,
attr *commandpb.CancelWorkflowExecutionCommandAttributes,
) error {

Expand All @@ -736,9 +733,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandCancelWorkflow(
return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_COMMAND, nil)
}

if handler.updateRegistry.HasUndeliveredUpdates() {
return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_UPDATE, nil)
}
handler.updateRegistry.TerminateUpdates(ctx, workflow.WithEffects(handler.effects, handler.mutableState))

if err := handler.validateCommandAttr(
func() (enumspb.WorkflowTaskFailedCause, error) {
Expand Down Expand Up @@ -843,9 +838,7 @@ func (handler *workflowTaskHandlerImpl) handleCommandContinueAsNewWorkflow(
return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_COMMAND, nil)
}

if handler.updateRegistry.HasUndeliveredUpdates() {
return handler.failWorkflowTask(enumspb.WORKFLOW_TASK_FAILED_CAUSE_UNHANDLED_UPDATE, nil)
}
handler.updateRegistry.TerminateUpdates(ctx, workflow.WithEffects(handler.effects, handler.mutableState))

namespaceName := handler.mutableState.GetNamespaceEntry().Name()

Expand Down
139 changes: 137 additions & 2 deletions tests/update_workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1781,7 +1781,7 @@ func (s *integrationSuite) TestUpdateWorkflow_FailWorkflowTask() {

// Complete workflow.
_, err = poller.PollAndProcessWorkflowTask(false, false)
s.Error(err, "update was never successfully accepted so it prevents completion")
s.NoError(err)

s.Equal(5, wtHandlerCalls)
s.Equal(5, msgHandlerCalls)
Expand All @@ -1795,7 +1795,11 @@ func (s *integrationSuite) TestUpdateWorkflow_FailWorkflowTask() {
5 ActivityTaskScheduled
6 WorkflowTaskScheduled
7 WorkflowTaskStarted
8 WorkflowTaskFailed`, events)
8 WorkflowTaskFailed
9 WorkflowTaskScheduled
10 WorkflowTaskStarted
11 WorkflowTaskCompleted
12 WorkflowExecutionCompleted`, events)
}

func (s *integrationSuite) TestUpdateWorkflow_ConvertStartedSpeculativeWorkflowTaskToNormal_BecauseOfBufferedSignal() {
Expand Down Expand Up @@ -2559,3 +2563,134 @@ func (s *integrationSuite) TestUpdateWorkflow_ScheduledSpeculativeWorkflowTask_T
// completion_event_batch_id should point to WFTerminated event.
s.EqualValues(6, msResp.GetDatabaseMutableState().GetExecutionInfo().GetCompletionEventBatchId())
}

func (s *integrationSuite) TestUpdateWorkflow_CompleteWorkflow_CancelUpdate() {
testCases := []struct {
Name string
UpdateErrMsg string
Commands func(tv *testvars.TestVars) []*commandpb.Command
Messages func(tv *testvars.TestVars, updRequestMsg *protocolpb.Message) []*protocolpb.Message
}{
{
Name: "requested",
UpdateErrMsg: "update canceled",
Commands: func(_ *testvars.TestVars) []*commandpb.Command { return nil },
Messages: func(_ *testvars.TestVars, _ *protocolpb.Message) []*protocolpb.Message { return nil },
},
{
Name: "accepted",
UpdateErrMsg: "update canceled",
Commands: func(tv *testvars.TestVars) []*commandpb.Command { return s.acceptUpdateCommands(tv, "1") },
Messages: func(tv *testvars.TestVars, updRequestMsg *protocolpb.Message) []*protocolpb.Message {
return s.acceptUpdateMessages(tv, updRequestMsg, "1")
},
},
{
Name: "completed",
UpdateErrMsg: "",
Commands: func(tv *testvars.TestVars) []*commandpb.Command { return s.acceptCompleteUpdateCommands(tv, "1") },
Messages: func(tv *testvars.TestVars, updRequestMsg *protocolpb.Message) []*protocolpb.Message {
return s.acceptCompleteUpdateMessages(tv, updRequestMsg, "1")
},
},
}

for _, tc := range testCases {
s.T().Run(tc.Name, func(t *testing.T) {
tv := testvars.New(t.Name())

tv = s.startWorkflow(tv)

wtHandlerCalls := 0
wtHandler := func(execution *commonpb.WorkflowExecution, wt *commonpb.WorkflowType, previousStartedEventID, startedEventID int64, history *historypb.History) ([]*commandpb.Command, error) {
wtHandlerCalls++
switch wtHandlerCalls {
case 1:
// Completes first WT with update unrelated command.
return []*commandpb.Command{{
CommandType: enumspb.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK,
Attributes: &commandpb.Command_ScheduleActivityTaskCommandAttributes{ScheduleActivityTaskCommandAttributes: &commandpb.ScheduleActivityTaskCommandAttributes{
ActivityId: tv.ActivityID(),
ActivityType: tv.ActivityType(),
TaskQueue: tv.TaskQueue(),
ScheduleToCloseTimeout: tv.InfiniteTimeout(),
}},
}}, nil
case 2:
return append(tc.Commands(tv), &commandpb.Command{
CommandType: enumspb.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION,
Attributes: &commandpb.Command_CompleteWorkflowExecutionCommandAttributes{CompleteWorkflowExecutionCommandAttributes: &commandpb.CompleteWorkflowExecutionCommandAttributes{}},
}), nil
default:
s.Failf("wtHandler called too many times", "wtHandler shouldn't be called %d times", wtHandlerCalls)
return nil, nil
}
}

msgHandlerCalls := 0
msgHandler := func(task *workflowservice.PollWorkflowTaskQueueResponse) ([]*protocolpb.Message, error) {
msgHandlerCalls++
switch msgHandlerCalls {
case 1:
return nil, nil
case 2:
updRequestMsg := task.Messages[0]
return tc.Messages(tv, updRequestMsg), nil
default:
s.Failf("msgHandler called too many times", "msgHandler shouldn't be called %d times", msgHandlerCalls)
return nil, nil
}
}

poller := &TaskPoller{
Engine: s.engine,
Namespace: s.namespace,
TaskQueue: tv.TaskQueue(),
WorkflowTaskHandler: wtHandler,
MessageHandler: msgHandler,
Logger: s.Logger,
T: s.T(),
}

// Drain exiting first workflow task.
_, err := poller.PollAndProcessWorkflowTask(true, false)
s.NoError(err)

updateResultCh := make(chan struct{})
go func(updateErrMsg string) {
halfSecondTimeoutCtx, cancel := context.WithTimeout(NewContext(), 500*time.Millisecond)
defer cancel()

resp, err1 := s.engine.UpdateWorkflowExecution(halfSecondTimeoutCtx, &workflowservice.UpdateWorkflowExecutionRequest{
Namespace: s.namespace,
WorkflowExecution: tv.WorkflowExecution(),
Request: &updatepb.Request{
Meta: &updatepb.Meta{UpdateId: tv.UpdateID("1")},
Input: &updatepb.Input{
Name: tv.HandlerName(),
Args: payloads.EncodeString(tv.String("args", "1")),
},
},
})

if updateErrMsg == "" {
s.NoError(err1)
s.NotNil(resp)
} else {
s.Error(err1)
s.True(common.IsContextDeadlineExceededErr(err1))
s.Nil(resp)
}
updateResultCh <- struct{}{}
}(tc.UpdateErrMsg)

// Complete workflow.
_, err = poller.PollAndProcessWorkflowTask(false, false)
s.NoError(err)
<-updateResultCh

s.Equal(2, wtHandlerCalls)
s.Equal(2, msgHandlerCalls)
})
}
}