Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functional test for shard reload during update #4437

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion service/history/workflow/update/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
protocolpb "go.temporal.io/api/protocol/v1"
"go.temporal.io/api/serviceerror"
updatepb "go.temporal.io/api/update/v1"

updatespb "go.temporal.io/server/api/update/v1"
"go.temporal.io/server/common/future"
"go.temporal.io/server/common/log"
Expand Down Expand Up @@ -208,7 +209,7 @@ func (r *RegistryImpl) ReadOutgoingMessages(
}

// TODO (alex-update): currently sequencing_id is simply pointing to the
// event before WorkflowTaskStartedEvent. SDKs are supposed to respect this
// event before WorkflowTaskStartedEvent. SDKs are supposed to respect this
// and process messages (specifically, updates) after event with that ID.
// In the future, sequencing_id could point to some specific event
// (specifically, signal) after which the update should be processed.
Expand Down
16 changes: 16 additions & 0 deletions tests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import (
commonpb "go.temporal.io/api/common/v1"
"go.temporal.io/api/workflowservice/v1"

"go.temporal.io/server/api/adminservice/v1"
"go.temporal.io/server/common"
"go.temporal.io/server/common/dynamicconfig"
"go.temporal.io/server/common/payloads"
)
Expand Down Expand Up @@ -85,6 +87,20 @@ func (s *integrationSuite) sendSignal(namespace string, execution *commonpb.Work
return err
}

func (s *integrationSuite) closeShard(wid string) {
s.T().Helper()

resp, err := s.engine.DescribeNamespace(NewContext(), &workflowservice.DescribeNamespaceRequest{
Namespace: s.namespace,
})
s.NoError(err)

_, err = s.adminClient.CloseShard(NewContext(), &adminservice.CloseShardRequest{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does this actually do what the function name implies?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it least it does what I want it to do: both new tests have asserts that update /WT is gone after shard is closed.

ShardId: common.WorkflowIDToHistoryShard(resp.NamespaceInfo.Id, wid, s.testClusterConfig.HistoryConfig.NumHistoryShards),
})
s.NoError(err)
}

func unmarshalAny[T proto.Message](s *integrationSuite, a *types.Any) T {
s.T().Helper()
pb := new(T)
Expand Down
8 changes: 8 additions & 0 deletions tests/testvars/test_vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,14 @@ func (tv *TestVars) WithHandlerName(handlerName string, key ...string) *TestVars
return tv.cloneSet("handler_name", handlerName, key)
}

func (tv *TestVars) WorkerIdentity(key ...string) string {
return tv.getOrCreate("worker_identity", key)
}

func (tv *TestVars) WithWorkerIdentity(identity string, key ...string) *TestVars {
return tv.cloneSet("worker_identity", identity, key)
}

// ----------- Generic methods ------------

