From 2de41e587256926ecc6cb4f95ac9d5d8d0baea76 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Thu, 25 May 2023 23:17:09 -0700 Subject: [PATCH] Handle worker versioning for sticky queues (#4397) Note: This commit came from a feature branch and is not expected to build. --- service/frontend/errors.go | 1 + service/frontend/workflow_handler.go | 61 ++--- service/history/api/get_workflow_util.go | 5 +- .../transferQueueActiveTaskExecutor_test.go | 3 +- .../history/workflow/mutable_state_impl.go | 10 +- service/matching/db.go | 10 +- service/matching/matchingEngine.go | 140 +++++----- service/matching/matchingEngine_test.go | 28 +- service/matching/taskQueueManager.go | 72 ++++-- service/matching/taskQueueManager_test.go | 61 ++++- service/matching/taskReader.go | 2 +- service/matching/taskWriter.go | 2 +- service/matching/version_sets.go | 51 +++- tests/versioning_test.go | 244 ++++++++++-------- 14 files changed, 436 insertions(+), 254 deletions(-) diff --git a/service/frontend/errors.go b/service/frontend/errors.go index 8f6f02f78be..aeac87b145e 100644 --- a/service/frontend/errors.go +++ b/service/frontend/errors.go @@ -81,6 +81,7 @@ var ( errInvalidWorkflowStartDelaySeconds = serviceerror.NewInvalidArgument("An invalid WorkflowStartDelaySeconds is set on request.") errRaceConditionAddingSearchAttributes = serviceerror.NewUnavailable("Generated search attributes mapping unavailble.") errUseVersioningWithoutBuildId = serviceerror.NewInvalidArgument("WorkerVersionStamp must be present if UseVersioning is true.") + errUseVersioningWithoutNormalName = serviceerror.NewInvalidArgument("NormalName must be set on sticky queue if UseVersioning is true.") errBuildIdTooLong = serviceerror.NewInvalidArgument("Build ID exceeds configured limit.workerBuildIdSize, use a shorter build ID.") errUpdateMetaNotSet = serviceerror.NewInvalidArgument("Update meta is not set on request.") diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index 0394b1dd215..072c6e714aa 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -840,12 +840,8 @@ func (wh *WorkflowHandler) PollWorkflowTaskQueue(ctx context.Context, request *w return nil, errIdentityTooLong } - if request.GetWorkerVersionCapabilities().GetUseVersioning() && !wh.config.EnableWorkerVersioningWorkflow(request.Namespace) { - return nil, errWorkerVersioningNotAllowed - } - - if len(request.GetWorkerVersionCapabilities().GetBuildId()) > wh.config.WorkerBuildIdSizeLimit() { - return nil, errBuildIdTooLong + if err := wh.validateVersioningInfo(request.Namespace, request.WorkerVersionCapabilities, request.TaskQueue); err != nil { + return nil, err } if err := wh.validateTaskQueue(request.TaskQueue); err != nil { @@ -858,14 +854,6 @@ func (wh *WorkflowHandler) PollWorkflowTaskQueue(ctx context.Context, request *w } namespaceID := namespaceEntry.ID() - // Copy WorkerVersionCapabilities.BuildId to BinaryChecksum if BinaryChecksum is missing (small - // optimization to save space in the poll request). - if request.WorkerVersionCapabilities != nil { - if len(request.WorkerVersionCapabilities.BuildId) > 0 && len(request.BinaryChecksum) == 0 { - request.BinaryChecksum = request.WorkerVersionCapabilities.BuildId - } - } - wh.logger.Debug("Poll workflow task queue.", tag.WorkflowNamespace(namespaceEntry.Name().String()), tag.WorkflowNamespaceID(namespaceID.String())) if err := wh.checkBadBinary(namespaceEntry, request.GetBinaryChecksum()); err != nil { return nil, err @@ -942,12 +930,12 @@ func (wh *WorkflowHandler) RespondWorkflowTaskCompleted( return nil, errIdentityTooLong } - if request.GetWorkerVersionStamp().GetUseVersioning() && !wh.config.EnableWorkerVersioningWorkflow(request.Namespace) { - return nil, errWorkerVersioningNotAllowed - } - - if len(request.GetWorkerVersionStamp().GetBuildId()) > wh.config.WorkerBuildIdSizeLimit() { - return nil, errBuildIdTooLong + if err := wh.validateVersioningInfo( + request.Namespace, + request.WorkerVersionStamp, + request.StickyAttributes.GetWorkerTaskQueue(), + ); err != nil { + return nil, err } taskToken, err := wh.tokenSerializer.Deserialize(request.TaskToken) @@ -956,10 +944,6 @@ func (wh *WorkflowHandler) RespondWorkflowTaskCompleted( } namespaceId := namespace.ID(taskToken.GetNamespaceId()) - if request.WorkerVersionStamp.GetUseVersioning() && len(request.WorkerVersionStamp.GetBuildId()) == 0 { - return nil, errUseVersioningWithoutBuildId - } - wh.overrides.DisableEagerActivityDispatchForBuggyClients(ctx, request) histResp, err := wh.historyClient.RespondWorkflowTaskCompleted(ctx, &historyservice.RespondWorkflowTaskCompletedRequest{ @@ -1102,12 +1086,8 @@ func (wh *WorkflowHandler) PollActivityTaskQueue(ctx context.Context, request *w return nil, errIdentityTooLong } - if request.GetWorkerVersionCapabilities().GetUseVersioning() && !wh.config.EnableWorkerVersioningWorkflow(request.Namespace) { - return nil, errWorkerVersioningNotAllowed - } - - if len(request.GetWorkerVersionCapabilities().GetBuildId()) > wh.config.WorkerBuildIdSizeLimit() { - return nil, errBuildIdTooLong + if err := wh.validateVersioningInfo(request.Namespace, request.WorkerVersionCapabilities, request.TaskQueue); err != nil { + return nil, err } namespaceID, err := wh.namespaceRegistry.GetNamespaceID(namespace.Name(request.GetNamespace())) @@ -4292,6 +4272,27 @@ func (wh *WorkflowHandler) validateTaskQueue(t *taskqueuepb.TaskQueue) error { return nil } +type buildIdAndFlag interface { + GetBuildId() string + GetUseVersioning() bool +} + +func (wh *WorkflowHandler) validateVersioningInfo(namespace string, id buildIdAndFlag, tq *taskqueuepb.TaskQueue) error { + if id.GetUseVersioning() && !wh.config.EnableWorkerVersioningWorkflow(namespace) { + return errWorkerVersioningNotAllowed + } + if id.GetUseVersioning() && tq.GetKind() == enumspb.TASK_QUEUE_KIND_STICKY && len(tq.GetNormalName()) == 0 { + return errUseVersioningWithoutNormalName + } + if id.GetUseVersioning() && len(id.GetBuildId()) == 0 { + return errUseVersioningWithoutBuildId + } + if len(id.GetBuildId()) > wh.config.WorkerBuildIdSizeLimit() { + return errBuildIdTooLong + } + return nil +} + //nolint:revive // cyclomatic complexity func (wh *WorkflowHandler) validateBuildIdCompatibilityUpdate( req *workflowservice.UpdateWorkerBuildIdCompatibilityRequest, diff --git a/service/history/api/get_workflow_util.go b/service/history/api/get_workflow_util.go index f546958d361..f7bdc30a15f 100644 --- a/service/history/api/get_workflow_util.go +++ b/service/history/api/get_workflow_util.go @@ -205,8 +205,9 @@ func MutableStateToGetResponse( Kind: enumspb.TASK_QUEUE_KIND_NORMAL, }, StickyTaskQueue: &taskqueuepb.TaskQueue{ - Name: executionInfo.StickyTaskQueue, - Kind: enumspb.TASK_QUEUE_KIND_STICKY, + Name: executionInfo.StickyTaskQueue, + Kind: enumspb.TASK_QUEUE_KIND_STICKY, + NormalName: executionInfo.TaskQueue, }, StickyTaskQueueScheduleToStartTimeout: executionInfo.StickyScheduleToStartTimeout, CurrentBranchToken: currentBranchToken, diff --git a/service/history/transferQueueActiveTaskExecutor_test.go b/service/history/transferQueueActiveTaskExecutor_test.go index 483f97c0c89..aa829d27e3f 100644 --- a/service/history/transferQueueActiveTaskExecutor_test.go +++ b/service/history/transferQueueActiveTaskExecutor_test.go @@ -2642,8 +2642,9 @@ func (s *transferQueueActiveTaskExecutorSuite) createAddWorkflowTaskRequest( } executionInfo := mutableState.GetExecutionInfo() timeout := executionInfo.WorkflowRunTimeout - if mutableState.GetExecutionInfo().TaskQueue != task.TaskQueue { + if executionInfo.TaskQueue != task.TaskQueue { taskQueue.Kind = enumspb.TASK_QUEUE_KIND_STICKY + taskQueue.NormalName = executionInfo.TaskQueue timeout = executionInfo.StickyScheduleToStartTimeout } diff --git a/service/history/workflow/mutable_state_impl.go b/service/history/workflow/mutable_state_impl.go index e57e2f42e73..f1aa9f198ba 100644 --- a/service/history/workflow/mutable_state_impl.go +++ b/service/history/workflow/mutable_state_impl.go @@ -677,8 +677,9 @@ func (ms *MutableStateImpl) GetNamespaceEntry() *namespace.Namespace { func (ms *MutableStateImpl) CurrentTaskQueue() *taskqueuepb.TaskQueue { if ms.IsStickyTaskQueueSet() { return &taskqueuepb.TaskQueue{ - Name: ms.executionInfo.StickyTaskQueue, - Kind: enumspb.TASK_QUEUE_KIND_STICKY, + Name: ms.executionInfo.StickyTaskQueue, + Kind: enumspb.TASK_QUEUE_KIND_STICKY, + NormalName: ms.executionInfo.TaskQueue, } } return &taskqueuepb.TaskQueue{ @@ -707,8 +708,9 @@ func (ms *MutableStateImpl) IsStickyTaskQueueSet() bool { func (ms *MutableStateImpl) TaskQueueScheduleToStartTimeout(name string) (*taskqueuepb.TaskQueue, *time.Duration) { if ms.executionInfo.TaskQueue != name { return &taskqueuepb.TaskQueue{ - Name: ms.executionInfo.StickyTaskQueue, - Kind: enumspb.TASK_QUEUE_KIND_STICKY, + Name: ms.executionInfo.StickyTaskQueue, + Kind: enumspb.TASK_QUEUE_KIND_STICKY, + NormalName: ms.executionInfo.TaskQueue, }, ms.executionInfo.StickyScheduleToStartTimeout } return &taskqueuepb.TaskQueue{ diff --git a/service/matching/db.go b/service/matching/db.go index 84d11e64aac..da640f16c6b 100644 --- a/service/matching/db.go +++ b/service/matching/db.go @@ -297,6 +297,12 @@ func (db *taskQueueDB) CompleteTasksLessThan( return n, err } +// Returns true if we are storing user data in the db. We need to be the root partition, +// workflow type, unversioned, and also a normal queue. +func (db *taskQueueDB) DbStoresUserData() bool { + return db.taskQueue.OwnsUserData() && db.taskQueueKind == enumspb.TASK_QUEUE_KIND_NORMAL +} + // GetUserData returns the versioning data for this task queue. Do not mutate the returned pointer, as doing so // will cause cache inconsistency. func (db *taskQueueDB) GetUserData( @@ -320,7 +326,7 @@ func (db *taskQueueDB) getUserDataLocked( ctx context.Context, ) (*persistencespb.VersionedTaskQueueUserData, chan struct{}, error) { if db.userData == nil { - if !db.taskQueue.OwnsUserData() { + if !db.DbStoresUserData() { return nil, db.userDataChanged, nil } @@ -350,7 +356,7 @@ func (db *taskQueueDB) getUserDataLocked( // // On success returns a pointer to the updated data, which must *not* be mutated. func (db *taskQueueDB) UpdateUserData(ctx context.Context, updateFn func(*persistencespb.TaskQueueUserData) (*persistencespb.TaskQueueUserData, error), taskQueueLimitPerBuildId int) (*persistencespb.VersionedTaskQueueUserData, error) { - if !db.taskQueue.OwnsUserData() { + if !db.DbStoresUserData() { return nil, errUserDataNoMutateNonRoot } db.Lock() diff --git a/service/matching/matchingEngine.go b/service/matching/matchingEngine.go index 5f672ad38b4..f4a72869797 100644 --- a/service/matching/matchingEngine.go +++ b/service/matching/matchingEngine.go @@ -90,7 +90,7 @@ type ( taskQueueCounterKey struct { namespaceID namespace.ID taskType enumspb.TaskQueueType - queueType enumspb.TaskQueueKind + kind enumspb.TaskQueueKind } pollMetadata struct { @@ -111,9 +111,9 @@ type ( tokenSerializer common.TaskTokenSerializer logger log.Logger metricsHandler metrics.Handler - taskQueuesLock sync.RWMutex // locks mutation of taskQueues - taskQueues map[taskQueueID]taskQueueManager // Convert to LRU cache - taskQueueCount map[taskQueueCounterKey]int // per-namespace task queue counter + taskQueuesLock sync.RWMutex // locks mutation of taskQueues + taskQueues map[taskQueueID]taskQueueManager + taskQueueCount map[taskQueueCounterKey]int // per-namespace task queue counter config *Config lockableQueryTaskMap lockableQueryTaskMap namespaceRegistry namespace.Registry @@ -231,7 +231,20 @@ func (e *matchingEngineImpl) String() string { // Returns taskQueueManager for a task queue. If not already cached, and create is true, tries // to get new range from DB and create one. This blocks (up to the context deadline) for the // task queue to be initialized. -func (e *matchingEngineImpl) getTaskQueueManager(ctx context.Context, taskQueue *taskQueueID, taskQueueKind enumspb.TaskQueueKind, create bool) (taskQueueManager, error) { +// +// Note that stickyInfo is not used as part of the task queue identity. That means that if +// getTaskQueueManager is called twice with the same taskQueue but different stickyInfo, the +// properties of the taskQueueManager will depend on which call came first. In general we can +// rely on kind being the same for all calls now, but normalName was a later addition to the +// protocol and is not always set consistently. normalName is only required when using +// versioning, and SDKs that support versioning will always set it. The current server version +// will also set it when adding tasks from history. So that particular inconsistency is okay. +func (e *matchingEngineImpl) getTaskQueueManager( + ctx context.Context, + taskQueue *taskQueueID, + stickyInfo stickyInfo, + create bool, +) (taskQueueManager, error) { e.taskQueuesLock.RLock() tqm, ok := e.taskQueues[*taskQueue] e.taskQueuesLock.RUnlock() @@ -245,14 +258,18 @@ func (e *matchingEngineImpl) getTaskQueueManager(ctx context.Context, taskQueue e.taskQueuesLock.Lock() if tqm, ok = e.taskQueues[*taskQueue]; !ok { var err error - tqm, err = newTaskQueueManager(e, taskQueue, taskQueueKind, e.config, e.clusterMeta) + tqm, err = newTaskQueueManager(e, taskQueue, stickyInfo, e.config, e.clusterMeta) if err != nil { e.taskQueuesLock.Unlock() return nil, err } tqm.Start() e.taskQueues[*taskQueue] = tqm - countKey := taskQueueCounterKey{namespaceID: taskQueue.namespaceID, taskType: taskQueue.taskType, queueType: taskQueueKind} + countKey := taskQueueCounterKey{ + namespaceID: taskQueue.namespaceID, + taskType: taskQueue.taskType, + kind: stickyInfo.kind, + } e.taskQueueCount[countKey]++ taskQueueCount := e.taskQueueCount[countKey] e.updateTaskQueueGauge(countKey, taskQueueCount) @@ -281,7 +298,7 @@ func (e *matchingEngineImpl) AddWorkflowTask( ) (bool, error) { namespaceID := namespace.ID(addRequest.GetNamespaceId()) taskQueueName := addRequest.TaskQueue.GetName() - taskQueueKind := addRequest.TaskQueue.GetKind() + stickyInfo := stickyInfoFromTaskQueue(addRequest.TaskQueue) origTaskQueue, err := newTaskQueueID(namespaceID, taskQueueName, enumspb.TASK_QUEUE_TYPE_WORKFLOW) if err != nil { @@ -291,14 +308,14 @@ func (e *matchingEngineImpl) AddWorkflowTask( // We don't need the userDataChanged channel here because: // - if we sync match or sticky worker unavailable, we're done // - if we spool to db, we'll re-resolve when it comes out of the db - taskQueue, _, err := e.redirectToVersionedQueueForAdd(ctx, origTaskQueue, addRequest.VersionDirective, taskQueueKind) + taskQueue, _, err := e.redirectToVersionedQueueForAdd(ctx, origTaskQueue, addRequest.VersionDirective, stickyInfo) if err != nil { return false, err } - sticky := taskQueueKind == enumspb.TASK_QUEUE_KIND_STICKY + sticky := stickyInfo.kind == enumspb.TASK_QUEUE_KIND_STICKY // do not load sticky task queue if it is not already loaded, which means it has no poller. - tqm, err := e.getTaskQueueManager(ctx, taskQueue, taskQueueKind, !sticky) + tqm, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, !sticky) if err != nil { return false, err } else if sticky && (tqm == nil || !tqm.HasPollerAfter(time.Now().Add(-stickyPollerUnavailableWindow))) { @@ -341,7 +358,7 @@ func (e *matchingEngineImpl) AddActivityTask( namespaceID := namespace.ID(addRequest.GetNamespaceId()) runID := addRequest.Execution.GetRunId() taskQueueName := addRequest.TaskQueue.GetName() - taskQueueKind := addRequest.TaskQueue.GetKind() + stickyInfo := stickyInfoFromTaskQueue(addRequest.TaskQueue) origTaskQueue, err := newTaskQueueID(namespaceID, taskQueueName, enumspb.TASK_QUEUE_TYPE_ACTIVITY) if err != nil { @@ -351,12 +368,12 @@ func (e *matchingEngineImpl) AddActivityTask( // We don't need the userDataChanged channel here because: // - if we sync match, we're done // - if we spool to db, we'll re-resolve when it comes out of the db - taskQueue, _, err := e.redirectToVersionedQueueForAdd(ctx, origTaskQueue, addRequest.VersionDirective, taskQueueKind) + taskQueue, _, err := e.redirectToVersionedQueueForAdd(ctx, origTaskQueue, addRequest.VersionDirective, stickyInfo) if err != nil { return false, err } - tlMgr, err := e.getTaskQueueManager(ctx, taskQueue, taskQueueKind, true) + tlMgr, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, true) if err != nil { return false, err } @@ -392,21 +409,23 @@ func (e *matchingEngineImpl) DispatchSpooledTask( ctx context.Context, task *internalTask, origTaskQueue *taskQueueID, - kind enumspb.TaskQueueKind, + stickyInfo stickyInfo, ) error { + taskInfo := task.event.GetData() // This task came from taskReader so task.event is always set here. - directive := task.event.GetData().GetVersionDirective() + directive := taskInfo.GetVersionDirective() // If this came from a versioned queue, ignore the version and re-resolve, in case we're // going to the default and the default changed. unversionedOrigTaskQueue := newTaskQueueIDWithVersionSet(origTaskQueue, "") // Redirect and re-resolve if we're blocked in matcher and user data changes. for { - taskQueue, userDataChanged, err := e.redirectToVersionedQueueForAdd(ctx, unversionedOrigTaskQueue, directive, kind) + taskQueue, userDataChanged, err := e.redirectToVersionedQueueForAdd( + ctx, unversionedOrigTaskQueue, directive, stickyInfo) if err != nil { return err } - sticky := kind == enumspb.TASK_QUEUE_KIND_STICKY - tqm, err := e.getTaskQueueManager(ctx, taskQueue, kind, !sticky) + sticky := stickyInfo.kind == enumspb.TASK_QUEUE_KIND_STICKY + tqm, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, !sticky) if err != nil { return err } @@ -427,6 +446,7 @@ func (e *matchingEngineImpl) PollWorkflowTaskQueue( pollerID := req.GetPollerId() request := req.PollRequest taskQueueName := request.TaskQueue.GetName() + stickyInfo := stickyInfoFromTaskQueue(request.TaskQueue) e.logger.Debug("Received PollWorkflowTaskQueue for taskQueue", tag.WorkflowTaskQueueName(taskQueueName)) pollLoop: for { @@ -442,11 +462,10 @@ pollLoop: if err != nil { return nil, err } - taskQueueKind := request.TaskQueue.GetKind() pollMetadata := &pollMetadata{ workerVersionCapabilities: request.WorkerVersionCapabilities, } - task, err := e.getTask(pollerCtx, taskQueue, taskQueueKind, pollMetadata) + task, err := e.getTask(pollerCtx, taskQueue, stickyInfo, pollMetadata) if err != nil { // TODO: Is empty poll the best reply for errPumpClosed? if err == ErrNoTasks || err == errPumpClosed { @@ -541,6 +560,7 @@ func (e *matchingEngineImpl) PollActivityTaskQueue( pollerID := req.GetPollerId() request := req.PollRequest taskQueueName := request.TaskQueue.GetName() + stickyInfo := stickyInfoFromTaskQueue(request.TaskQueue) e.logger.Debug("Received PollActivityTaskQueue for taskQueue", tag.Name(taskQueueName)) pollLoop: for { @@ -558,14 +578,13 @@ pollLoop: // long-poll when frontend calls CancelOutstandingPoll API pollerCtx := context.WithValue(ctx, pollerIDKey, pollerID) pollerCtx = context.WithValue(pollerCtx, identityKey, request.GetIdentity()) - taskQueueKind := request.TaskQueue.GetKind() pollMetadata := &pollMetadata{ workerVersionCapabilities: request.WorkerVersionCapabilities, } if request.TaskQueueMetadata != nil && request.TaskQueueMetadata.MaxTasksPerSecond != nil { pollMetadata.ratePerSecond = &request.TaskQueueMetadata.MaxTasksPerSecond.Value } - task, err := e.getTask(pollerCtx, taskQueue, taskQueueKind, pollMetadata) + task, err := e.getTask(pollerCtx, taskQueue, stickyInfo, pollMetadata) if err != nil { // TODO: Is empty poll the best reply for errPumpClosed? if err == ErrNoTasks || err == errPumpClosed { @@ -628,7 +647,7 @@ func (e *matchingEngineImpl) QueryWorkflow( ) (*matchingservice.QueryWorkflowResponse, error) { namespaceID := namespace.ID(queryRequest.GetNamespaceId()) taskQueueName := queryRequest.TaskQueue.GetName() - taskQueueKind := queryRequest.TaskQueue.GetKind() + stickyInfo := stickyInfoFromTaskQueue(queryRequest.TaskQueue) origTaskQueue, err := newTaskQueueID(namespaceID, taskQueueName, enumspb.TASK_QUEUE_TYPE_WORKFLOW) if err != nil { @@ -637,14 +656,14 @@ func (e *matchingEngineImpl) QueryWorkflow( // We don't need the userDataChanged channel here because we either do this sync (local or remote) // or fail with a relatively short timeout. - taskQueue, _, err := e.redirectToVersionedQueueForAdd(ctx, origTaskQueue, queryRequest.VersionDirective, taskQueueKind) + taskQueue, _, err := e.redirectToVersionedQueueForAdd(ctx, origTaskQueue, queryRequest.VersionDirective, stickyInfo) if err != nil { return nil, err } - sticky := taskQueueKind == enumspb.TASK_QUEUE_KIND_STICKY + sticky := stickyInfo.kind == enumspb.TASK_QUEUE_KIND_STICKY // do not load sticky task queue if it is not already loaded, which means it has no poller. - tqm, err := e.getTaskQueueManager(ctx, taskQueue, taskQueueKind, !sticky) + tqm, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, !sticky) if err != nil { return nil, err } else if sticky && (tqm == nil || !tqm.HasPollerAfter(time.Now().Add(-stickyPollerUnavailableWindow))) { @@ -714,15 +733,15 @@ func (e *matchingEngineImpl) CancelOutstandingPoll( namespaceID := namespace.ID(request.GetNamespaceId()) taskQueueType := request.GetTaskQueueType() taskQueueName := request.TaskQueue.GetName() + stickyInfo := stickyInfoFromTaskQueue(request.TaskQueue) pollerID := request.GetPollerId() taskQueue, err := newTaskQueueID(namespaceID, taskQueueName, taskQueueType) if err != nil { return err } - taskQueueKind := request.TaskQueue.GetKind() - tlMgr, err := e.getTaskQueueManager(ctx, taskQueue, taskQueueKind, true) - if err != nil { + tlMgr, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, false) + if err != nil || tlMgr == nil { return err } @@ -737,12 +756,12 @@ func (e *matchingEngineImpl) DescribeTaskQueue( namespaceID := namespace.ID(request.GetNamespaceId()) taskQueueType := request.DescRequest.GetTaskQueueType() taskQueueName := request.DescRequest.TaskQueue.GetName() + stickyInfo := stickyInfoFromTaskQueue(request.DescRequest.TaskQueue) taskQueue, err := newTaskQueueID(namespaceID, taskQueueName, taskQueueType) if err != nil { return nil, err } - taskQueueKind := request.DescRequest.TaskQueue.GetKind() - tlMgr, err := e.getTaskQueueManager(ctx, taskQueue, taskQueueKind, true) + tlMgr, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, true) if err != nil { return nil, err } @@ -806,7 +825,7 @@ func (e *matchingEngineImpl) UpdateWorkerBuildIdCompatibility( if err != nil { return nil, err } - tqMgr, err := e.getTaskQueueManager(ctx, taskQueue, enumspb.TASK_QUEUE_KIND_NORMAL, true) + tqMgr, err := e.getTaskQueueManager(ctx, taskQueue, normalStickyInfo, true) if err != nil { return nil, err } @@ -853,7 +872,7 @@ func (e *matchingEngineImpl) GetWorkerBuildIdCompatibility( if err != nil { return nil, err } - tqMgr, err := e.getTaskQueueManager(ctx, taskQueue, enumspb.TASK_QUEUE_KIND_NORMAL, true) + tqMgr, err := e.getTaskQueueManager(ctx, taskQueue, normalStickyInfo, true) if err != nil { if _, ok := err.(*serviceerror.NotFound); ok { return &matchingservice.GetWorkerBuildIdCompatibilityResponse{}, nil @@ -881,7 +900,7 @@ func (e *matchingEngineImpl) GetTaskQueueUserData( if err != nil { return nil, err } - tqMgr, err := e.getTaskQueueManager(ctx, taskQueue, enumspb.TASK_QUEUE_KIND_NORMAL, true) + tqMgr, err := e.getTaskQueueManager(ctx, taskQueue, normalStickyInfo, true) if err != nil { return nil, err } @@ -937,7 +956,7 @@ func (e *matchingEngineImpl) ApplyTaskQueueUserDataReplicationEvent( if err != nil { return nil, err } - tqMgr, err := e.getTaskQueueManager(ctx, taskQueue, enumspb.TASK_QUEUE_KIND_NORMAL, true) + tqMgr, err := e.getTaskQueueManager(ctx, taskQueue, normalStickyInfo, true) if err != nil { return nil, err } @@ -977,8 +996,7 @@ func (e *matchingEngineImpl) ForceUnloadTaskQueue( if err != nil { return nil, err } - // kind is only used if we want to create a new tqm - tqm, err := e.getTaskQueueManager(ctx, taskQueue, enumspb.TASK_QUEUE_KIND_UNSPECIFIED, false) + tqm, err := e.getTaskQueueManager(ctx, taskQueue, normalStickyInfo, false) if err != nil { return nil, err } @@ -1072,19 +1090,19 @@ func (e *matchingEngineImpl) getAllPartitions( func (e *matchingEngineImpl) getTask( ctx context.Context, origTaskQueue *taskQueueID, - taskQueueKind enumspb.TaskQueueKind, + stickyInfo stickyInfo, pollMetadata *pollMetadata, ) (*internalTask, error) { taskQueue, err := e.redirectToVersionedQueueForPoll( ctx, origTaskQueue, pollMetadata.workerVersionCapabilities, - taskQueueKind, + stickyInfo, ) if err != nil { return nil, err } - tlMgr, err := e.getTaskQueueManager(ctx, taskQueue, taskQueueKind, true) + tlMgr, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, true) if err != nil { return nil, err } @@ -1100,7 +1118,7 @@ func (e *matchingEngineImpl) unloadTaskQueue(unloadTQM taskQueueManager) { return } delete(e.taskQueues, *queueID) - countKey := taskQueueCounterKey{namespaceID: queueID.namespaceID, taskType: queueID.taskType, queueType: foundTQM.TaskQueueKind()} + countKey := taskQueueCounterKey{namespaceID: queueID.namespaceID, taskType: queueID.taskType, kind: foundTQM.TaskQueueKind()} e.taskQueueCount[countKey]-- taskQueueCount := e.taskQueueCount[countKey] e.taskQueuesLock.Unlock() @@ -1120,7 +1138,7 @@ func (e *matchingEngineImpl) updateTaskQueueGauge(countKey taskQueueCounterKey, float64(taskQueueCount), metrics.NamespaceTag(namespace.String()), metrics.TaskTypeTag(countKey.taskType.String()), - metrics.QueueTypeTag(countKey.queueType.String()), + metrics.QueueTypeTag(countKey.kind.String()), ) } @@ -1282,22 +1300,15 @@ func (e *matchingEngineImpl) redirectToVersionedQueueForPoll( ctx context.Context, taskQueue *taskQueueID, workerVersionCapabilities *commonpb.WorkerVersionCapabilities, - kind enumspb.TaskQueueKind, + stickyInfo stickyInfo, ) (*taskQueueID, error) { - // Since sticky queues are pinned to a particular worker, we don't need to redirect - if kind == enumspb.TASK_QUEUE_KIND_STICKY { - // TODO: we may need to kick the workflow off of the sticky queue here - // (e.g. serviceerrors.StickyWorkerUnavailable) if there's a newer build id - return taskQueue, nil - } - if !workerVersionCapabilities.GetUseVersioning() { // Either this task queue is versioned, or there are still some workflows running on // the "unversioned" set. return taskQueue, nil } - unversionedTQM, err := e.getTaskQueueManager(ctx, taskQueue, kind, true) + unversionedTQM, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, true) if err != nil { return nil, err } @@ -1309,6 +1320,14 @@ func (e *matchingEngineImpl) redirectToVersionedQueueForPoll( return nil, err } data := userData.GetData().GetVersioningData() + + if stickyInfo.kind == enumspb.TASK_QUEUE_KIND_STICKY { + // In the sticky case we don't redirect, but we may kick off this worker if there's a + // newer one. + err := checkVersionForStickyPoll(data, workerVersionCapabilities) + return taskQueue, err + } + versionSet, err := lookupVersionSetForPoll(data, workerVersionCapabilities) if err != nil { return nil, err @@ -1320,13 +1339,8 @@ func (e *matchingEngineImpl) redirectToVersionedQueueForAdd( ctx context.Context, taskQueue *taskQueueID, directive *taskqueuespb.TaskVersionDirective, - kind enumspb.TaskQueueKind, + stickyInfo stickyInfo, ) (*taskQueueID, chan struct{}, error) { - // sticky queues are unversioned - if kind == enumspb.TASK_QUEUE_KIND_STICKY { - return taskQueue, nil, nil - } - var buildId string switch dir := directive.GetValue().(type) { case *taskqueuespb.TaskVersionDirective_UseDefault: @@ -1339,7 +1353,7 @@ func (e *matchingEngineImpl) redirectToVersionedQueueForAdd( } // Have to look up versioning data. - unversionedTQM, err := e.getTaskQueueManager(ctx, taskQueue, kind, true) + unversionedTQM, err := e.getTaskQueueManager(ctx, taskQueue, stickyInfo, true) if err != nil { return nil, nil, err } @@ -1348,6 +1362,14 @@ func (e *matchingEngineImpl) redirectToVersionedQueueForAdd( return nil, nil, err } data := userData.GetData().GetVersioningData() + + if stickyInfo.kind == enumspb.TASK_QUEUE_KIND_STICKY { + // In the sticky case we don't redirect, but we may kick off this worker if there's a + // newer one. + err := checkVersionForStickyAdd(data, buildId) + return taskQueue, userDataChanged, err + } + versionSet, err := lookupVersionSetForAdd(data, buildId) if err == errEmptyVersioningData { // default was requested for an unversioned queue diff --git a/service/matching/matchingEngine_test.go b/service/matching/matchingEngine_test.go index 475db90c29d..c8753a07cd8 100644 --- a/service/matching/matchingEngine_test.go +++ b/service/matching/matchingEngine_test.go @@ -280,14 +280,14 @@ func (s *matchingEngineSuite) TestOnlyUnloadMatchingInstance() { tqm, err := s.matchingEngine.getTaskQueueManager( context.Background(), queueID, - enumspb.TASK_QUEUE_KIND_NORMAL, + normalStickyInfo, true) s.Require().NoError(err) tqm2, err := newTaskQueueManager( s.matchingEngine, queueID, // same queueID as above - enumspb.TASK_QUEUE_KIND_NORMAL, + normalStickyInfo, s.matchingEngine.config, s.matchingEngine.clusterMeta, ) @@ -297,7 +297,7 @@ func (s *matchingEngineSuite) TestOnlyUnloadMatchingInstance() { s.matchingEngine.unloadTaskQueue(tqm2) got, err := s.matchingEngine.getTaskQueueManager( - context.Background(), queueID, enumspb.TASK_QUEUE_KIND_NORMAL, true) + context.Background(), queueID, normalStickyInfo, true) s.Require().NoError(err) s.Require().Same(tqm, got, "Unload call with non-matching taskQueueManager should not cause unload") @@ -306,7 +306,7 @@ func (s *matchingEngineSuite) TestOnlyUnloadMatchingInstance() { s.matchingEngine.unloadTaskQueue(tqm) got, err = s.matchingEngine.getTaskQueueManager( - context.Background(), queueID, enumspb.TASK_QUEUE_KIND_NORMAL, true) + context.Background(), queueID, normalStickyInfo, true) s.Require().NoError(err) s.Require().NotSame(tqm, got, "Unload call with matching incarnation should have caused unload") @@ -630,8 +630,7 @@ func (s *matchingEngineSuite) TestTaskWriterShutdown() { execution := &commonpb.WorkflowExecution{RunId: runID, WorkflowId: workflowID} tlID := newTestTaskQueueID(namespaceID, tl, enumspb.TASK_QUEUE_TYPE_ACTIVITY) - tlKind := enumspb.TASK_QUEUE_KIND_NORMAL - tlm, err := s.matchingEngine.getTaskQueueManager(context.Background(), tlID, tlKind, true) + tlm, err := s.matchingEngine.getTaskQueueManager(context.Background(), tlID, normalStickyInfo, true) s.Nil(err) addRequest := matchingservice.AddActivityTaskRequest{ @@ -787,7 +786,6 @@ func (s *matchingEngineSuite) TestSyncMatchActivities() { namespaceID := namespace.ID(uuid.New()) tl := "makeToast" tlID := newTestTaskQueueID(namespaceID, tl, enumspb.TASK_QUEUE_TYPE_ACTIVITY) - tlKind := enumspb.TASK_QUEUE_KIND_NORMAL s.matchingEngine.config.RangeSize = rangeSize // override to low number for the test // So we can get snapshots scope := tally.NewTestScope("test", nil) @@ -795,7 +793,7 @@ func (s *matchingEngineSuite) TestSyncMatchActivities() { var err error s.taskManager.getTaskQueueManager(tlID).rangeID = initialRangeID - mgr, err := newTaskQueueManager(s.matchingEngine, tlID, tlKind, s.matchingEngine.config, s.matchingEngine.clusterMeta) + mgr, err := newTaskQueueManager(s.matchingEngine, tlID, normalStickyInfo, s.matchingEngine.config, s.matchingEngine.clusterMeta) s.NoError(err) mgrImpl, ok := mgr.(*taskQueueManagerImpl) @@ -1009,12 +1007,11 @@ func (s *matchingEngineSuite) concurrentPublishConsumeActivities( namespaceID := namespace.ID(uuid.New()) tl := "makeToast" tlID := newTestTaskQueueID(namespaceID, tl, enumspb.TASK_QUEUE_TYPE_ACTIVITY) - tlKind := enumspb.TASK_QUEUE_KIND_NORMAL s.matchingEngine.config.RangeSize = rangeSize // override to low number for the test s.taskManager.getTaskQueueManager(tlID).rangeID = initialRangeID var err error - mgr, err := newTaskQueueManager(s.matchingEngine, tlID, tlKind, s.matchingEngine.config, s.matchingEngine.clusterMeta) + mgr, err := newTaskQueueManager(s.matchingEngine, tlID, normalStickyInfo, s.matchingEngine.config, s.matchingEngine.clusterMeta) s.NoError(err) mgrImpl := mgr.(*taskQueueManagerImpl) @@ -1619,7 +1616,6 @@ func (s *matchingEngineSuite) TestAddTaskAfterStartFailure() { namespaceID := namespace.ID(uuid.New()) tl := "makeToast" tlID := newTestTaskQueueID(namespaceID, tl, enumspb.TASK_QUEUE_TYPE_ACTIVITY) - tlKind := enumspb.TASK_QUEUE_KIND_NORMAL taskQueue := &taskqueuepb.TaskQueue{ Name: tl, @@ -1639,12 +1635,12 @@ func (s *matchingEngineSuite) TestAddTaskAfterStartFailure() { s.NoError(err) s.EqualValues(1, s.taskManager.getTaskCount(tlID)) - ctx, err := s.matchingEngine.getTask(context.Background(), tlID, tlKind, &pollMetadata{}) + ctx, err := s.matchingEngine.getTask(context.Background(), tlID, normalStickyInfo, &pollMetadata{}) s.NoError(err) ctx.finish(errors.New("test error")) s.EqualValues(1, s.taskManager.getTaskCount(tlID)) - ctx2, err := s.matchingEngine.getTask(context.Background(), tlID, tlKind, &pollMetadata{}) + ctx2, err := s.matchingEngine.getTask(context.Background(), tlID, normalStickyInfo, &pollMetadata{}) s.NoError(err) s.NotEqual(ctx.event.GetTaskId(), ctx2.event.GetTaskId()) @@ -1755,13 +1751,12 @@ func (s *matchingEngineSuite) TestTaskQueueManagerGetTaskBatch_ReadBatchDone() { namespaceID := namespace.ID(uuid.New()) tl := "makeToast" tlID := newTestTaskQueueID(namespaceID, tl, enumspb.TASK_QUEUE_TYPE_ACTIVITY) - tlNormal := enumspb.TASK_QUEUE_KIND_NORMAL const rangeSize = 10 const maxReadLevel = int64(120) config := defaultTestConfig() config.RangeSize = rangeSize - tlMgr0, err := newTaskQueueManager(s.matchingEngine, tlID, tlNormal, config, s.matchingEngine.clusterMeta) + tlMgr0, err := newTaskQueueManager(s.matchingEngine, tlID, normalStickyInfo, config, s.matchingEngine.clusterMeta) s.NoError(err) tlMgr, ok := tlMgr0.(*taskQueueManagerImpl) @@ -1793,13 +1788,12 @@ func (s *matchingEngineSuite) TestTaskQueueManager_CyclingBehavior() { namespaceID := namespace.ID(uuid.New()) tl := "makeToast" tlID := newTestTaskQueueID(namespaceID, tl, enumspb.TASK_QUEUE_TYPE_ACTIVITY) - tlNormal := enumspb.TASK_QUEUE_KIND_NORMAL config := defaultTestConfig() for i := 0; i < 4; i++ { prevGetTasksCount := s.taskManager.getGetTasksCount(tlID) - tlMgr, err := newTaskQueueManager(s.matchingEngine, tlID, tlNormal, config, s.matchingEngine.clusterMeta) + tlMgr, err := newTaskQueueManager(s.matchingEngine, tlID, normalStickyInfo, config, s.matchingEngine.clusterMeta) s.NoError(err) tlMgr.Start() diff --git a/service/matching/taskQueueManager.go b/service/matching/taskQueueManager.go index 72f878a67c6..f548e4d8c38 100644 --- a/service/matching/taskQueueManager.go +++ b/service/matching/taskQueueManager.go @@ -100,6 +100,11 @@ type ( forwardedFrom string } + stickyInfo struct { + kind enumspb.TaskQueueKind // sticky taskQueue has different process in persistence + normalName string // if kind is sticky, name of normal queue + } + UserDataUpdateOptions struct { Replicate bool TaskQueueLimitPerBuildId int @@ -141,10 +146,10 @@ type ( // Single task queue in memory state taskQueueManagerImpl struct { - status int32 - engine *matchingEngineImpl - taskQueueID *taskQueueID - taskQueueKind enumspb.TaskQueueKind // sticky taskQueue has different process in persistence + status int32 + engine *matchingEngineImpl + taskQueueID *taskQueueID + stickyInfo config *taskQueueConfig db *taskQueueDB taskWriter *taskWriter @@ -179,7 +184,12 @@ type ( var _ taskQueueManager = (*taskQueueManagerImpl)(nil) -var errRemoteSyncMatchFailed = serviceerror.NewCanceled("remote sync match failed") +var ( + errRemoteSyncMatchFailed = serviceerror.NewCanceled("remote sync match failed") + errMissingNormalQueueName = errors.New("missing normal queue name") + + normalStickyInfo = stickyInfo{kind: enumspb.TASK_QUEUE_KIND_NORMAL} +) func withIDBlockAllocator(ibl idBlockAllocator) taskQueueManagerOpt { return func(tqm *taskQueueManagerImpl) { @@ -187,10 +197,17 @@ func withIDBlockAllocator(ibl idBlockAllocator) taskQueueManagerOpt { } } +func stickyInfoFromTaskQueue(tq *taskqueuepb.TaskQueue) stickyInfo { + return stickyInfo{ + kind: tq.GetKind(), + normalName: tq.GetNormalName(), + } +} + func newTaskQueueManager( e *matchingEngineImpl, taskQueue *taskQueueID, - taskQueueKind enumspb.TaskQueueKind, + stickyInfo stickyInfo, config *Config, clusterMeta cluster.Metadata, opts ...taskQueueManagerOpt, @@ -203,7 +220,7 @@ func newTaskQueueManager( taskQueueConfig := newTaskQueueConfig(taskQueue, config, nsName) - db := newTaskQueueDB(e.taskManager, e.matchingClient, taskQueue.namespaceID, taskQueue, taskQueueKind, e.logger) + db := newTaskQueueDB(e.taskManager, e.matchingClient, taskQueue.namespaceID, taskQueue, stickyInfo.kind, e.logger) logger := log.With(e.logger, tag.WorkflowTaskQueueName(taskQueue.FullName()), tag.WorkflowTaskQueueType(taskQueue.taskType), @@ -212,7 +229,7 @@ func newTaskQueueManager( e.metricsHandler.WithTags(metrics.OperationTag(metrics.MatchingTaskQueueMgrScope), metrics.TaskQueueTypeTag(taskQueue.taskType)), nsName.String(), taskQueue.FullName(), - taskQueueKind, + stickyInfo.kind, ) tlMgr := &taskQueueManagerImpl{ status: common.DaemonStatusInitialized, @@ -221,7 +238,7 @@ func newTaskQueueManager( matchingClient: e.matchingClient, metricsHandler: e.metricsHandler, taskQueueID: taskQueue, - taskQueueKind: taskQueueKind, + stickyInfo: stickyInfo, logger: logger, db: db, taskAckManager: newAckManager(e.logger), @@ -245,11 +262,11 @@ func newTaskQueueManager( tlMgr.taskReader = newTaskReader(tlMgr) var fwdr *Forwarder - if tlMgr.isFowardingAllowed(taskQueue, taskQueueKind) { + if tlMgr.isFowardingAllowed(taskQueue, stickyInfo.kind) { // Forward without version set, the target will resolve the correct version set from // the build id itself. TODO: check if we still need this here after tqm refactoring forwardTaskQueue := newTaskQueueIDWithVersionSet(taskQueue, "") - fwdr = newForwarder(&taskQueueConfig.forwarderConfig, forwardTaskQueue, taskQueueKind, e.matchingClient) + fwdr = newForwarder(&taskQueueConfig.forwarderConfig, forwardTaskQueue, stickyInfo.kind, e.matchingClient) } tlMgr.matcher = newTaskMatcher(taskQueueConfig, fwdr, tlMgr.taggedMetricsHandler) for _, opt := range opts { @@ -291,7 +308,9 @@ func (c *taskQueueManagerImpl) Start() { c.liveness.Start() c.taskWriter.Start() c.taskReader.Start() - c.goroGroup.Go(c.fetchUserDataLoop) + if c.shouldFetchUserData() { + c.goroGroup.Go(c.fetchUserDataLoop) + } c.logger.Info("", tag.LifeCycleStarted) c.taggedMetricsHandler.Counter(metrics.TaskQueueStartedCounter.GetMetricName()).Record(1) } @@ -337,6 +356,15 @@ func (c *taskQueueManagerImpl) isVersioned() bool { return c.taskQueueID.VersionSet() != "" } +// shouldFetchUserData consolidates the logic for when to fetch user data from another task +// queue or (maybe) read it from the db. We set the userDataInitialFetch future from two +// places, so they need to agree on which one should set it. +func (c *taskQueueManagerImpl) shouldFetchUserData() bool { + // 1. If the db stores it, then we definitely should not be fetching. + // 2. Additionally, we should not fetch for "versioned" tqms. + return !c.db.DbStoresUserData() && !c.isVersioned() +} + func (c *taskQueueManagerImpl) WaitUntilInitialized(ctx context.Context) error { _, err := c.initializedError.Get(ctx) if err != nil { @@ -703,7 +731,7 @@ func (c *taskQueueManagerImpl) QueueID() *taskQueueID { } func (c *taskQueueManagerImpl) TaskQueueKind() enumspb.TaskQueueKind { - return c.taskQueueKind + return c.kind } func (c *taskQueueManagerImpl) callerInfoContext(ctx context.Context) context.Context { @@ -717,6 +745,15 @@ func (c *taskQueueManagerImpl) newIOContext() (context.Context, context.CancelFu } func (c *taskQueueManagerImpl) userDataFetchSource() (string, error) { + if c.kind == enumspb.TASK_QUEUE_KIND_STICKY { + // Sticky queues get data from their corresponding normal queue + if c.normalName == "" { + // Older SDKs don't send the normal name. That's okay, they just can't use versioning. + return "", errMissingNormalQueueName + } + return c.normalName, nil + } + degree := c.config.ForwarderMaxChildrenPerNode() parent, err := c.taskQueueID.Parent(degree) if err == tqname.ErrNoParent { @@ -732,13 +769,12 @@ func (c *taskQueueManagerImpl) userDataFetchSource() (string, error) { func (c *taskQueueManagerImpl) fetchUserDataLoop(ctx context.Context) error { ctx = c.callerInfoContext(ctx) - // root workflow partition reads data from db; versioned tqm has no user data - if c.taskQueueID.OwnsUserData() || c.isVersioned() { - return nil - } - fetchSource, err := c.userDataFetchSource() if err != nil { + if err == errMissingNormalQueueName { + // pretend we have no user data + c.userDataInitialFetch.Set(struct{}{}, nil) + } return err } diff --git a/service/matching/taskQueueManager_test.go b/service/matching/taskQueueManager_test.go index e6822955f1b..3b404199110 100644 --- a/service/matching/taskQueueManager_test.go +++ b/service/matching/taskQueueManager_test.go @@ -33,6 +33,7 @@ import ( "time" "github.com/golang/mock/gomock" + "github.com/pborman/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" commonpb "go.temporal.io/api/common/v1" @@ -350,8 +351,7 @@ func createTestTaskQueueManagerWithConfig( mockNamespaceCache.EXPECT().GetNamespaceName(gomock.Any()).Return(namespace.Name("ns-name"), nil).AnyTimes() cmeta := cluster.NewMetadataForTest(cluster.NewTestClusterMetadataConfig(false, true)) me := newMatchingEngine(testOpts.config, tm, nil, logger, mockNamespaceCache, testOpts.matchingClientMock) - tlKind := enumspb.TASK_QUEUE_KIND_NORMAL - tlMgr, err := newTaskQueueManager(me, testOpts.tqId, tlKind, testOpts.config, cmeta, opts...) + tlMgr, err := newTaskQueueManager(me, testOpts.tqId, normalStickyInfo, testOpts.config, cmeta, opts...) if err != nil { return nil, err } @@ -769,6 +769,63 @@ func TestTQMFetchesUserDataActivityToWorkflow(t *testing.T) { tq.Stop() } +func TestTQMFetchesUserDataStickyToNormal(t *testing.T) { + controller := gomock.NewController(t) + defer controller.Finish() + ctx := context.Background() + tqCfg := defaultTqmTestOpts(controller) + + normalName := "normal-queue" + stickyName := uuid.New() + + tqId, err := newTaskQueueIDWithPartition(defaultNamespaceId, stickyName, enumspb.TASK_QUEUE_TYPE_WORKFLOW, 0) + require.NoError(t, err) + tqCfg.tqId = tqId + + data1 := &persistencespb.VersionedTaskQueueUserData{ + Version: 1, + Data: mkUserData(1), + } + + tqCfg.matchingClientMock.EXPECT().GetTaskQueueUserData( + gomock.Any(), + &matchingservice.GetTaskQueueUserDataRequest{ + NamespaceId: defaultNamespaceId.String(), + TaskQueue: normalName, + TaskQueueType: enumspb.TASK_QUEUE_TYPE_WORKFLOW, + LastKnownUserDataVersion: 0, + WaitNewData: false, + }). + Return(&matchingservice.GetTaskQueueUserDataResponse{ + TaskQueueHasUserData: true, + UserData: data1, + }, nil) + + // have to create manually to get sticky + logger := log.NewTestLogger() + tm := newTestTaskManager(logger) + mockNamespaceCache := namespace.NewMockRegistry(controller) + mockNamespaceCache.EXPECT().GetNamespaceByID(gomock.Any()).Return(&namespace.Namespace{}, nil).AnyTimes() + mockNamespaceCache.EXPECT().GetNamespaceName(gomock.Any()).Return(namespace.Name("ns-name"), nil).AnyTimes() + me := newMatchingEngine(tqCfg.config, tm, nil, logger, mockNamespaceCache, tqCfg.matchingClientMock) + cmeta := cluster.NewMetadataForTest(cluster.NewTestClusterMetadataConfig(false, true)) + stickyInfo := stickyInfo{ + kind: enumspb.TASK_QUEUE_KIND_STICKY, + normalName: normalName, + } + tlMgr, err := newTaskQueueManager(me, tqCfg.tqId, stickyInfo, tqCfg.config, cmeta) + require.NoError(t, err) + tq := tlMgr.(*taskQueueManagerImpl) + + tq.config.GetUserDataMinWaitTime = 10 * time.Second // wait on success + tq.Start() + require.NoError(t, tq.WaitUntilInitialized(ctx)) + userData, _, err := tq.GetUserData(ctx) + require.NoError(t, err) + require.Equal(t, data1, userData) + tq.Stop() +} + func TestUpdateOnNonRootFails(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() diff --git a/service/matching/taskReader.go b/service/matching/taskReader.go index 87dcfa06e7b..e9d95ae352d 100644 --- a/service/matching/taskReader.go +++ b/service/matching/taskReader.go @@ -131,7 +131,7 @@ dispatchLoop: // Don't try to set read level here because it may have been advanced already. break } - err := tr.tlMgr.engine.DispatchSpooledTask(ctx, task, tr.tlMgr.taskQueueID, tr.tlMgr.taskQueueKind) + err := tr.tlMgr.engine.DispatchSpooledTask(ctx, task, tr.tlMgr.taskQueueID, tr.tlMgr.stickyInfo) if err == nil { break } diff --git a/service/matching/taskWriter.go b/service/matching/taskWriter.go index c28a2650d38..91540029923 100644 --- a/service/matching/taskWriter.go +++ b/service/matching/taskWriter.go @@ -219,7 +219,7 @@ func (w *taskWriter) appendTasks( func (w *taskWriter) taskWriterLoop(ctx context.Context) error { err := w.initReadWriteState(ctx) w.tlMgr.initializedError.Set(struct{}{}, err) - if w.taskQueueID.OwnsUserData() || w.tlMgr.isVersioned() { + if !w.tlMgr.shouldFetchUserData() { w.tlMgr.userDataInitialFetch.Set(struct{}{}, err) } if err != nil { diff --git a/service/matching/version_sets.go b/service/matching/version_sets.go index 8a65992c9a6..a8436d5d340 100644 --- a/service/matching/version_sets.go +++ b/service/matching/version_sets.go @@ -36,6 +36,7 @@ import ( "go.temporal.io/api/workflowservice/v1" persistencespb "go.temporal.io/server/api/persistence/v1" hlc "go.temporal.io/server/common/clock/hybrid_logical_clock" + serviceerrors "go.temporal.io/server/common/serviceerror" ) var ( @@ -313,7 +314,7 @@ func lookupVersionSetForPoll(data *persistencespb.VersioningData, caps *commonpb // For poll, only the latest version in the compatible set can get tasks. // Find the version set that this worker is in. // Note data may be nil here, findVersion will return -1 then. - setIdx, _ := findVersion(data, caps.BuildId) + setIdx, indexInSet := findVersion(data, caps.BuildId) if setIdx < 0 { // A poller is using a build ID but we don't know about that build ID. This can happen // in a replication scenario if pollers are running on the passive side before the data @@ -328,13 +329,33 @@ func lookupVersionSetForPoll(data *persistencespb.VersioningData, caps *commonpb return guessedSetId, nil } set := data.VersionSets[setIdx] - latestInSet := set.BuildIds[len(set.BuildIds)-1].Id - if caps.BuildId != latestInSet { - return "", serviceerror.NewNewerBuildExists(latestInSet) + lastIndex := len(set.BuildIds) - 1 + if indexInSet != lastIndex { + return "", serviceerror.NewNewerBuildExists(set.BuildIds[lastIndex].Id) } return getSetID(set), nil } +// Requires: caps is not nil +func checkVersionForStickyPoll(data *persistencespb.VersioningData, caps *commonpb.WorkerVersionCapabilities) error { + // For poll, only the latest version in the compatible set can get tasks. + // Find the version set that this worker is in. + // Note data may be nil here, findVersion will return -1 then. + setIdx, indexInSet := findVersion(data, caps.BuildId) + if setIdx < 0 { + // A poller is using a build ID but we don't know about that build ID. See comments in + // lookupVersionSetForPoll. If we consider it the default for its set, then we should + // leave it on the sticky queue here. + return nil + } + set := data.VersionSets[setIdx] + lastIndex := len(set.BuildIds) - 1 + if indexInSet != lastIndex { + return serviceerror.NewNewerBuildExists(set.BuildIds[lastIndex].Id) + } + return nil +} + // For this function, buildId == "" means "use default" func lookupVersionSetForAdd(data *persistencespb.VersioningData, buildId string) (string, error) { var set *persistencespb.CompatibleVersionSet @@ -371,6 +392,28 @@ func lookupVersionSetForAdd(data *persistencespb.VersioningData, buildId string) return getSetID(set), nil } +// For this function, buildId == "" means "use default" +func checkVersionForStickyAdd(data *persistencespb.VersioningData, buildId string) error { + if buildId == "" { + // This shouldn't happen. + return serviceerror.NewInternal("should have a build id directive on versioned sticky queue") + } + // For add, any version in the compatible set maps to the set. + // Note data may be nil here, findVersion will return -1 then. + setIdx, indexInSet := findVersion(data, buildId) + if setIdx < 0 { + // A poller is using a build ID but we don't know about that build ID. See comments in + // lookupVersionSetForAdd. If we consider it the default for its set, then we should + // leave it on the sticky queue here. + return nil + } + // If this is not the set's default anymore, we need to kick it back to the regular queue. + if indexInSet != len(data.VersionSets[setIdx].BuildIds)-1 { + return serviceerrors.NewStickyWorkerUnavailable() + } + return nil +} + // getSetID returns an arbitrary but consistent member of the set. // We want Add and Poll requests for the same set to converge on a single id so we can match // them, but we don't have a single id for a set in the general case: in rare cases we may have diff --git a/tests/versioning_test.go b/tests/versioning_test.go index 6550a0219fa..eadb6bd4188 100644 --- a/tests/versioning_test.go +++ b/tests/versioning_test.go @@ -33,6 +33,7 @@ import ( "testing" "time" + "github.com/dgryski/go-farm" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -62,6 +63,9 @@ type versioningIntegSuite struct { const ( partitionTreeDegree = 3 + longPollTime = 5 * time.Second + // use > 2 pollers by default to expose more timing situations + numPollers = 4 ) func (s *versioningIntegSuite) SetupSuite() { @@ -70,6 +74,12 @@ func (s *versioningIntegSuite) SetupSuite() { dynamicconfig.FrontendEnableWorkerVersioningWorkflowAPIs: true, dynamicconfig.MatchingForwarderMaxChildrenPerNode: partitionTreeDegree, dynamicconfig.TaskQueuesPerBuildIdLimit: 3, + + // The dispatch tests below rely on being able to see the effects of changing + // versioning data relatively quickly. In general we only promise to act on new + // versioning data "soon", i.e. after a long poll interval. We can reduce the long poll + // interval so that we don't have to wait so long. + dynamicconfig.MatchingLongPollExpirationInterval: longPollTime, } s.setupSuite("testdata/integration_test_cluster.yaml") } @@ -279,15 +289,16 @@ func (s *versioningIntegSuite) dispatchNewWorkflow() { return "done!", nil } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() s.addNewDefaultBuildId(ctx, tq, "v1") s.waitForPropagation(ctx, tq, "v1") w1 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v1"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w1.RegisterWorkflow(wf) s.NoError(w1.Start()) @@ -313,8 +324,9 @@ func (s *versioningIntegSuite) dispatchNewWorkflowStartWorkerFirst() { // run worker before registering build. it will use guessed set id w1 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v1"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w1.RegisterWorkflow(wf) s.NoError(w1.Start()) @@ -323,7 +335,7 @@ func (s *versioningIntegSuite) dispatchNewWorkflowStartWorkerFirst() { // wait for it to start polling time.Sleep(200 * time.Millisecond) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() s.addNewDefaultBuildId(ctx, tq, "v1") @@ -343,7 +355,7 @@ func (s *versioningIntegSuite) TestDispatchUnversionedRemainsUnversioned() { func (s *versioningIntegSuite) dispatchUnversionedRemainsUnversioned() { tq := s.randomizeStr(s.T().Name()) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() started := make(chan struct{}, 1) @@ -376,14 +388,15 @@ func (s *versioningIntegSuite) dispatchUnversionedRemainsUnversioned() { s.Equal("done!", out) } -func (s *versioningIntegSuite) TestDispatchUpgradeStickyTimeout() { +func (s *versioningIntegSuite) TestDispatchUpgradeStopOld() { s.testWithMatchingBehavior(func() { s.dispatchUpgrade(true) }) } -func (s *versioningIntegSuite) TestDispatchUpgradeStickyUnavailable() { + +func (s *versioningIntegSuite) TestDispatchUpgradeWait() { s.testWithMatchingBehavior(func() { s.dispatchUpgrade(false) }) } -func (s *versioningIntegSuite) dispatchUpgrade(letStickyWftTimeout bool) { +func (s *versioningIntegSuite) dispatchUpgrade(stopOld bool) { tq := s.randomizeStr(s.T().Name()) started := make(chan struct{}, 1) @@ -394,21 +407,21 @@ func (s *versioningIntegSuite) dispatchUpgrade(letStickyWftTimeout bool) { return "done!", nil } - wf2 := func(ctx workflow.Context) (string, error) { + wf11 := func(ctx workflow.Context) (string, error) { workflow.GetSignalChannel(ctx, "wait").Receive(ctx, nil) - return "done from two!", nil + return "done from 1.1!", nil } - // TODO: reduce after fixing sticky - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() s.addNewDefaultBuildId(ctx, tq, "v1") s.waitForPropagation(ctx, tq, "v1") w1 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v1"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w1.RegisterWorkflowWithOptions(wf1, workflow.RegisterOptions{Name: "wf"}) s.NoError(w1.Start()) @@ -418,39 +431,41 @@ func (s *versioningIntegSuite) dispatchUpgrade(letStickyWftTimeout bool) { s.NoError(err) s.waitForChan(ctx, started) - // Stop w1 to break stickiness - // TODO: this shouldn't be necessary, add behavior cases that disable stickiness - w1.Stop() + // now add v11 as compatible so the next workflow task runs there + s.addCompatibleBuildId(ctx, tq, "v11", "v1", false) + s.waitForPropagation(ctx, tq, "v11") + // add another 100ms to make sure it got to sticky queues also + time.Sleep(100 * time.Millisecond) - // two methods of breaking stickiness: - if letStickyWftTimeout { - // in this case we just start the new worker and kick the workflow immediately. the new - // wft will go to the sticky queue, be spooled, but eventually timeout and we'll get a - // new wft. + w11 := worker.New(s.sdkClient, tq, worker.Options{ + BuildID: s.prefixed("v11"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, + }) + w11.RegisterWorkflowWithOptions(wf11, workflow.RegisterOptions{Name: "wf"}) + s.NoError(w11.Start()) + defer w11.Stop() + + // Two cases: + if stopOld { + // Stop the old worker. Workflow tasks will go to the sticky queue, which will see that + // it's not the latest and kick them back to the normal queue, which will be dispatched + // to v11. + w1.Stop() } else { - // in this case we sleep for more than stickyPollerUnavailableWindow. matching will - // return StickyWorkerUnavailable immediately after that. - time.Sleep(11 * time.Second) + // Don't stop the old worker. In this case, w1 will still have some pollers blocked on + // the normal queue which could pick up tasks that we want to go to v11. (We don't + // interrupt long polls.) To ensure those polls don't interfere, wait for them to + // expire. + time.Sleep(longPollTime) } - // now add v2 as compatible so the next workflow task runs there - s.addCompatibleBuildId(ctx, tq, "v2", "v1", false) - s.waitForPropagation(ctx, tq, "v2") - - w2 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v2"), - UseBuildIDForVersioning: true, - }) - w2.RegisterWorkflowWithOptions(wf2, workflow.RegisterOptions{Name: "wf"}) - s.NoError(w2.Start()) - defer w2.Stop() - // unblock the workflow s.NoError(s.sdkClient.SignalWorkflow(ctx, run.GetID(), run.GetRunID(), "wait", nil)) var out string s.NoError(run.Get(ctx, &out)) - s.Equal("done from two!", out) + s.Equal("done from 1.1!", out) } func (s *versioningIntegSuite) TestDispatchActivity() { @@ -493,15 +508,16 @@ func (s *versioningIntegSuite) dispatchActivity() { panic("workflow should not run on v2") } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() s.addNewDefaultBuildId(ctx, tq, "v1") s.waitForPropagation(ctx, tq, "v1") w1 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v1"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w1.RegisterWorkflowWithOptions(wf1, workflow.RegisterOptions{Name: "wf"}) w1.RegisterActivityWithOptions(act1, activity.RegisterOptions{Name: "act"}) @@ -519,8 +535,9 @@ func (s *versioningIntegSuite) dispatchActivity() { s.waitForPropagation(ctx, tq, "v2") // start worker for v2 w2 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v2"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v2"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w2.RegisterWorkflowWithOptions(wf2, workflow.RegisterOptions{Name: "wf"}) w2.RegisterActivityWithOptions(act2, activity.RegisterOptions{Name: "act"}) @@ -569,15 +586,16 @@ func (s *versioningIntegSuite) dispatchChildWorkflow() { panic("workflow should not run on v2") } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() s.addNewDefaultBuildId(ctx, tq, "v1") s.waitForPropagation(ctx, tq, "v1") w1 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v1"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w1.RegisterWorkflowWithOptions(wf1, workflow.RegisterOptions{Name: "wf"}) w1.RegisterWorkflowWithOptions(child1, workflow.RegisterOptions{Name: "child"}) @@ -595,8 +613,9 @@ func (s *versioningIntegSuite) dispatchChildWorkflow() { s.waitForPropagation(ctx, tq, "v2") // start worker for v2 w2 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v2"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v2"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w2.RegisterWorkflowWithOptions(wf2, workflow.RegisterOptions{Name: "wf"}) w2.RegisterWorkflowWithOptions(child2, workflow.RegisterOptions{Name: "child"}) @@ -616,21 +635,19 @@ func (s *versioningIntegSuite) TestDispatchContinueAsNew() { } func (s *versioningIntegSuite) dispatchContinueAsNew() { - // TODO: this test will need updating after fixing stickiness, see comments below tq := s.randomizeStr(s.T().Name()) - started1 := make(chan struct{}) - started11 := make(chan struct{}) - started2 := make(chan struct{}) + started1 := make(chan struct{}, 10) + started11 := make(chan struct{}, 20) wf1 := func(ctx workflow.Context, attempt int) (string, error) { started1 <- struct{}{} workflow.GetSignalChannel(ctx, "wait").Receive(ctx, nil) switch attempt { case 0: - // TODO: after fixing stickiness, comment this out: - return "", workflow.NewContinueAsNewError(ctx, "wf", attempt+1) + // return "", workflow.NewContinueAsNewError(ctx, "wf", attempt+1) case 1: + // newCtx := workflow.WithWorkflowVersioningIntent(ctx, temporal.VersioningIntentDefault) // this one should go to default // return "", workflow.NewContinueAsNewError(newCtx, "wf", attempt+1) case 2: // return "done!", nil @@ -642,8 +659,7 @@ func (s *versioningIntegSuite) dispatchContinueAsNew() { workflow.GetSignalChannel(ctx, "wait").Receive(ctx, nil) switch attempt { case 0: - // TODO: after fixing stickiness, uncomment this: - // return "", workflow.NewContinueAsNewError(ctx, "wf", attempt+1) + return "", workflow.NewContinueAsNewError(ctx, "wf", attempt+1) case 1: newCtx := workflow.WithWorkflowVersioningIntent(ctx, temporal.VersioningIntentDefault) // this one should go to default return "", workflow.NewContinueAsNewError(newCtx, "wf", attempt+1) @@ -653,28 +669,23 @@ func (s *versioningIntegSuite) dispatchContinueAsNew() { panic("oops") } wf2 := func(ctx workflow.Context, attempt int) (string, error) { - started2 <- struct{}{} - workflow.GetSignalChannel(ctx, "wait").Receive(ctx, nil) switch attempt { - case 0: - // return "",workflow.NewContinueAsNewError(ctx, "wf", attempt+1) - case 1: - // return "", workflow.NewContinueAsNewError(newCtx, "wf", attempt+1) case 2: return "done!", nil } panic("oops") } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() s.addNewDefaultBuildId(ctx, tq, "v1") s.waitForPropagation(ctx, tq, "v1") w1 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v1"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w1.RegisterWorkflowWithOptions(wf1, workflow.RegisterOptions{Name: "wf"}) s.NoError(w1.Start()) @@ -689,35 +700,40 @@ func (s *versioningIntegSuite) dispatchContinueAsNew() { s.addCompatibleBuildId(ctx, tq, "v11", "v1", false) s.addNewDefaultBuildId(ctx, tq, "v2") s.waitForPropagation(ctx, tq, "v2") + // add another 100ms to make sure it got to sticky queues also + time.Sleep(100 * time.Millisecond) // start workers for v11 and v2 w11 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v11"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v11"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w11.RegisterWorkflowWithOptions(wf11, workflow.RegisterOptions{Name: "wf"}) s.NoError(w11.Start()) defer w11.Stop() w2 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v2"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v2"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w2.RegisterWorkflowWithOptions(wf2, workflow.RegisterOptions{Name: "wf"}) s.NoError(w2.Start()) defer w2.Stop() - // unblock the workflow. it should continue on v1 then continue-as-new onto v11 - // TODO: after fixing stickiness, it should continue on v11 and then continue-as-new onto v11. - // will also need to mess with channels then. + // wait for w1 long polls to all time out + time.Sleep(longPollTime) + + // unblock the workflow. it should get kicked off the sticky queue and replay on v11 s.NoError(s.sdkClient.SignalWorkflow(ctx, run.GetID(), "", "wait", nil)) + s.waitForChan(ctx, started11) - // wait for it to start on v11 then unblock. it should continue on v11 then continue-as-new onto v2. + // then continue-as-new onto v11 s.waitForChan(ctx, started11) - s.NoError(s.sdkClient.SignalWorkflow(ctx, run.GetID(), "", "wait", nil)) - // wait for it to start on v2 and unblock. it should return. - s.waitForChan(ctx, started2) + // unblock the second run. it should continue on v11 then continue-as-new onto v2, then + // complete. s.NoError(s.sdkClient.SignalWorkflow(ctx, run.GetID(), "", "wait", nil)) var out string @@ -730,19 +746,17 @@ func (s *versioningIntegSuite) TestDispatchRetry() { } func (s *versioningIntegSuite) dispatchRetry() { - // TODO: this test will need updating after fixing stickiness, see comments below tq := s.randomizeStr(s.T().Name()) - started1 := make(chan struct{}, 3) - started11 := make(chan struct{}, 3) + started1 := make(chan struct{}, 10) + started11 := make(chan struct{}, 30) wf1 := func(ctx workflow.Context) (string, error) { started1 <- struct{}{} workflow.GetSignalChannel(ctx, "wait").Receive(ctx, nil) switch workflow.GetInfo(ctx).Attempt { case 1: - // TODO: stickiness - return "", errors.New("try again") + // return "", errors.New("try again") case 2: // return "", errors.New("try again") case 3: @@ -755,8 +769,7 @@ func (s *versioningIntegSuite) dispatchRetry() { workflow.GetSignalChannel(ctx, "wait").Receive(ctx, nil) switch workflow.GetInfo(ctx).Attempt { case 1: - // TODO: stickiness - // return "", errors.New("try again") + return "", errors.New("try again") case 2: return "", errors.New("try again") case 3: @@ -768,15 +781,16 @@ func (s *versioningIntegSuite) dispatchRetry() { panic("oops") } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() s.addNewDefaultBuildId(ctx, tq, "v1") s.waitForPropagation(ctx, tq, "v1") w1 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v1"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w1.RegisterWorkflowWithOptions(wf1, workflow.RegisterOptions{Name: "wf"}) s.NoError(w1.Start()) @@ -790,46 +804,47 @@ func (s *versioningIntegSuite) dispatchRetry() { }, "wf") s.NoError(err) // wait for it to start on v1 - <-started1 + s.waitForChan(ctx, started1) // now register v11 as newer compatible with v1 AND v2 as a new default s.addCompatibleBuildId(ctx, tq, "v11", "v1", false) s.addNewDefaultBuildId(ctx, tq, "v2") s.waitForPropagation(ctx, tq, "v2") + // add another 100ms to make sure it got to sticky queues also + time.Sleep(100 * time.Millisecond) // start workers for v11 and v2 w11 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v11"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v11"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w11.RegisterWorkflowWithOptions(wf11, workflow.RegisterOptions{Name: "wf"}) s.NoError(w11.Start()) defer w11.Stop() w2 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v2"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v2"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w2.RegisterWorkflowWithOptions(wf2, workflow.RegisterOptions{Name: "wf"}) s.NoError(w2.Start()) defer w2.Stop() - // unblock the workflow. it should continue on v1 then retry onto v11 - // TODO: after fixing stickiness, it should continue on v11 and then retry onto v11. - // will also need to mess with channels then. - s.NoError(s.sdkClient.SignalWorkflow(ctx, run.GetID(), "", "wait", nil)) + // wait for w1 long polls to all time out + time.Sleep(longPollTime) - // TODO: fix this hack by making polls interruptable? the problem is w1 still has a poller - // sitting around blocked since before we changed versioning data - time.Sleep(500 * time.Millisecond) - w1.Stop() + // unblock the workflow. it should replay on v11 and then retry (on v11). + s.NoError(s.sdkClient.SignalWorkflow(ctx, run.GetID(), "", "wait", nil)) + s.waitForChan(ctx, started11) // replay + s.waitForChan(ctx, started11) // attempt 2 - // wait for it to start on v11 then unblock. it should continue on v11 then retry onto v11 again. - <-started11 + // now it's blocked in attempt 2. unblock it. s.NoError(s.sdkClient.SignalWorkflow(ctx, run.GetID(), "", "wait", nil)) - // wait for it to start on v11 and unblock. it should return. - <-started11 + // wait for attempt 3. unblock that and it should return. + s.waitForChan(ctx, started11) // attempt 3 s.NoError(s.sdkClient.SignalWorkflow(ctx, run.GetID(), "", "wait", nil)) var out string @@ -861,15 +876,16 @@ func (s *versioningIntegSuite) dispatchCron() { return "ok", nil } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() s.addNewDefaultBuildId(ctx, tq, "v1") s.waitForPropagation(ctx, tq, "v1") w1 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v1"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v1"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w1.RegisterWorkflowWithOptions(wf1, workflow.RegisterOptions{Name: "wf"}) s.NoError(w1.Start()) @@ -892,16 +908,18 @@ func (s *versioningIntegSuite) dispatchCron() { // start workers for v11 and v2 w11 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v11"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v11"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w11.RegisterWorkflowWithOptions(wf11, workflow.RegisterOptions{Name: "wf"}) s.NoError(w11.Start()) defer w11.Stop() w2 := worker.New(s.sdkClient, tq, worker.Options{ - BuildID: s.prefixed("v2"), - UseBuildIDForVersioning: true, + BuildID: s.prefixed("v2"), + UseBuildIDForVersioning: true, + MaxConcurrentWorkflowTaskPollers: numPollers, }) w2.RegisterWorkflowWithOptions(wf2, workflow.RegisterOptions{Name: "wf"}) s.NoError(w2.Start()) @@ -917,7 +935,7 @@ func (s *versioningIntegSuite) dispatchCron() { // Add a per test prefix to avoid hitting the namespace limit of mapped task queue per build id func (s *versioningIntegSuite) prefixed(buildId string) string { - return s.T().Name() + ":" + buildId + return fmt.Sprintf("t%x:%s", 0xffff&farm.Hash32([]byte(s.T().Name())), buildId) } // addNewDefaultBuildId updates build id info on a task queue with a new build id in a new default set.