Skip to content

Commit

Permalink
Handle worker versioning for sticky queues (#4397)
Browse files Browse the repository at this point in the history
Note: This commit came from a feature branch and is not expected to build.
  • Loading branch information
dnr committed May 26, 2023
1 parent b717268 commit 2de41e5
Show file tree
Hide file tree
Showing 14 changed files with 436 additions and 254 deletions.
1 change: 1 addition & 0 deletions service/frontend/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ var (
errInvalidWorkflowStartDelaySeconds = serviceerror.NewInvalidArgument("An invalid WorkflowStartDelaySeconds is set on request.")
errRaceConditionAddingSearchAttributes = serviceerror.NewUnavailable("Generated search attributes mapping unavailble.")
errUseVersioningWithoutBuildId = serviceerror.NewInvalidArgument("WorkerVersionStamp must be present if UseVersioning is true.")
errUseVersioningWithoutNormalName = serviceerror.NewInvalidArgument("NormalName must be set on sticky queue if UseVersioning is true.")
errBuildIdTooLong = serviceerror.NewInvalidArgument("Build ID exceeds configured limit.workerBuildIdSize, use a shorter build ID.")

errUpdateMetaNotSet = serviceerror.NewInvalidArgument("Update meta is not set on request.")
Expand Down
61 changes: 31 additions & 30 deletions service/frontend/workflow_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -840,12 +840,8 @@ func (wh *WorkflowHandler) PollWorkflowTaskQueue(ctx context.Context, request *w
return nil, errIdentityTooLong
}

if request.GetWorkerVersionCapabilities().GetUseVersioning() && !wh.config.EnableWorkerVersioningWorkflow(request.Namespace) {
return nil, errWorkerVersioningNotAllowed
}

if len(request.GetWorkerVersionCapabilities().GetBuildId()) > wh.config.WorkerBuildIdSizeLimit() {
return nil, errBuildIdTooLong
if err := wh.validateVersioningInfo(request.Namespace, request.WorkerVersionCapabilities, request.TaskQueue); err != nil {
return nil, err
}

if err := wh.validateTaskQueue(request.TaskQueue); err != nil {
Expand All @@ -858,14 +854,6 @@ func (wh *WorkflowHandler) PollWorkflowTaskQueue(ctx context.Context, request *w
}
namespaceID := namespaceEntry.ID()

// Copy WorkerVersionCapabilities.BuildId to BinaryChecksum if BinaryChecksum is missing (small
// optimization to save space in the poll request).
if request.WorkerVersionCapabilities != nil {
if len(request.WorkerVersionCapabilities.BuildId) > 0 && len(request.BinaryChecksum) == 0 {
request.BinaryChecksum = request.WorkerVersionCapabilities.BuildId
}
}

wh.logger.Debug("Poll workflow task queue.", tag.WorkflowNamespace(namespaceEntry.Name().String()), tag.WorkflowNamespaceID(namespaceID.String()))
if err := wh.checkBadBinary(namespaceEntry, request.GetBinaryChecksum()); err != nil {
return nil, err
Expand Down Expand Up @@ -942,12 +930,12 @@ func (wh *WorkflowHandler) RespondWorkflowTaskCompleted(
return nil, errIdentityTooLong
}

if request.GetWorkerVersionStamp().GetUseVersioning() && !wh.config.EnableWorkerVersioningWorkflow(request.Namespace) {
return nil, errWorkerVersioningNotAllowed
}

if len(request.GetWorkerVersionStamp().GetBuildId()) > wh.config.WorkerBuildIdSizeLimit() {
return nil, errBuildIdTooLong
if err := wh.validateVersioningInfo(
request.Namespace,
request.WorkerVersionStamp,
request.StickyAttributes.GetWorkerTaskQueue(),
); err != nil {
return nil, err
}

taskToken, err := wh.tokenSerializer.Deserialize(request.TaskToken)
Expand All @@ -956,10 +944,6 @@ func (wh *WorkflowHandler) RespondWorkflowTaskCompleted(
}
namespaceId := namespace.ID(taskToken.GetNamespaceId())

if request.WorkerVersionStamp.GetUseVersioning() && len(request.WorkerVersionStamp.GetBuildId()) == 0 {
return nil, errUseVersioningWithoutBuildId
}

wh.overrides.DisableEagerActivityDispatchForBuggyClients(ctx, request)

histResp, err := wh.historyClient.RespondWorkflowTaskCompleted(ctx, &historyservice.RespondWorkflowTaskCompletedRequest{
Expand Down Expand Up @@ -1102,12 +1086,8 @@ func (wh *WorkflowHandler) PollActivityTaskQueue(ctx context.Context, request *w
return nil, errIdentityTooLong
}

if request.GetWorkerVersionCapabilities().GetUseVersioning() && !wh.config.EnableWorkerVersioningWorkflow(request.Namespace) {
return nil, errWorkerVersioningNotAllowed
}

if len(request.GetWorkerVersionCapabilities().GetBuildId()) > wh.config.WorkerBuildIdSizeLimit() {
return nil, errBuildIdTooLong
if err := wh.validateVersioningInfo(request.Namespace, request.WorkerVersionCapabilities, request.TaskQueue); err != nil {
return nil, err
}

namespaceID, err := wh.namespaceRegistry.GetNamespaceID(namespace.Name(request.GetNamespace()))
Expand Down Expand Up @@ -4292,6 +4272,27 @@ func (wh *WorkflowHandler) validateTaskQueue(t *taskqueuepb.TaskQueue) error {
return nil
}

type buildIdAndFlag interface {
GetBuildId() string
GetUseVersioning() bool
}

func (wh *WorkflowHandler) validateVersioningInfo(namespace string, id buildIdAndFlag, tq *taskqueuepb.TaskQueue) error {
if id.GetUseVersioning() && !wh.config.EnableWorkerVersioningWorkflow(namespace) {
return errWorkerVersioningNotAllowed
}
if id.GetUseVersioning() && tq.GetKind() == enumspb.TASK_QUEUE_KIND_STICKY && len(tq.GetNormalName()) == 0 {
return errUseVersioningWithoutNormalName
}
if id.GetUseVersioning() && len(id.GetBuildId()) == 0 {
return errUseVersioningWithoutBuildId
}
if len(id.GetBuildId()) > wh.config.WorkerBuildIdSizeLimit() {
return errBuildIdTooLong
}
return nil
}

//nolint:revive // cyclomatic complexity
func (wh *WorkflowHandler) validateBuildIdCompatibilityUpdate(
req *workflowservice.UpdateWorkerBuildIdCompatibilityRequest,
Expand Down
5 changes: 3 additions & 2 deletions service/history/api/get_workflow_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,9 @@ func MutableStateToGetResponse(
Kind: enumspb.TASK_QUEUE_KIND_NORMAL,
},
StickyTaskQueue: &taskqueuepb.TaskQueue{
Name: executionInfo.StickyTaskQueue,
Kind: enumspb.TASK_QUEUE_KIND_STICKY,
Name: executionInfo.StickyTaskQueue,
Kind: enumspb.TASK_QUEUE_KIND_STICKY,
NormalName: executionInfo.TaskQueue,
},
StickyTaskQueueScheduleToStartTimeout: executionInfo.StickyScheduleToStartTimeout,
CurrentBranchToken: currentBranchToken,
Expand Down
3 changes: 2 additions & 1 deletion service/history/transferQueueActiveTaskExecutor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2642,8 +2642,9 @@ func (s *transferQueueActiveTaskExecutorSuite) createAddWorkflowTaskRequest(
}
executionInfo := mutableState.GetExecutionInfo()
timeout := executionInfo.WorkflowRunTimeout
if mutableState.GetExecutionInfo().TaskQueue != task.TaskQueue {
if executionInfo.TaskQueue != task.TaskQueue {
taskQueue.Kind = enumspb.TASK_QUEUE_KIND_STICKY
taskQueue.NormalName = executionInfo.TaskQueue
timeout = executionInfo.StickyScheduleToStartTimeout
}

Expand Down
10 changes: 6 additions & 4 deletions service/history/workflow/mutable_state_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,9 @@ func (ms *MutableStateImpl) GetNamespaceEntry() *namespace.Namespace {
func (ms *MutableStateImpl) CurrentTaskQueue() *taskqueuepb.TaskQueue {
if ms.IsStickyTaskQueueSet() {
return &taskqueuepb.TaskQueue{
Name: ms.executionInfo.StickyTaskQueue,
Kind: enumspb.TASK_QUEUE_KIND_STICKY,
Name: ms.executionInfo.StickyTaskQueue,
Kind: enumspb.TASK_QUEUE_KIND_STICKY,
NormalName: ms.executionInfo.TaskQueue,
}
}
return &taskqueuepb.TaskQueue{
Expand Down Expand Up @@ -707,8 +708,9 @@ func (ms *MutableStateImpl) IsStickyTaskQueueSet() bool {
func (ms *MutableStateImpl) TaskQueueScheduleToStartTimeout(name string) (*taskqueuepb.TaskQueue, *time.Duration) {
if ms.executionInfo.TaskQueue != name {
return &taskqueuepb.TaskQueue{
Name: ms.executionInfo.StickyTaskQueue,
Kind: enumspb.TASK_QUEUE_KIND_STICKY,
Name: ms.executionInfo.StickyTaskQueue,
Kind: enumspb.TASK_QUEUE_KIND_STICKY,
NormalName: ms.executionInfo.TaskQueue,
}, ms.executionInfo.StickyScheduleToStartTimeout
}
return &taskqueuepb.TaskQueue{
Expand Down
10 changes: 8 additions & 2 deletions service/matching/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ func (db *taskQueueDB) CompleteTasksLessThan(
return n, err
}

// Returns true if we are storing user data in the db. We need to be the root partition,
// workflow type, unversioned, and also a normal queue.
func (db *taskQueueDB) DbStoresUserData() bool {
return db.taskQueue.OwnsUserData() && db.taskQueueKind == enumspb.TASK_QUEUE_KIND_NORMAL
}

// GetUserData returns the versioning data for this task queue. Do not mutate the returned pointer, as doing so
// will cause cache inconsistency.
func (db *taskQueueDB) GetUserData(
Expand All @@ -320,7 +326,7 @@ func (db *taskQueueDB) getUserDataLocked(
ctx context.Context,
) (*persistencespb.VersionedTaskQueueUserData, chan struct{}, error) {
if db.userData == nil {
if !db.taskQueue.OwnsUserData() {
if !db.DbStoresUserData() {
return nil, db.userDataChanged, nil
}

Expand Down Expand Up @@ -350,7 +356,7 @@ func (db *taskQueueDB) getUserDataLocked(
//
// On success returns a pointer to the updated data, which must *not* be mutated.
func (db *taskQueueDB) UpdateUserData(ctx context.Context, updateFn func(*persistencespb.TaskQueueUserData) (*persistencespb.TaskQueueUserData, error), taskQueueLimitPerBuildId int) (*persistencespb.VersionedTaskQueueUserData, error) {
if !db.taskQueue.OwnsUserData() {
if !db.DbStoresUserData() {
return nil, errUserDataNoMutateNonRoot
}
db.Lock()
Expand Down
Loading

0 comments on commit 2de41e5

Please sign in to comment.