diff --git a/cmd/tools/rpcwrappers/main.go b/cmd/tools/rpcwrappers/main.go index 143309cf954..e3cfd64c05b 100644 --- a/cmd/tools/rpcwrappers/main.go +++ b/cmd/tools/rpcwrappers/main.go @@ -186,7 +186,7 @@ func makeGetMatchingClient(reqType reflect.Type) string { switch t.Name() { case "GetBuildIdTaskQueueMappingRequest": // Pick a random node for this request, it's not associated with a specific task queue. - tqPath = "&taskqueuepb.TaskQueue{Name: fmt.Sprintf(\"not-applicable-%%s\", rand.Int())}" + tqPath = "&taskqueuepb.TaskQueue{Name: fmt.Sprintf(\"not-applicable-%s\", rand.Int())}" tqtPath = "enumspb.TASK_QUEUE_TYPE_UNSPECIFIED" return fmt.Sprintf("client, err := c.getClientForTaskqueue(%s, %s, %s)", nsIDPath, tqPath, tqtPath) case "UpdateTaskQueueUserDataRequest", diff --git a/common/persistence/cassandra/matching_task_store.go b/common/persistence/cassandra/matching_task_store.go index 4c7fa1752d2..c2374f30f3c 100644 --- a/common/persistence/cassandra/matching_task_store.go +++ b/common/persistence/cassandra/matching_task_store.go @@ -134,7 +134,7 @@ const ( templateGetTaskQueueUserDataQuery = `SELECT data, data_encoding, version FROM task_queue_user_data - WHERE namespace_id = ? AND build_id_if_row_is_an_index = '' + WHERE namespace_id = ? AND build_id = '' AND task_queue_name = ?` templateUpdateTaskQueueUserDataQuery = `UPDATE task_queue_user_data SET @@ -142,22 +142,22 @@ const ( data_encoding = ?, version = ? WHERE namespace_id = ? - AND build_id_if_row_is_an_index = '' + AND build_id = '' AND task_queue_name = ? IF version = ?` templateInsertTaskQueueUserDataQuery = `INSERT INTO task_queue_user_data - (namespace_id, build_id_if_row_is_an_index, task_queue_name, data, data_encoding, version) VALUES - (? , '' , ? , ? , ? , 1 ) IF NOT EXISTS` + (namespace_id, build_id, task_queue_name, data, data_encoding, version) VALUES + (? , '' , ? , ? , ? , 1 ) IF NOT EXISTS` templateInsertBuildIdTaskQueueMappingQuery = `INSERT INTO task_queue_user_data - (namespace_id, build_id_if_row_is_an_index, task_queue_name) VALUES - (? , ? , ?)` + (namespace_id, build_id, task_queue_name) VALUES + (? , ? , ?)` templateDeleteBuildIdTaskQueueMappingQuery = `DELETE FROM task_queue_user_data - WHERE namespace_id = ? AND build_id_if_row_is_an_index = ? AND task_queue_name = ?` - templateListTaskQueueUserDataQuery = `SELECT task_queue_name, data, data_encoding FROM task_queue_user_data WHERE namespace_id = ? AND build_id_if_row_is_an_index = ''` - templateListTaskQueueNamesByBuildIdQuery = `SELECT task_queue_name FROM task_queue_user_data WHERE namespace_id = ? AND build_id_if_row_is_an_index = ?` - templateCountTaskQueueByBuildIdQuery = `SELECT COUNT(*) FROM task_queue_user_data WHERE namespace_id = ? AND build_id_if_row_is_an_index = ?` + WHERE namespace_id = ? AND build_id = ? AND task_queue_name = ?` + templateListTaskQueueUserDataQuery = `SELECT task_queue_name, data, data_encoding FROM task_queue_user_data WHERE namespace_id = ? AND build_id = ''` + templateListTaskQueueNamesByBuildIdQuery = `SELECT task_queue_name FROM task_queue_user_data WHERE namespace_id = ? AND build_id = ?` + templateCountTaskQueueByBuildIdQuery = `SELECT COUNT(*) FROM task_queue_user_data WHERE namespace_id = ? AND build_id = ?` // Not much of a need to make this configurable, we're just reading some strings listTaskQueueNamesByBuildIdPageSize = 100 diff --git a/common/persistence/dataInterfaces.go b/common/persistence/dataInterfaces.go index 681e97f8498..f77876dd5b7 100644 --- a/common/persistence/dataInterfaces.go +++ b/common/persistence/dataInterfaces.go @@ -585,7 +585,10 @@ type ( BuildID string } - CountTaskQueuesByBuildIdRequest = GetTaskQueuesByBuildIdRequest + CountTaskQueuesByBuildIdRequest struct { + NamespaceID string + BuildID string + } // ListTaskQueueRequest contains the request params needed to invoke ListTaskQueue API ListTaskQueueRequest struct { diff --git a/common/persistence/persistenceInterface.go b/common/persistence/persistenceInterface.go index e0505d8bd8d..57d9d9b62bc 100644 --- a/common/persistence/persistenceInterface.go +++ b/common/persistence/persistenceInterface.go @@ -76,7 +76,7 @@ type ( UpdateTaskQueueUserData(ctx context.Context, request *InternalUpdateTaskQueueUserDataRequest) error ListTaskQueueUserDataEntries(ctx context.Context, request *ListTaskQueueUserDataEntriesRequest) (*InternalListTaskQueueUserDataEntriesResponse, error) GetTaskQueuesByBuildId(ctx context.Context, request *GetTaskQueuesByBuildIdRequest) ([]string, error) - CountTaskQueuesByBuildId(ctx context.Context, request *GetTaskQueuesByBuildIdRequest) (int, error) + CountTaskQueuesByBuildId(ctx context.Context, request *CountTaskQueuesByBuildIdRequest) (int, error) } // MetadataStore is a lower level of MetadataManager MetadataStore interface { diff --git a/common/persistence/sql/sqlplugin/matching_task_queue.go b/common/persistence/sql/sqlplugin/matching_task_queue.go index f478a2a7d3d..267a78b6328 100644 --- a/common/persistence/sql/sqlplugin/matching_task_queue.go +++ b/common/persistence/sql/sqlplugin/matching_task_queue.go @@ -64,13 +64,13 @@ type ( DataEncoding string } - AddBuildIdsToTaskQueueMapping struct { + AddToBuildIdToTaskQueueMapping struct { NamespaceID []byte TaskQueueName string BuildIds []string } - RemoveBuildIdsToTaskQueueMapping struct { + RemoveFromBuildIdToTaskQueueMapping struct { NamespaceID []byte TaskQueueName string BuildIds []string @@ -81,7 +81,10 @@ type ( BuildID string } - CountTaskQueuesByBuildIdRequest = GetTaskQueuesByBuildIdRequest + CountTaskQueuesByBuildIdRequest struct { + NamespaceID []byte + BuildID string + } VersionedBlob struct { Version int64 @@ -114,8 +117,8 @@ type ( LockTaskQueues(ctx context.Context, filter TaskQueuesFilter) (int64, error) GetTaskQueueUserData(ctx context.Context, request *GetTaskQueueUserDataRequest) (*VersionedBlob, error) UpdateTaskQueueUserData(ctx context.Context, request *UpdateTaskQueueDataRequest) error - AddBuildIdToTaskQueueMapping(ctx context.Context, request AddBuildIdsToTaskQueueMapping) error - RemoveBuildIdToTaskQueueMapping(ctx context.Context, request RemoveBuildIdsToTaskQueueMapping) error + AddBuildIdToTaskQueueMapping(ctx context.Context, request AddToBuildIdToTaskQueueMapping) error + RemoveBuildIdToTaskQueueMapping(ctx context.Context, request RemoveFromBuildIdToTaskQueueMapping) error ListTaskQueueUserDataEntries(ctx context.Context, request *ListTaskQueueUserDataEntriesRequest) ([]TaskQueueUserDataEntry, error) GetTaskQueuesByBuildId(ctx context.Context, request *GetTaskQueuesByBuildIdRequest) ([]string, error) CountTaskQueuesByBuildId(ctx context.Context, request *CountTaskQueuesByBuildIdRequest) (int, error) diff --git a/common/persistence/sql/sqlplugin/mysql/task.go b/common/persistence/sql/sqlplugin/mysql/task.go index 7023c39593c..0cd3dbcc686 100644 --- a/common/persistence/sql/sqlplugin/mysql/task.go +++ b/common/persistence/sql/sqlplugin/mysql/task.go @@ -340,7 +340,7 @@ func (mdb *db) UpdateTaskQueueUserData(ctx context.Context, request *sqlplugin.U return nil } -func (mdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddBuildIdsToTaskQueueMapping) error { +func (mdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddToBuildIdToTaskQueueMapping) error { query := addBuildIdToTaskQueueMappingQry var params []any for idx, buildId := range request.BuildIds { @@ -356,7 +356,7 @@ func (mdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplug return err } -func (mdb *db) RemoveBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveBuildIdsToTaskQueueMapping) error { +func (mdb *db) RemoveBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveFromBuildIdToTaskQueueMapping) error { // TODO(bergundy): implement when we support deletion panic("not implemented") } @@ -380,7 +380,7 @@ func (mdb *db) GetTaskQueuesByBuildId(ctx context.Context, request *sqlplugin.Ge return taskQueues, err } -func (mdb *db) CountTaskQueuesByBuildId(ctx context.Context, request *sqlplugin.GetTaskQueuesByBuildIdRequest) (int, error) { +func (mdb *db) CountTaskQueuesByBuildId(ctx context.Context, request *sqlplugin.CountTaskQueuesByBuildIdRequest) (int, error) { var count int err := mdb.conn.GetContext(ctx, &count, countTaskQueuesByBuildIdQry, request.NamespaceID, request.BuildID) return count, err diff --git a/common/persistence/sql/sqlplugin/postgresql/task.go b/common/persistence/sql/sqlplugin/postgresql/task.go index 2a29d4d8232..77eb9505688 100644 --- a/common/persistence/sql/sqlplugin/postgresql/task.go +++ b/common/persistence/sql/sqlplugin/postgresql/task.go @@ -340,7 +340,7 @@ func (pdb *db) UpdateTaskQueueUserData(ctx context.Context, request *sqlplugin.U return nil } -func (pdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddBuildIdsToTaskQueueMapping) error { +func (pdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddToBuildIdToTaskQueueMapping) error { query := addBuildIdToTaskQueueMappingQry var params []any for idx, buildId := range request.BuildIds { @@ -355,7 +355,7 @@ func (pdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplug return err } -func (pdb *db) RemoveBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveBuildIdsToTaskQueueMapping) error { +func (pdb *db) RemoveBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveFromBuildIdToTaskQueueMapping) error { // TODO(bergundy): implement when we support deletion panic("not implemented") } @@ -379,7 +379,7 @@ func (pdb *db) GetTaskQueuesByBuildId(ctx context.Context, request *sqlplugin.Ge return taskQueues, err } -func (pdb *db) CountTaskQueuesByBuildId(ctx context.Context, request *sqlplugin.GetTaskQueuesByBuildIdRequest) (int, error) { +func (pdb *db) CountTaskQueuesByBuildId(ctx context.Context, request *sqlplugin.CountTaskQueuesByBuildIdRequest) (int, error) { var count int err := pdb.conn.GetContext(ctx, &count, countTaskQueuesByBuildIdQry, request.NamespaceID, request.BuildID) return count, err diff --git a/common/persistence/sql/sqlplugin/sqlite/task.go b/common/persistence/sql/sqlplugin/sqlite/task.go index 57b3d7a84b0..93277be5d4b 100644 --- a/common/persistence/sql/sqlplugin/sqlite/task.go +++ b/common/persistence/sql/sqlplugin/sqlite/task.go @@ -30,6 +30,7 @@ import ( "context" "database/sql" "fmt" + "strings" "go.temporal.io/api/serviceerror" @@ -110,7 +111,7 @@ task_queue_id = :task_queue_id listTaskQueueUserDataQry = `SELECT task_queue_name, data, data_encoding FROM task_queue_user_data WHERE namespace_id = ? AND task_queue_name > ? LIMIT ?` addBuildIdToTaskQueueMappingQry = `INSERT INTO build_id_to_task_queue (namespace_id, build_id, task_queue_name) VALUES ` - removeBuildIdToTaskQueueMappingQry = `DELETE FROM build_id_to_task_queue WHERE ` + removeBuildIdToTaskQueueMappingQry = `DELETE FROM build_id_to_task_queue WHERE namespace_id = ? AND task_queue_name = ? AND build_id IN (` listTaskQueuesByBuildIdQry = `SELECT task_queue_name FROM build_id_to_task_queue WHERE namespace_id = ? AND build_id = ?` countTaskQueuesByBuildIdQry = `SELECT COUNT(*) FROM build_id_to_task_queue WHERE namespace_id = ? AND build_id = ?` ) @@ -346,7 +347,7 @@ func (mdb *db) UpdateTaskQueueUserData(ctx context.Context, request *sqlplugin.U return nil } -func (mdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddBuildIdsToTaskQueueMapping) error { +func (mdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.AddToBuildIdToTaskQueueMapping) error { query := addBuildIdToTaskQueueMappingQry var params []any for idx, buildId := range request.BuildIds { @@ -362,15 +363,14 @@ func (mdb *db) AddBuildIdToTaskQueueMapping(ctx context.Context, request sqlplug return err } -func (mdb *db) RemoveBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveBuildIdsToTaskQueueMapping) error { - query := removeBuildIdToTaskQueueMappingQry - var params []any - for idx, buildId := range request.BuildIds { - query += "namespace_id = ? AND build_id = ? AND task_queue_name = ?" - if idx < len(request.BuildIds)-1 { - query += " OR " - } - params = append(params, request.NamespaceID, buildId, request.TaskQueueName) +func (mdb *db) RemoveBuildIdToTaskQueueMapping(ctx context.Context, request sqlplugin.RemoveFromBuildIdToTaskQueueMapping) error { + query := removeBuildIdToTaskQueueMappingQry + strings.Repeat("?, ", len(request.BuildIds)-1) + "?)" + // Golang doesn't support appending a string slice to an any slice which is essentially what we're doing here. + params := make([]any, len(request.BuildIds)+2) + params[0] = request.NamespaceID + params[1] = request.TaskQueueName + for i, buildId := range request.BuildIds { + params[i+2] = buildId } _, err := mdb.conn.ExecContext(ctx, query, params...) @@ -400,7 +400,7 @@ func (mdb *db) GetTaskQueuesByBuildId(ctx context.Context, request *sqlplugin.Ge return taskQueues, err } -func (mdb *db) CountTaskQueuesByBuildId(ctx context.Context, request *sqlplugin.GetTaskQueuesByBuildIdRequest) (int, error) { +func (mdb *db) CountTaskQueuesByBuildId(ctx context.Context, request *sqlplugin.CountTaskQueuesByBuildIdRequest) (int, error) { var count int err := mdb.conn.GetContext(ctx, &count, countTaskQueuesByBuildIdQry, request.NamespaceID, request.BuildID) return count, err diff --git a/common/persistence/sql/task.go b/common/persistence/sql/task.go index 1acf7d72880..97995e9da86 100644 --- a/common/persistence/sql/task.go +++ b/common/persistence/sql/task.go @@ -518,7 +518,7 @@ func (m *sqlTaskManager) UpdateTaskQueueUserData(ctx context.Context, request *p return err } if len(request.BuildIdsAdded) > 0 { - err = tx.AddBuildIdToTaskQueueMapping(ctx, sqlplugin.AddBuildIdsToTaskQueueMapping{ + err = tx.AddBuildIdToTaskQueueMapping(ctx, sqlplugin.AddToBuildIdToTaskQueueMapping{ NamespaceID: namespaceID, TaskQueueName: request.TaskQueue, BuildIds: request.BuildIdsAdded, @@ -528,7 +528,7 @@ func (m *sqlTaskManager) UpdateTaskQueueUserData(ctx context.Context, request *p } } if len(request.BuildIdsRemoved) > 0 { - err = tx.RemoveBuildIdToTaskQueueMapping(ctx, sqlplugin.RemoveBuildIdsToTaskQueueMapping{ + err = tx.RemoveBuildIdToTaskQueueMapping(ctx, sqlplugin.RemoveFromBuildIdToTaskQueueMapping{ NamespaceID: namespaceID, TaskQueueName: request.TaskQueue, BuildIds: request.BuildIdsRemoved, @@ -599,7 +599,7 @@ func (m *sqlTaskManager) CountTaskQueuesByBuildId(ctx context.Context, request * if err != nil { return 0, serviceerror.NewInternal(err.Error()) } - return m.Db.CountTaskQueuesByBuildId(ctx, &sqlplugin.GetTaskQueuesByBuildIdRequest{NamespaceID: namespaceID, BuildID: request.BuildID}) + return m.Db.CountTaskQueuesByBuildId(ctx, &sqlplugin.CountTaskQueuesByBuildIdRequest{NamespaceID: namespaceID, BuildID: request.BuildID}) } // Returns uint32 hash for a particular TaskQueue/Task given a Namespace, TaskQueueName and TaskQueueType diff --git a/common/worker_versioning.go b/common/worker_versioning.go index aedb12dc1d5..cc8e5ea6202 100644 --- a/common/worker_versioning.go +++ b/common/worker_versioning.go @@ -24,14 +24,18 @@ package common -import commonpb "go.temporal.io/api/common/v1" +import ( + commonpb "go.temporal.io/api/common/v1" + taskqueuepb "go.temporal.io/api/taskqueue/v1" +) -const buildIdSearchAttributePrefixVersioned = "versioned" -const buildIdSearchAttributePrefixUnversioned = "unversioned" -const BuildIdSearchAttributeDelimiter = ":" - -// UnversionedSearchAttribute is the sentinel value used to mark all unversioned workflows -const UnversionedSearchAttribute = buildIdSearchAttributePrefixUnversioned +const ( + buildIdSearchAttributePrefixVersioned = "versioned" + buildIdSearchAttributePrefixUnversioned = "unversioned" + BuildIdSearchAttributeDelimiter = ":" + // UnversionedSearchAttribute is the sentinel value used to mark all unversioned workflows + UnversionedSearchAttribute = buildIdSearchAttributePrefixUnversioned +) // VersionedBuildIdSearchAttribute returns the search attribute value for an unversioned build id func VersionedBuildIdSearchAttribute(buildId string) string { @@ -53,3 +57,20 @@ func VersionStampToBuildIdSearchAttribute(stamp *commonpb.WorkerVersionStamp) st } return UnversionedBuildIdSearchAttribute(stamp.BuildId) } + +func FindBuildId(versionSets []*taskqueuepb.CompatibleVersionSet, buildId string) (setIndex, indexInSet int) { + setIndex = -1 + indexInSet = -1 + if len(versionSets) > 0 { + for sidx, set := range versionSets { + for bidx, id := range set.BuildIds { + if buildId == id { + setIndex = sidx + indexInSet = bidx + break + } + } + } + } + return setIndex, indexInSet +} diff --git a/schema/cassandra/temporal/schema.cql b/schema/cassandra/temporal/schema.cql index 9e808544044..2bdecdf824a 100644 --- a/schema/cassandra/temporal/schema.cql +++ b/schema/cassandra/temporal/schema.cql @@ -100,18 +100,17 @@ CREATE TABLE tasks ( -- OR -- Used as a mapping from build id to task queue CREATE TABLE task_queue_user_data ( - namespace_id uuid, - task_queue_name text, - build_id_if_row_is_an_index text, -- If this row is used as a mapping of build id to task queue, this will not be empty - data blob, -- temporal.server.api.persistence.v1.TaskQueueUserData - data_encoding text, -- Encoding type used for serialization, in practice this should always be proto3 - build_ids set, -- All active build ids in all version sets on this task queue (used in an index below) - version bigint, -- Version of this row, used for optimistic concurrency + namespace_id uuid, + task_queue_name text, + build_id text, -- If this row is used as a mapping of build id to task queue, this will not be empty + data blob, -- temporal.server.api.persistence.v1.TaskQueueUserData + data_encoding text, -- Encoding type used for serialization, in practice this should always be proto3 + version bigint, -- Version of this row, used for optimistic concurrency -- task_queue_name is not a part of the parititioning key to allow cheaply iterating all task queues in a single -- namespace. Access to this table should be infrequent enough that a single partition per namespace can be used. -- Note that this imposes a limit on total task queue user data within one namespace (see the relevant single -- partition Cassandra limits). - PRIMARY KEY ((namespace_id), build_id_if_row_is_an_index, task_queue_name) + PRIMARY KEY ((namespace_id), build_id, task_queue_name) ) WITH COMPACTION = { 'class': 'org.apache.cassandra.db.compaction.LeveledCompactionStrategy' }; diff --git a/schema/cassandra/temporal/versioned/v1.8/task_queue_user_data.cql b/schema/cassandra/temporal/versioned/v1.8/task_queue_user_data.cql index 4928178a8ff..f2ad8db0725 100644 --- a/schema/cassandra/temporal/versioned/v1.8/task_queue_user_data.cql +++ b/schema/cassandra/temporal/versioned/v1.8/task_queue_user_data.cql @@ -2,18 +2,17 @@ -- OR -- Used as a mapping from build id to task queue CREATE TABLE task_queue_user_data ( - namespace_id uuid, - task_queue_name text, - build_id_if_row_is_an_index text, -- If this row is used as a mapping of build id to task queue, this will not be empty - data blob, -- temporal.server.api.persistence.v1.TaskQueueUserData - data_encoding text, -- Encoding type used for serialization, in practice this should always be proto3 - build_ids set, -- All active build ids in all version sets on this task queue (used in an index below) - version bigint, -- Version of this row, used for optimistic concurrency + namespace_id uuid, + task_queue_name text, + build_id text, -- If this row is used as a mapping of build id to task queue, this will not be empty + data blob, -- temporal.server.api.persistence.v1.TaskQueueUserData + data_encoding text, -- Encoding type used for serialization, in practice this should always be proto3 + version bigint, -- Version of this row, used for optimistic concurrency -- task_queue_name is not a part of the parititioning key to allow cheaply iterating all task queues in a single -- namespace. Access to this table should be infrequent enough that a single partition per namespace can be used. -- Note that this imposes a limit on total task queue user data within one namespace (see the relevant single -- partition Cassandra limits). - PRIMARY KEY ((namespace_id), build_id_if_row_is_an_index, task_queue_name) + PRIMARY KEY ((namespace_id), build_id, task_queue_name) ) WITH COMPACTION = { 'class': 'org.apache.cassandra.db.compaction.LeveledCompactionStrategy' }; \ No newline at end of file diff --git a/service/frontend/task_reachability.go b/service/frontend/task_reachability.go index 059592e4e0e..d7c69cd4b77 100644 --- a/service/frontend/task_reachability.go +++ b/service/frontend/task_reachability.go @@ -30,61 +30,41 @@ import ( "strings" "sync" + "github.com/xwb1989/sqlparser" enumspb "go.temporal.io/api/enums/v1" "go.temporal.io/api/serviceerror" taskqueuepb "go.temporal.io/api/taskqueue/v1" "go.temporal.io/api/workflowservice/v1" "go.temporal.io/server/api/matchingservice/v1" "go.temporal.io/server/common" + "go.temporal.io/server/common/future" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence/visibility/manager" + "go.temporal.io/server/common/searchattribute" "go.temporal.io/server/common/util" ) // Helper for deduping GetWorkerBuildIdCompatibility matching requests. type versionSetFetcher struct { - sync.Mutex - matchingClient matchingservice.MatchingServiceClient - notificationChs map[string]chan struct{} - responses map[string]versionSetFetcherResponse -} - -type versionSetFetcherResponse struct { - value *matchingservice.GetWorkerBuildIdCompatibilityResponse - err error + lock sync.Mutex + matchingClient matchingservice.MatchingServiceClient + futures map[string]future.Future[*matchingservice.GetWorkerBuildIdCompatibilityResponse] } func newVersionSetFetcher(matchingClient matchingservice.MatchingServiceClient) *versionSetFetcher { return &versionSetFetcher{ - matchingClient: matchingClient, - notificationChs: make(map[string]chan struct{}), - responses: make(map[string]versionSetFetcherResponse), + matchingClient: matchingClient, + futures: make(map[string]future.Future[*matchingservice.GetWorkerBuildIdCompatibilityResponse]), } } -func (f *versionSetFetcher) getStoredResponse(taskQueue string) (*matchingservice.GetWorkerBuildIdCompatibilityResponse, error) { - f.Lock() - defer f.Unlock() - if response, found := f.responses[taskQueue]; found { - return response.value, response.err - } - return nil, nil -} - -func (f *versionSetFetcher) fetchTaskQueueVersions(ctx context.Context, ns *namespace.Namespace, taskQueue string) (*matchingservice.GetWorkerBuildIdCompatibilityResponse, error) { - response, err := f.getStoredResponse(taskQueue) - if err != nil || response != nil { - return response, err - } - - f.Lock() - var waitCh chan struct{} - if ch, found := f.notificationChs[taskQueue]; found { - waitCh = ch - } else { - waitCh = make(chan struct{}) - f.notificationChs[taskQueue] = waitCh - +func (f *versionSetFetcher) getFuture(ctx context.Context, ns *namespace.Namespace, taskQueue string) future.Future[*matchingservice.GetWorkerBuildIdCompatibilityResponse] { + f.lock.Lock() + defer f.lock.Unlock() + _, found := f.futures[taskQueue] + if !found { + fut := future.NewFuture[*matchingservice.GetWorkerBuildIdCompatibilityResponse]() + f.futures[taskQueue] = fut go func() { value, err := f.matchingClient.GetWorkerBuildIdCompatibility(ctx, &matchingservice.GetWorkerBuildIdCompatibilityRequest{ NamespaceId: ns.ID().String(), @@ -93,16 +73,14 @@ func (f *versionSetFetcher) fetchTaskQueueVersions(ctx context.Context, ns *name TaskQueue: taskQueue, }, }) - f.Lock() - defer f.Unlock() - f.responses[taskQueue] = versionSetFetcherResponse{value: value, err: err} - close(waitCh) + fut.Set(value, err) }() } - f.Unlock() + return f.futures[taskQueue] +} - <-waitCh - return f.getStoredResponse(taskQueue) +func (f *versionSetFetcher) fetchTaskQueueVersions(ctx context.Context, ns *namespace.Namespace, taskQueue string) (*matchingservice.GetWorkerBuildIdCompatibilityResponse, error) { + return f.getFuture(ctx, ns, taskQueue).Get(ctx) } // Implementation of the GetWorkerTaskReachability API. Expects an already validated request. @@ -202,23 +180,10 @@ func (wh *WorkflowHandler) getTaskQueueReachability(ctx context.Context, request // Query for the unversioned worker isDefaultInQueue = len(request.versionSets) == 0 // Query workflows that have completed tasks marked with a sentinel "unversioned" search attribute. - buildIdsFilter = fmt.Sprintf("BuildIds = %q", common.UnversionedSearchAttribute) + buildIdsFilter = fmt.Sprintf(`%s = "%s"`, searchattribute.BuildIds, common.UnversionedSearchAttribute) } else { // Query for a versioned worker - setIdx := -1 - buildIdIdx := -1 - if len(request.versionSets) > 0 { - for sidx, set := range request.versionSets { - for bidx, id := range set.BuildIds { - if request.buildId == id { - setIdx = sidx - buildIdIdx = bidx - break - } - } - } - } - + setIdx, buildIdIdx := common.FindBuildId(request.versionSets, request.buildId) if setIdx == -1 { // build id not in set - unreachable return &taskQueueReachability, nil @@ -234,9 +199,9 @@ func (wh *WorkflowHandler) getTaskQueueReachability(ctx context.Context, request isDefaultInQueue = setIdx == len(request.versionSets)-1 escapedBuildIds := make([]string, len(set.GetBuildIds())) for i, buildId := range set.GetBuildIds() { - escapedBuildIds[i] = fmt.Sprintf("%q", common.VersionedBuildIdSearchAttribute(buildId)) + escapedBuildIds[i] = sqlparser.String(sqlparser.NewStrVal([]byte(common.VersionedBuildIdSearchAttribute(buildId)))) } - buildIdsFilter = fmt.Sprintf("BuildIds IN (%s)", strings.Join(escapedBuildIds, ",")) + buildIdsFilter = fmt.Sprintf("%s IN (%s)", searchattribute.BuildIds, strings.Join(escapedBuildIds, ",")) } if isDefaultInQueue { @@ -246,7 +211,7 @@ func (wh *WorkflowHandler) getTaskQueueReachability(ctx context.Context, request ) // Take into account started workflows that have not yet been processed by any worker. if request.reachabilityType != enumspb.TASK_REACHABILITY_CLOSED_WORKFLOWS { - buildIdsFilter = fmt.Sprintf("(BuildIds IS NULL OR %s)", buildIdsFilter) + buildIdsFilter = fmt.Sprintf("(%s IS NULL OR %s)", searchattribute.BuildIds, buildIdsFilter) } } @@ -268,9 +233,9 @@ func (wh *WorkflowHandler) queryVisibilityForExisitingWorkflowsReachability( statusFilter := "" switch reachabilityType { case enumspb.TASK_REACHABILITY_OPEN_WORKFLOWS: - statusFilter = " AND ExecutionStatus = \"Running\"" + statusFilter = fmt.Sprintf(` AND %s = "Running"`, searchattribute.ExecutionStatus) case enumspb.TASK_REACHABILITY_CLOSED_WORKFLOWS: - statusFilter = " AND ExecutionStatus != \"Running\"" + statusFilter = fmt.Sprintf(` AND %s != "Running"`, searchattribute.ExecutionStatus) case enumspb.TASK_REACHABILITY_UNSPECIFIED: reachabilityType = enumspb.TASK_REACHABILITY_EXISTING_WORKFLOWS statusFilter = "" @@ -285,7 +250,7 @@ func (wh *WorkflowHandler) queryVisibilityForExisitingWorkflowsReachability( req := manager.CountWorkflowExecutionsRequest{ NamespaceID: ns.ID(), Namespace: ns.Name(), - Query: fmt.Sprintf("TaskQueue = %q AND %s%s", taskQueue, buildIdsFilter, statusFilter), + Query: fmt.Sprintf("%s = %q AND %s%s", searchattribute.TaskQueue, taskQueue, buildIdsFilter, statusFilter), } // TODO(bergundy): is count more efficient than select with page size of 1? diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index 61129ab5995..44727a4b566 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -3653,7 +3653,7 @@ func (wh *WorkflowHandler) GetWorkerTaskReachability(ctx context.Context, reques } } if gotUnversionedRequest && len(request.GetTaskQueues()) == 0 { - return nil, serviceerror.NewInvalidArgument("Cannot get reachability of an unversioned worker without specifying at least one task queue") + return nil, serviceerror.NewInvalidArgument("Cannot get reachability of an unversioned worker without specifying at least one task queue (empty build id is interpereted as unversioned)") } for _, taskQueue := range request.GetTaskQueues() { taskQueue := &taskqueuepb.TaskQueue{Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL} diff --git a/service/matching/db.go b/service/matching/db.go index 0e9e073df2e..84d11e64aac 100644 --- a/service/matching/db.go +++ b/service/matching/db.go @@ -374,7 +374,7 @@ func (db *taskQueueDB) UpdateUserData(ctx context.Context, updateFn func(*persis // We iterate here but in practice there should only be a single build Id added when the limit is enforced. // We do not enforce the limit when applying replication events. for _, buildId := range added { - numTaskQueues, err := db.store.CountTaskQueuesByBuildId(ctx, &persistence.GetTaskQueuesByBuildIdRequest{ + numTaskQueues, err := db.store.CountTaskQueuesByBuildId(ctx, &persistence.CountTaskQueuesByBuildIdRequest{ NamespaceID: db.namespaceID.String(), BuildID: buildId, }) diff --git a/service/matching/matchingEngine_test.go b/service/matching/matchingEngine_test.go index 3cbb2e937d4..475db90c29d 100644 --- a/service/matching/matchingEngine_test.go +++ b/service/matching/matchingEngine_test.go @@ -2615,7 +2615,7 @@ func (*testTaskManager) GetTaskQueuesByBuildId(ctx context.Context, request *per } // CountTaskQueuesByBuildId implements persistence.TaskManager -func (*testTaskManager) CountTaskQueuesByBuildId(ctx context.Context, request *persistence.GetTaskQueuesByBuildIdRequest) (int, error) { +func (*testTaskManager) CountTaskQueuesByBuildId(ctx context.Context, request *persistence.CountTaskQueuesByBuildIdRequest) (int, error) { // This is only used to validate that the build id to task queue mapping is enforced (at the time of writing), report 0. return 0, nil } diff --git a/service/matching/taskQueueManager.go b/service/matching/taskQueueManager.go index 79c0ab0c03c..72f878a67c6 100644 --- a/service/matching/taskQueueManager.go +++ b/service/matching/taskQueueManager.go @@ -488,7 +488,6 @@ func (c *taskQueueManagerImpl) GetUserData(ctx context.Context) (*persistencespb return c.db.GetUserData(ctx) } -//nolint:revive // control coupling func (c *taskQueueManagerImpl) UpdateUserData(ctx context.Context, options UserDataUpdateOptions, updateFn UserDataUpdateFunc) error { newData, err := c.db.UpdateUserData(ctx, updateFn, options.TaskQueueLimitPerBuildId) if err != nil { diff --git a/tests/advanced_visibility_test.go b/tests/advanced_visibility_test.go index 9a6b191487a..0435c8b015e 100644 --- a/tests/advanced_visibility_test.go +++ b/tests/advanced_visibility_test.go @@ -2296,14 +2296,6 @@ func (s *advancedVisibilitySuite) TestWorkerTaskReachability_ByBuildId() { }, }) s.Require().NoError(err) - _, err = s.engine.UpdateWorkerBuildIdCompatibility(ctx, &workflowservice.UpdateWorkerBuildIdCompatibilityRequest{ - Namespace: s.namespace, - TaskQueue: tq2, - Operation: &workflowservice.UpdateWorkerBuildIdCompatibilityRequest_AddNewBuildIdInNewDefaultSet{ - AddNewBuildIdInNewDefaultSet: v0, - }, - }) - s.Require().NoError(err) _, err = s.engine.UpdateWorkerBuildIdCompatibility(ctx, &workflowservice.UpdateWorkerBuildIdCompatibilityRequest{ Namespace: s.namespace, TaskQueue: tq1, @@ -2315,6 +2307,14 @@ func (s *advancedVisibilitySuite) TestWorkerTaskReachability_ByBuildId() { }, }) s.Require().NoError(err) + _, err = s.engine.UpdateWorkerBuildIdCompatibility(ctx, &workflowservice.UpdateWorkerBuildIdCompatibilityRequest{ + Namespace: s.namespace, + TaskQueue: tq2, + Operation: &workflowservice.UpdateWorkerBuildIdCompatibilityRequest_AddNewBuildIdInNewDefaultSet{ + AddNewBuildIdInNewDefaultSet: v0, + }, + }) + s.Require().NoError(err) // Map v0 to a third queue to test limit enforcement _, err = s.engine.UpdateWorkerBuildIdCompatibility(ctx, &workflowservice.UpdateWorkerBuildIdCompatibilityRequest{