Skip to content

Commit

Permalink
Address David's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
bergundy committed May 26, 2023
1 parent 00821c1 commit a2d884e
Show file tree
Hide file tree
Showing 13 changed files with 92 additions and 105 deletions.
2 changes: 1 addition & 1 deletion cmd/tools/rpcwrappers/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func makeGetMatchingClient(reqType reflect.Type) string {
switch t.Name() {
case "GetBuildIdTaskQueueMappingRequest":
// Pick a random node for this request, it's not associated with a specific task queue.
tqPath = "&taskqueuepb.TaskQueue{Name: fmt.Sprintf(\"not-applicable-%%s\", rand.Int())}"
tqPath = "&taskqueuepb.TaskQueue{Name: fmt.Sprintf(\"not-applicable-%s\", rand.Int())}"
tqtPath = "enumspb.TASK_QUEUE_TYPE_UNSPECIFIED"
return fmt.Sprintf("client, err := c.getClientForTaskqueue(%s, %s, %s)", nsIDPath, tqPath, tqtPath)
case "UpdateTaskQueueUserDataRequest",
Expand Down
13 changes: 8 additions & 5 deletions common/persistence/sql/sqlplugin/matching_task_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ type (
DataEncoding string
}

AddBuildIdsToTaskQueueMapping struct {
AddToBuildIdToTaskQueueMapping struct {
NamespaceID []byte
TaskQueueName string
BuildIds []string
}

RemoveBuildIdsToTaskQueueMapping struct {
RemoveFromBuildIdToTaskQueueMapping struct {
NamespaceID []byte
TaskQueueName string
BuildIds []string
Expand All @@ -81,7 +81,10 @@ type (
BuildID string
}

CountTaskQueuesByBuildIdRequest = GetTaskQueuesByBuildIdRequest
CountTaskQueuesByBuildIdRequest struct {
NamespaceID []byte
BuildID string
}

VersionedBlob struct {
Version int64
Expand Down Expand Up @@ -114,8 +117,8 @@ type (
LockTaskQueues(ctx context.Context, filter TaskQueuesFilter) (int64, error)
GetTaskQueueUserData(ctx context.Context, request *GetTaskQueueUserDataRequest) (*VersionedBlob, error)
UpdateTaskQueueUserData(ctx context.Context, request *UpdateTaskQueueDataRequest) error
AddBuildIdToTaskQueueMapping(ctx context.Context, request AddBuildIdsToTaskQueueMapping) error
RemoveBuildIdToTaskQueueMapping(ctx context.Context, request RemoveBuildIdsToTaskQueueMapping) error
AddBuildIdToTaskQueueMapping(ctx context.Context, request AddToBuildIdToTaskQueueMapping) error
RemoveBuildIdToTaskQueueMapping(ctx context.Context, request RemoveFromBuildIdToTaskQueueMapping) error
ListTaskQueueUserDataEntries(ctx context.Context, request *ListTaskQueueUserDataEntriesRequest) ([]TaskQueueUserDataEntry, error)
GetTaskQueuesByBuildId(ctx context.Context, request *GetTaskQueuesByBuildIdRequest) ([]string, error)
CountTaskQueuesByBuildId(ctx context.Context, request *CountTaskQueuesByBuildIdRequest) (int, error)
Expand Down
4 changes: 2 additions & 2 deletions common/persistence/sql/sqlplugin/mysql/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ func (mdb *db) UpdateTaskQueueUserData(ctx context.Context, request *sqlplugin.U
return nil
}

func (mdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddBuildIdsToTaskQueueMapping) error {
func (mdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddToBuildIdToTaskQueueMapping) error {
query := addBuildIdToTaskQueueMappingQry
var params []any
for idx, buildId := range request.BuildIds {
Expand All @@ -356,7 +356,7 @@ func (mdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplug
return err
}

func (mdb *db) RemoveBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveBuildIdsToTaskQueueMapping) error {
func (mdb *db) RemoveBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveFromBuildIdToTaskQueueMapping) error {
// TODO(bergundy): implement when we support deletion
panic("not implemented")
}
Expand Down
4 changes: 2 additions & 2 deletions common/persistence/sql/sqlplugin/postgresql/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ func (pdb *db) UpdateTaskQueueUserData(ctx context.Context, request *sqlplugin.U
return nil
}

func (pdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddBuildIdsToTaskQueueMapping) error {
func (pdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddToBuildIdToTaskQueueMapping) error {
query := addBuildIdToTaskQueueMappingQry
var params []any
for idx, buildId := range request.BuildIds {
Expand All @@ -355,7 +355,7 @@ func (pdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplug
return err
}

func (pdb *db) RemoveBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveBuildIdsToTaskQueueMapping) error {
func (pdb *db) RemoveBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveFromBuildIdToTaskQueueMapping) error {
// TODO(bergundy): implement when we support deletion
panic("not implemented")
}
Expand Down
22 changes: 11 additions & 11 deletions common/persistence/sql/sqlplugin/sqlite/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"context"
"database/sql"
"fmt"
"strings"

"go.temporal.io/api/serviceerror"

Expand Down Expand Up @@ -110,7 +111,7 @@ task_queue_id = :task_queue_id
listTaskQueueUserDataQry = `SELECT task_queue_name, data, data_encoding FROM task_queue_user_data WHERE namespace_id = ? AND task_queue_name > ? LIMIT ?`

addBuildIdToTaskQueueMappingQry = `INSERT INTO build_id_to_task_queue (namespace_id, build_id, task_queue_name) VALUES `
removeBuildIdToTaskQueueMappingQry = `DELETE FROM build_id_to_task_queue WHERE `
removeBuildIdToTaskQueueMappingQry = `DELETE FROM build_id_to_task_queue WHERE namespace_id = ? AND task_queue_name = ? AND build_id IN (`
listTaskQueuesByBuildIdQry = `SELECT task_queue_name FROM build_id_to_task_queue WHERE namespace_id = ? AND build_id = ?`
countTaskQueuesByBuildIdQry = `SELECT COUNT(*) FROM build_id_to_task_queue WHERE namespace_id = ? AND build_id = ?`
)
Expand Down Expand Up @@ -346,7 +347,7 @@ func (mdb *db) UpdateTaskQueueUserData(ctx context.Context, request *sqlplugin.U
return nil
}

func (mdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddBuildIdsToTaskQueueMapping) error {
func (mdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddToBuildIdToTaskQueueMapping) error {
query := addBuildIdToTaskQueueMappingQry
var params []any
for idx, buildId := range request.BuildIds {
Expand All @@ -362,15 +363,14 @@ func (mdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplug
return err
}

func (mdb *db) RemoveBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveBuildIdsToTaskQueueMapping) error {
query := removeBuildIdToTaskQueueMappingQry
var params []any
for idx, buildId := range request.BuildIds {
query += "namespace_id = ? AND build_id = ? AND task_queue_name = ?"
if idx < len(request.BuildIds)-1 {
query += " OR "
}
params = append(params, request.NamespaceID, buildId, request.TaskQueueName)
func (mdb *db) RemoveBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveFromBuildIdToTaskQueueMapping) error {
query := removeBuildIdToTaskQueueMappingQry + strings.Repeat("?, ", len(request.BuildIds)-1) + "?)"
// Golang doesn't support appending a string slice to an any slice which is essentially what we're doing here.
params := make([]any, len(request.BuildIds)+2)
params[0] = request.NamespaceID
params[1] = request.TaskQueueName
for i, buildId := range request.BuildIds {
params[i+2] = buildId
}

_, err := mdb.conn.ExecContext(ctx, query, params...)
Expand Down
4 changes: 2 additions & 2 deletions common/persistence/sql/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ func (m *sqlTaskManager) UpdateTaskQueueUserData(ctx context.Context, request *p
return err
}
if len(request.BuildIdsAdded) > 0 {
err = tx.AddBuildIdToTaskQueueMapping(ctx, sqlplugin.AddBuildIdsToTaskQueueMapping{
err = tx.AddBuildIdToTaskQueueMapping(ctx, sqlplugin.AddToBuildIdToTaskQueueMapping{
NamespaceID: namespaceID,
TaskQueueName: request.TaskQueue,
BuildIds: request.BuildIdsAdded,
Expand All @@ -528,7 +528,7 @@ func (m *sqlTaskManager) UpdateTaskQueueUserData(ctx context.Context, request *p
}
}
if len(request.BuildIdsRemoved) > 0 {
err = tx.RemoveBuildIdToTaskQueueMapping(ctx, sqlplugin.RemoveBuildIdsToTaskQueueMapping{
err = tx.RemoveBuildIdToTaskQueueMapping(ctx, sqlplugin.RemoveFromBuildIdToTaskQueueMapping{
NamespaceID: namespaceID,
TaskQueueName: request.TaskQueue,
BuildIds: request.BuildIdsRemoved,
Expand Down
35 changes: 28 additions & 7 deletions common/worker_versioning.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@

package common

import commonpb "go.temporal.io/api/common/v1"
import (
commonpb "go.temporal.io/api/common/v1"
taskqueuepb "go.temporal.io/api/taskqueue/v1"
)

const buildIdSearchAttributePrefixVersioned = "versioned"
const buildIdSearchAttributePrefixUnversioned = "unversioned"
const BuildIdSearchAttributeDelimiter = ":"

// UnversionedSearchAttribute is the sentinel value used to mark all unversioned workflows
const UnversionedSearchAttribute = buildIdSearchAttributePrefixUnversioned
const (
buildIdSearchAttributePrefixVersioned = "versioned"
buildIdSearchAttributePrefixUnversioned = "unversioned"
BuildIdSearchAttributeDelimiter = ":"
// UnversionedSearchAttribute is the sentinel value used to mark all unversioned workflows
UnversionedSearchAttribute = buildIdSearchAttributePrefixUnversioned
)

// VersionedBuildIdSearchAttribute returns the search attribute value for an unversioned build id
func VersionedBuildIdSearchAttribute(buildId string) string {
Expand All @@ -53,3 +57,20 @@ func VersionStampToBuildIdSearchAttribute(stamp *commonpb.WorkerVersionStamp) st
}
return UnversionedBuildIdSearchAttribute(stamp.BuildId)
}

func FindBuildId(versionSets []*taskqueuepb.CompatibleVersionSet, buildId string) (setIndex, indexInSet int) {
setIndex = -1
indexInSet = -1
if len(versionSets) > 0 {
for sidx, set := range versionSets {
for bidx, id := range set.BuildIds {
if buildId == id {
setIndex = sidx
indexInSet = bidx
break
}
}
}
}
return setIndex, indexInSet
}
1 change: 0 additions & 1 deletion schema/cassandra/temporal/schema.cql
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ CREATE TABLE task_queue_user_data (
build_id_if_row_is_an_index text, -- If this row is used as a mapping of build id to task queue, this will not be empty
data blob, -- temporal.server.api.persistence.v1.TaskQueueUserData
data_encoding text, -- Encoding type used for serialization, in practice this should always be proto3
build_ids set<text>, -- All active build ids in all version sets on this task queue (used in an index below)
version bigint, -- Version of this row, used for optimistic concurrency
-- task_queue_name is not a part of the parititioning key to allow cheaply iterating all task queues in a single
-- namespace. Access to this table should be infrequent enough that a single partition per namespace can be used.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ CREATE TABLE task_queue_user_data (
build_id_if_row_is_an_index text, -- If this row is used as a mapping of build id to task queue, this will not be empty
data blob, -- temporal.server.api.persistence.v1.TaskQueueUserData
data_encoding text, -- Encoding type used for serialization, in practice this should always be proto3
build_ids set<text>, -- All active build ids in all version sets on this task queue (used in an index below)
version bigint, -- Version of this row, used for optimistic concurrency
-- task_queue_name is not a part of the parititioning key to allow cheaply iterating all task queues in a single
-- namespace. Access to this table should be infrequent enough that a single partition per namespace can be used.
Expand Down
92 changes: 29 additions & 63 deletions service/frontend/task_reachability.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,61 +30,41 @@ import (
"strings"
"sync"

"github.com/xwb1989/sqlparser"
enumspb "go.temporal.io/api/enums/v1"
"go.temporal.io/api/serviceerror"
taskqueuepb "go.temporal.io/api/taskqueue/v1"
"go.temporal.io/api/workflowservice/v1"
"go.temporal.io/server/api/matchingservice/v1"
"go.temporal.io/server/common"
"go.temporal.io/server/common/future"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/persistence/visibility/manager"
"go.temporal.io/server/common/searchattribute"
"go.temporal.io/server/common/util"
)

// Helper for deduping GetWorkerBuildIdCompatibility matching requests.
type versionSetFetcher struct {
sync.Mutex
matchingClient matchingservice.MatchingServiceClient
notificationChs map[string]chan struct{}
responses map[string]versionSetFetcherResponse
}

type versionSetFetcherResponse struct {
value *matchingservice.GetWorkerBuildIdCompatibilityResponse
err error
lock sync.Mutex
matchingClient matchingservice.MatchingServiceClient
futures map[string]future.Future[*matchingservice.GetWorkerBuildIdCompatibilityResponse]
}

func newVersionSetFetcher(matchingClient matchingservice.MatchingServiceClient) *versionSetFetcher {
return &versionSetFetcher{
matchingClient: matchingClient,
notificationChs: make(map[string]chan struct{}),
responses: make(map[string]versionSetFetcherResponse),
matchingClient: matchingClient,
futures: make(map[string]future.Future[*matchingservice.GetWorkerBuildIdCompatibilityResponse]),
}
}

func (f *versionSetFetcher) getStoredResponse(taskQueue string) (*matchingservice.GetWorkerBuildIdCompatibilityResponse, error) {
f.Lock()
defer f.Unlock()
if response, found := f.responses[taskQueue]; found {
return response.value, response.err
}
return nil, nil
}

func (f *versionSetFetcher) fetchTaskQueueVersions(ctx context.Context, ns *namespace.Namespace, taskQueue string) (*matchingservice.GetWorkerBuildIdCompatibilityResponse, error) {
response, err := f.getStoredResponse(taskQueue)
if err != nil || response != nil {
return response, err
}

f.Lock()
var waitCh chan struct{}
if ch, found := f.notificationChs[taskQueue]; found {
waitCh = ch
} else {
waitCh = make(chan struct{})
f.notificationChs[taskQueue] = waitCh

func (f *versionSetFetcher) getFuture(ctx context.Context, ns *namespace.Namespace, taskQueue string) future.Future[*matchingservice.GetWorkerBuildIdCompatibilityResponse] {
f.lock.Lock()
defer f.lock.Unlock()
_, found := f.futures[taskQueue]
if !found {
fut := future.NewFuture[*matchingservice.GetWorkerBuildIdCompatibilityResponse]()
f.futures[taskQueue] = fut
go func() {
value, err := f.matchingClient.GetWorkerBuildIdCompatibility(ctx, &matchingservice.GetWorkerBuildIdCompatibilityRequest{
NamespaceId: ns.ID().String(),
Expand All @@ -93,16 +73,15 @@ func (f *versionSetFetcher) fetchTaskQueueVersions(ctx context.Context, ns *name
TaskQueue: taskQueue,
},
})
f.Lock()
defer f.Unlock()
f.responses[taskQueue] = versionSetFetcherResponse{value: value, err: err}
close(waitCh)
fut.Set(value, err)
}()
}
f.Unlock()
return f.futures[taskQueue]
}

<-waitCh
return f.getStoredResponse(taskQueue)
func (f *versionSetFetcher) fetchTaskQueueVersions(ctx context.Context, ns *namespace.Namespace, taskQueue string) (*matchingservice.GetWorkerBuildIdCompatibilityResponse, error) {
future := f.getFuture(ctx, ns, taskQueue)
return future.Get(ctx)
}

// Implementation of the GetWorkerTaskReachability API. Expects an already validated request.
Expand Down Expand Up @@ -202,23 +181,10 @@ func (wh *WorkflowHandler) getTaskQueueReachability(ctx context.Context, request
// Query for the unversioned worker
isDefaultInQueue = len(request.versionSets) == 0
// Query workflows that have completed tasks marked with a sentinel "unversioned" search attribute.
buildIdsFilter = fmt.Sprintf("BuildIds = %q", common.UnversionedSearchAttribute)
buildIdsFilter = fmt.Sprintf(`%s = "%s"`, searchattribute.BuildIds, common.UnversionedSearchAttribute)
} else {
// Query for a versioned worker
setIdx := -1
buildIdIdx := -1
if len(request.versionSets) > 0 {
for sidx, set := range request.versionSets {
for bidx, id := range set.BuildIds {
if request.buildId == id {
setIdx = sidx
buildIdIdx = bidx
break
}
}
}
}

setIdx, buildIdIdx := common.FindBuildId(request.versionSets, request.buildId)
if setIdx == -1 {
// build id not in set - unreachable
return &taskQueueReachability, nil
Expand All @@ -234,9 +200,9 @@ func (wh *WorkflowHandler) getTaskQueueReachability(ctx context.Context, request
isDefaultInQueue = setIdx == len(request.versionSets)-1
escapedBuildIds := make([]string, len(set.GetBuildIds()))
for i, buildId := range set.GetBuildIds() {
escapedBuildIds[i] = fmt.Sprintf("%q", common.VersionedBuildIdSearchAttribute(buildId))
escapedBuildIds[i] = sqlparser.String(sqlparser.NewStrVal([]byte(common.VersionedBuildIdSearchAttribute(buildId))))
}
buildIdsFilter = fmt.Sprintf("BuildIds IN (%s)", strings.Join(escapedBuildIds, ","))
buildIdsFilter = fmt.Sprintf("%s IN (%s)", searchattribute.BuildIds, strings.Join(escapedBuildIds, ","))
}

if isDefaultInQueue {
Expand All @@ -246,7 +212,7 @@ func (wh *WorkflowHandler) getTaskQueueReachability(ctx context.Context, request
)
// Take into account started workflows that have not yet been processed by any worker.
if request.reachabilityType != enumspb.TASK_REACHABILITY_CLOSED_WORKFLOWS {
buildIdsFilter = fmt.Sprintf("(BuildIds IS NULL OR %s)", buildIdsFilter)
buildIdsFilter = fmt.Sprintf("(%s IS NULL OR %s)", searchattribute.BuildIds, buildIdsFilter)
}
}

Expand All @@ -268,9 +234,9 @@ func (wh *WorkflowHandler) queryVisibilityForExisitingWorkflowsReachability(
statusFilter := ""
switch reachabilityType {
case enumspb.TASK_REACHABILITY_OPEN_WORKFLOWS:
statusFilter = " AND ExecutionStatus = \"Running\""
statusFilter = fmt.Sprintf(` AND %s = "Running"`, searchattribute.ExecutionStatus)
case enumspb.TASK_REACHABILITY_CLOSED_WORKFLOWS:
statusFilter = " AND ExecutionStatus != \"Running\""
statusFilter = fmt.Sprintf(` AND %s != "Running"`, searchattribute.ExecutionStatus)
case enumspb.TASK_REACHABILITY_UNSPECIFIED:
reachabilityType = enumspb.TASK_REACHABILITY_EXISTING_WORKFLOWS
statusFilter = ""
Expand All @@ -285,7 +251,7 @@ func (wh *WorkflowHandler) queryVisibilityForExisitingWorkflowsReachability(
req := manager.CountWorkflowExecutionsRequest{
NamespaceID: ns.ID(),
Namespace: ns.Name(),
Query: fmt.Sprintf("TaskQueue = %q AND %s%s", taskQueue, buildIdsFilter, statusFilter),
Query: fmt.Sprintf("%s = %q AND %s%s", searchattribute.TaskQueue, taskQueue, buildIdsFilter, statusFilter),
}

// TODO(bergundy): is count more efficient than select with page size of 1?
Expand Down
2 changes: 1 addition & 1 deletion service/frontend/workflow_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3653,7 +3653,7 @@ func (wh *WorkflowHandler) GetWorkerTaskReachability(ctx context.Context, reques
}
}
if gotUnversionedRequest && len(request.GetTaskQueues()) == 0 {
return nil, serviceerror.NewInvalidArgument("Cannot get reachability of an unversioned worker without specifying at least one task queue")
return nil, serviceerror.NewInvalidArgument("Cannot get reachability of an unversioned worker without specifying at least one task queue (empty build id is interpereted as unversioned)")
}
for _, taskQueue := range request.GetTaskQueues() {
taskQueue := &taskqueuepb.TaskQueue{Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL}
Expand Down
1 change: 0 additions & 1 deletion service/matching/taskQueueManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,6 @@ func (c *taskQueueManagerImpl) GetUserData(ctx context.Context) (*persistencespb
return c.db.GetUserData(ctx)
}

//nolint:revive // control coupling
func (c *taskQueueManagerImpl) UpdateUserData(ctx context.Context, options UserDataUpdateOptions, updateFn UserDataUpdateFunc) error {
newData, err := c.db.UpdateUserData(ctx, updateFn, options.TaskQueueLimitPerBuildId)
if err != nil {
Expand Down
Loading

0 comments on commit a2d884e

Please sign in to comment.