func (tv *TestVars) InfiniteTimeout() *time.Duration {
Expand Down
284 changes: 284 additions & 0 deletions tests/update_workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2985,3 +2985,287 @@ func (s *integrationSuite) TestUpdateWorkflow_SpeculativeWorkflowTask_Heartbeat(
14 WorkflowTaskCompleted
15 WorkflowExecutionCompleted`, events)
}

func (s *integrationSuite) TestUpdateWorkflow_NewScheduledSpeculativeWorkflowTaskLost_BecauseOfShardMove() {
tv := testvars.New(s.T().Name())

tv = s.startWorkflow(tv)

wtHandlerCalls := 0
wtHandler := func(execution *commonpb.WorkflowExecution, wt *commonpb.WorkflowType, previousStartedEventID, startedEventID int64, history *historypb.History) ([]*commandpb.Command, error) {
wtHandlerCalls++
switch wtHandlerCalls {
case 1:
// Completes first WT with update unrelated command.
return []*commandpb.Command{{
CommandType: enumspb.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK,
Attributes: &commandpb.Command_ScheduleActivityTaskCommandAttributes{ScheduleActivityTaskCommandAttributes: &commandpb.ScheduleActivityTaskCommandAttributes{
ActivityId: tv.ActivityID(),
ActivityType: tv.ActivityType(),
TaskQueue: tv.TaskQueue(),
ScheduleToCloseTimeout: tv.InfiniteTimeout(),
}},
}}, nil
case 2:
s.EqualHistory(`
1 WorkflowExecutionStarted
2 WorkflowTaskScheduled
3 WorkflowTaskStarted
4 WorkflowTaskCompleted
5 ActivityTaskScheduled
6 WorkflowExecutionSignaled
7 WorkflowTaskScheduled
8 WorkflowTaskStarted`, history)
return []*commandpb.Command{{
CommandType: enumspb.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION,
Attributes: &commandpb.Command_CompleteWorkflowExecutionCommandAttributes{CompleteWorkflowExecutionCommandAttributes: &commandpb.CompleteWorkflowExecutionCommandAttributes{}},
}}, nil
default:
s.Failf("wtHandler called too many times", "wtHandler shouldn't be called %d times", wtHandlerCalls)
return nil, nil
}
}

msgHandlerCalls := 0
msgHandler := func(task *workflowservice.PollWorkflowTaskQueueResponse) ([]*protocolpb.Message, error) {
msgHandlerCalls++
switch msgHandlerCalls {
case 1:
return nil, nil
case 2:
s.Empty(task.Messages, "update must be lost due to shard reload")
return nil, nil
default:
s.Failf("msgHandler called too many times", "msgHandler shouldn't be called %d times", msgHandlerCalls)
return nil, nil
}
}

poller := &TaskPoller{
Engine: s.engine,
Namespace: s.namespace,
TaskQueue: tv.TaskQueue(),
WorkflowTaskHandler: wtHandler,
MessageHandler: msgHandler,
Logger: s.Logger,
T: s.T(),
}

// Drain exiting first workflow task.
_, err := poller.PollAndProcessWorkflowTask(true, false)
s.NoError(err)

updateResultCh := make(chan struct{})
updateWorkflowFn := func() {
halfSecondTimeoutCtx, cancel := context.WithTimeout(NewContext(), 500*time.Millisecond)
defer cancel()

updateResponse, err1 := s.engine.UpdateWorkflowExecution(halfSecondTimeoutCtx, &workflowservice.UpdateWorkflowExecutionRequest{
Namespace: s.namespace,
WorkflowExecution: tv.WorkflowExecution(),
Request: &updatepb.Request{
Meta: &updatepb.Meta{UpdateId: tv.UpdateID("1")},
Input: &updatepb.Input{
Name: tv.Any(),
Args: payloads.EncodeString(tv.Any()),
},
},
})
assert.Error(s.T(), err1)
assert.True(s.T(), common.IsContextDeadlineExceededErr(err1), err1)
assert.Nil(s.T(), updateResponse)

updateResultCh <- struct{}{}
}
go updateWorkflowFn()

// Close shard, Speculative WT with update will be lost.
s.closeShard(tv.WorkflowID())

pollCtx, cancel := context.WithTimeout(NewContext(), common.MinLongPollTimeout+100*time.Millisecond)
defer cancel()
pollResponse, err := s.engine.PollWorkflowTaskQueue(pollCtx, &workflowservice.PollWorkflowTaskQueueRequest{
Namespace: s.namespace,
TaskQueue: tv.TaskQueue(),
Identity: tv.WorkerIdentity(),
})
s.NoError(err)
s.Nil(pollResponse.Messages)

<-updateResultCh

// Send signal to schedule new WT.
err = s.sendSignal(s.namespace, tv.WorkflowExecution(), tv.Any(), payloads.EncodeString(tv.Any()), tv.Any())
s.NoError(err)

// Complete workflow.
completeWorkflowResp, err := poller.PollAndProcessWorkflowTask(false, false)
s.NoError(err)
s.NotNil(completeWorkflowResp)

s.Equal(2, wtHandlerCalls)
s.Equal(2, msgHandlerCalls)

events := s.getHistory(s.namespace, tv.WorkflowExecution())

s.EqualHistoryEvents(`
1 WorkflowExecutionStarted
2 WorkflowTaskScheduled
3 WorkflowTaskStarted
4 WorkflowTaskCompleted
5 ActivityTaskScheduled
6 WorkflowExecutionSignaled
7 WorkflowTaskScheduled
8 WorkflowTaskStarted
9 WorkflowTaskCompleted
10 WorkflowExecutionCompleted`, events)
}

func (s *integrationSuite) TestUpdateWorkflow_NewStartedSpeculativeWorkflowTaskLost_BecauseOfShardMove() {
tv := testvars.New(s.T().Name())

tv = s.startWorkflow(tv)

wtHandlerCalls := 0
wtHandler := func(execution *commonpb.WorkflowExecution, wt *commonpb.WorkflowType, previousStartedEventID, startedEventID int64, history *historypb.History) ([]*commandpb.Command, error) {
wtHandlerCalls++
switch wtHandlerCalls {
case 1:
// Completes first WT with update unrelated command.
return []*commandpb.Command{{
CommandType: enumspb.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK,
Attributes: &commandpb.Command_ScheduleActivityTaskCommandAttributes{ScheduleActivityTaskCommandAttributes: &commandpb.ScheduleActivityTaskCommandAttributes{
ActivityId: tv.ActivityID(),
ActivityType: tv.ActivityType(),
TaskQueue: tv.TaskQueue(),
ScheduleToCloseTimeout: tv.InfiniteTimeout(),
}},
}}, nil
case 2:
// Speculative WT, with update.Request message.
s.EqualHistory(`
1 WorkflowExecutionStarted
2 WorkflowTaskScheduled
3 WorkflowTaskStarted
4 WorkflowTaskCompleted
5 ActivityTaskScheduled
6 WorkflowTaskScheduled
7 WorkflowTaskStarted`, history)

// Close shard, Speculative WT (6 and 7) will be lost, and NotFound error will be returned to RespondWorkflowTaskCompleted.
s.closeShard(tv.WorkflowID())

return s.acceptCompleteUpdateCommands(tv, "1"), nil
case 3:
s.EqualHistory(`
1 WorkflowExecutionStarted
2 WorkflowTaskScheduled
3 WorkflowTaskStarted
4 WorkflowTaskCompleted
5 ActivityTaskScheduled
6 WorkflowExecutionSignaled
7 WorkflowTaskScheduled
8 WorkflowTaskStarted`, history)
return []*commandpb.Command{{
CommandType: enumspb.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION,
Attributes: &commandpb.Command_CompleteWorkflowExecutionCommandAttributes{CompleteWorkflowExecutionCommandAttributes: &commandpb.CompleteWorkflowExecutionCommandAttributes{}},
}}, nil
default:
s.Failf("wtHandler called too many times", "wtHandler shouldn't be called %d times", wtHandlerCalls)
return nil, nil
}
}

msgHandlerCalls := 0
msgHandler := func(task *workflowservice.PollWorkflowTaskQueueResponse) ([]*protocolpb.Message, error) {
msgHandlerCalls++
switch msgHandlerCalls {
case 1:
return nil, nil
case 2:
updRequestMsg := task.Messages[0]
s.EqualValues(6, updRequestMsg.GetEventId())

return s.acceptCompleteUpdateMessages(tv, updRequestMsg, "1"), nil
case 3:
s.Empty(task.Messages, "update must be lost due to shard reload")
return nil, nil
default:
s.Failf("msgHandler called too many times", "msgHandler shouldn't be called %d times", msgHandlerCalls)
return nil, nil
}
}

poller := &TaskPoller{
Engine: s.engine,
Namespace: s.namespace,
TaskQueue: tv.TaskQueue(),
WorkflowTaskHandler: wtHandler,
MessageHandler: msgHandler,
Logger: s.Logger,
T: s.T(),
}

// Drain exiting first workflow task.
_, err := poller.PollAndProcessWorkflowTask(true, false)
s.NoError(err)

updateResultCh := make(chan struct{})
updateWorkflowFn := func() {
halfSecondTimeoutCtx, cancel := context.WithTimeout(NewContext(), 500*time.Millisecond)
defer cancel()

updateResponse, err1 := s.engine.UpdateWorkflowExecution(halfSecondTimeoutCtx, &workflowservice.UpdateWorkflowExecutionRequest{
Namespace: s.namespace,
WorkflowExecution: tv.WorkflowExecution(),
Request: &updatepb.Request{
Meta: &updatepb.Meta{UpdateId: tv.UpdateID("1")},
Input: &updatepb.Input{
Name: tv.Any(),
Args: payloads.EncodeString(tv.Any()),
},
},
})
assert.Error(s.T(), err1)
assert.True(s.T(), common.IsContextDeadlineExceededErr(err1), err1)
assert.Nil(s.T(), updateResponse)

updateResultCh <- struct{}{}
}
go updateWorkflowFn()

// Process update in workflow.
_, updateResp, err := poller.PollAndProcessWorkflowTaskWithAttemptAndRetryAndForceNewWorkflowTask(false, false, false, false, 1, 1, true, nil)
s.Error(err)
s.IsType(&serviceerror.NotFound{}, err)
s.ErrorContains(err, "Workflow task not found")
s.Nil(updateResp)

<-updateResultCh

// Send signal to schedule new WT.
err = s.sendSignal(s.namespace, tv.WorkflowExecution(), tv.Any(), payloads.EncodeString(tv.Any()), tv.Any())
s.NoError(err)

// Complete workflow.
completeWorkflowResp, err := poller.PollAndProcessWorkflowTask(false, false)
s.NoError(err)
s.NotNil(completeWorkflowResp)

s.Equal(3, wtHandlerCalls)
s.Equal(3, msgHandlerCalls)

events := s.getHistory(s.namespace, tv.WorkflowExecution())

s.EqualHistoryEvents(`
1 WorkflowExecutionStarted
2 WorkflowTaskScheduled
3 WorkflowTaskStarted
4 WorkflowTaskCompleted
5 ActivityTaskScheduled
6 WorkflowExecutionSignaled
7 WorkflowTaskScheduled
8 WorkflowTaskStarted
9 WorkflowTaskCompleted
10 WorkflowExecutionCompleted`, events)
}