Skip to content

Commit

Permalink
universal-search: Query Group API and local mode (#690)
Browse files Browse the repository at this point in the history
* universal-search: Query Group API and local mode

* requires different logic for gRPC error assertion

* you come to me at runtime for a compile time issue

* suddenly throwing a different error

* test more group key types

* extend limit of prefetches during group_by

* rescoring is the issue

* one problem at a time please

* code review

* regen clients

* code review

* add lookup_from to query_points_groups

* test with_lookup

* test and fix gRPC

* drop dedicated conversion

* Update qdrant_client/qdrant_client.py

Co-authored-by: Luis Cossío <[email protected]>

* regen async

* Distribution-based score fusion in local mode (#703)

* pre-implement dbsf

* add dbsf congruence tests

* mypy lints

* add conversions

* tests: add test for dbsf conversion

---------

Co-authored-by: George Panchuk <[email protected]>

* Random sample in local mode (#705)

* pre-implement random sampling

* generate models

* add conversions and tests

* fix mypy lints

* tests: add test for sample random conversion

* use camelcase Sample.Random

* review fixes

* fix mypy

---------

Co-authored-by: George Panchuk <[email protected]>

* fix: add type ignore for mypy

* fix: fix type hints for 3.8

* fix: do not run mypy on async client generator in CI, simplify condition

* Grpc comparison in tests (#726)

* add parametrized fixture for using grpc too

* compare grpc and http without running each setup twice

* fix: fix exception types in invalid types test

* fix: remove random seed which led to a erroneous sequence

---------

Co-authored-by: Luis Cossío <[email protected]>

---------

Co-authored-by: Luis Cossío <[email protected]>
Co-authored-by: George Panchuk <[email protected]>
  • Loading branch information
3 people authored Aug 9, 2024
1 parent 3c42c3a commit 44bbded
Show file tree
Hide file tree
Showing 24 changed files with 1,874 additions and 496 deletions.
5 changes: 1 addition & 4 deletions .github/workflows/type-checkers.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ jobs:
- name: mypy
run: |
if [[ ${{ matrix.python-version }} != "3.8" ]] || [[ ! -d "tools/async_client_generator" ]]; then
# async_qdrant_fastembed.py is autogenerated and erases type ignore statements from the source code
poetry run mypy . --exclude "async_qdrant_fastembed.py" --disallow-incomplete-defs --disallow-untyped-defs
fi
poetry run mypy . --exclude "tools/async_client_generator" --disallow-incomplete-defs --disallow-untyped-defs
- name: pyright
run: |
Expand Down
29 changes: 29 additions & 0 deletions qdrant_client/async_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,35 @@ async def query_points(
) -> types.QueryResponse:
raise NotImplementedError()

async def query_points_groups(
self,
collection_name: str,
group_by: str,
query: Union[
types.PointId,
List[float],
List[List[float]],
types.SparseVector,
types.Query,
types.NumpyArray,
types.Document,
None,
] = None,
using: Optional[str] = None,
prefetch: Union[types.Prefetch, List[types.Prefetch], None] = None,
query_filter: Optional[types.Filter] = None,
search_params: Optional[types.SearchParams] = None,
limit: int = 10,
group_size: int = 3,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
score_threshold: Optional[float] = None,
with_lookup: Optional[types.WithLookupInterface] = None,
lookup_from: Optional[types.LookupLocation] = None,
**kwargs: Any,
) -> types.GroupsResult:
raise NotImplementedError()

async def recommend_batch(
self, collection_name: str, requests: Sequence[types.RecommendRequest], **kwargs: Any
) -> List[List[types.ScoredPoint]]:
Expand Down
136 changes: 136 additions & 0 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,142 @@ async def query_points(
**kwargs,
)

async def query_points_groups(
self,
collection_name: str,
group_by: str,
query: Union[
types.PointId,
List[float],
List[List[float]],
types.SparseVector,
types.Query,
types.NumpyArray,
types.Document,
None,
] = None,
using: Optional[str] = None,
prefetch: Union[types.Prefetch, List[types.Prefetch], None] = None,
query_filter: Optional[types.Filter] = None,
search_params: Optional[types.SearchParams] = None,
limit: int = 10,
group_size: int = 3,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
score_threshold: Optional[float] = None,
with_lookup: Optional[types.WithLookupInterface] = None,
lookup_from: Optional[types.LookupLocation] = None,
consistency: Optional[types.ReadConsistency] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> types.GroupsResult:
"""Universal endpoint to group on any available operation, such as search, recommendation, discovery, context search.
Args:
collection_name: Collection to search in
query:
Query for the chosen search type operation.
- If `str` - use string as UUID of the existing point as a search query.
- If `int` - use integer as ID of the existing point as a search query.
- If `List[float]` - use as a dense vector for nearest search.
- If `List[List[float]]` - use as a multi-vector for nearest search.
- If `SparseVector` - use as a sparse vector for nearest search.
- If `Query` - use as a query for specific search type.
- If `NumpyArray` - use as a dense vector for nearest search.
- If `Document` - infer vector from the document text and use
it for nearest search (requires `fastembed` package installed).
- If `None` - return first `limit` points from the collection.
prefetch: prefetch queries to make a selection of the data to be used with the main query
query_filter:
- Exclude vectors which doesn't fit given conditions.
- If `None` - search among all vectors
search_params: Additional search params
limit: How many results return
group_size: How many results return for each group
group_by: Name of the payload field to group by.
Field must be of type "keyword" or "integer".
Nested fields are specified using dot notation, e.g. "nested_field.subfield".
with_payload:
- Specify which stored payload should be attached to the result.
- If `True` - attach all payload
- If `False` - do not attach any payload
- If List of string - include only specified fields
- If `PayloadSelector` - use explicit rules
with_vectors:
- If `True` - Attach stored vector to the search result.
- If `False` - Do not attach vector.
- If List of string - include only specified fields
- Default: `False`
score_threshold:
Define a minimal score threshold for the result.
If defined, less similar results will not be returned.
Score of the returned result might be higher or smaller than the threshold depending
on the Distance function used.
E.g. for cosine similarity only higher scores will be returned.
using:
Name of the vectors to use for query.
If `None` - use default vectors or provided in named vector structures.
with_lookup:
Look for points in another collection using the group ids.
If specified, each group will contain a record from the specified collection
with the same id as the group id. In addition, the parameter allows to specify
which parts of the record should be returned, like in `with_payload` and `with_vectors` parameters.
lookup_from:
Defines a location (collection and vector field name), used to lookup vectors being referenced in the query as IDs.
If `None` - current collection will be used.
consistency:
Read consistency of the search. Defines how many replicas should be queried before returning the result. Values:
- int - number of replicas to query, values should present in all queried replicas
- 'majority' - query all replicas, but return values present in the majority of replicas
- 'quorum' - query the majority of replicas, return values present in all of them
- 'all' - query all replicas, and return values present in all replicas
shard_key_selector:
This parameter allows to specify which shards should be queried.
If `None` - query all shards. Only works for collections with `custom` sharding method.
timeout:
Overrides global timeout for this search. Unit is seconds.
Examples:
`Search for closest points and group results`::
qdrant.query_points_groups(
collection_name="test_collection",
query=[1.0, 0.1, 0.2, 0.7],
group_by="color",
group_size=3,
)
Returns:
List of groups with not more than `group_size` hits in each group.
Each group also contains an id of the group, which is the value of the payload field.
"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
(using, query, prefetch) = self._resolve_query_to_embedding_embeddings_and_prefetch(
query, prefetch, using, limit
)
return await self._client.query_points_groups(
collection_name=collection_name,
query=query,
prefetch=prefetch,
query_filter=query_filter,
search_params=search_params,
group_by=group_by,
limit=limit,
group_size=group_size,
with_payload=with_payload,
with_vectors=with_vectors,
score_threshold=score_threshold,
using=using,
with_lookup=with_lookup,
consistency=consistency,
shard_key_selector=shard_key_selector,
timeout=timeout,
**kwargs,
)

async def search_groups(
self,
collection_name: str,
Expand Down
132 changes: 131 additions & 1 deletion qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,136 @@ async def query_batch_points(
assert http_res is not None, "Query batch returned None"
return http_res

async def query_points_groups(
self,
collection_name: str,
group_by: str,
query: Union[
types.PointId,
List[float],
List[List[float]],
types.SparseVector,
types.Query,
types.NumpyArray,
types.Document,
None,
] = None,
using: Optional[str] = None,
prefetch: Union[types.Prefetch, List[types.Prefetch], None] = None,
query_filter: Optional[types.Filter] = None,
search_params: Optional[types.SearchParams] = None,
limit: int = 10,
group_size: int = 3,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
score_threshold: Optional[float] = None,
with_lookup: Optional[types.WithLookupInterface] = None,
lookup_from: Optional[types.LookupLocation] = None,
consistency: Optional[types.ReadConsistency] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> types.GroupsResult:
if self._prefer_grpc:
if isinstance(query, get_args(models.Query)):
query = RestToGrpc.convert_query(query)
if isinstance(prefetch, models.Prefetch):
prefetch = [RestToGrpc.convert_prefetch_query(prefetch)]
if isinstance(prefetch, list):
prefetch = [
RestToGrpc.convert_prefetch_query(p) if isinstance(p, models.Prefetch) else p
for p in prefetch
]
if isinstance(query_filter, models.Filter):
query_filter = RestToGrpc.convert_filter(model=query_filter)
if isinstance(search_params, models.SearchParams):
search_params = RestToGrpc.convert_search_params(search_params)
if isinstance(with_payload, get_args_subscribed(models.WithPayloadInterface)):
with_payload = RestToGrpc.convert_with_payload_interface(with_payload)
if isinstance(with_vectors, get_args_subscribed(models.WithVector)):
with_vectors = RestToGrpc.convert_with_vectors(with_vectors)
if isinstance(with_lookup, models.WithLookup):
with_lookup = RestToGrpc.convert_with_lookup(with_lookup)
if isinstance(with_lookup, str):
with_lookup = grpc.WithLookup(collection=with_lookup)
if isinstance(lookup_from, models.LookupLocation):
lookup_from = RestToGrpc.convert_lookup_location(lookup_from)
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
consistency = RestToGrpc.convert_read_consistency(consistency)
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
result: grpc.QueryGroupsResponse = (
await self.grpc_points.QueryGroups(
grpc.QueryPointGroups(
collection_name=collection_name,
query=query,
prefetch=prefetch,
filter=query_filter,
limit=limit,
with_vectors=with_vectors,
with_payload=with_payload,
params=search_params,
score_threshold=score_threshold,
using=using,
group_by=group_by,
group_size=group_size,
with_lookup=with_lookup,
lookup_from=lookup_from,
timeout=timeout,
shard_key_selector=shard_key_selector,
read_consistency=consistency,
),
timeout=timeout if timeout is None else self._timeout,
)
).result
return GrpcToRest.convert_groups_result(result)
else:
if isinstance(query, grpc.Query):
query = GrpcToRest.convert_query(query)
if isinstance(prefetch, grpc.PrefetchQuery):
prefetch = GrpcToRest.convert_prefetch_query(prefetch)
if isinstance(prefetch, list):
prefetch = [
GrpcToRest.convert_prefetch_query(p)
if isinstance(p, grpc.PrefetchQuery)
else p
for p in prefetch
]
if isinstance(query_filter, grpc.Filter):
query_filter = GrpcToRest.convert_filter(model=query_filter)
if isinstance(search_params, grpc.SearchParams):
search_params = GrpcToRest.convert_search_params(search_params)
if isinstance(with_payload, grpc.WithPayloadSelector):
with_payload = GrpcToRest.convert_with_payload_selector(with_payload)
if isinstance(with_lookup, grpc.WithLookup):
with_lookup = GrpcToRest.convert_with_lookup(with_lookup)
if isinstance(lookup_from, grpc.LookupLocation):
lookup_from = GrpcToRest.convert_lookup_location(lookup_from)
query_request = models.QueryGroupsRequest(
shard_key=shard_key_selector,
prefetch=prefetch,
query=query,
using=using,
filter=query_filter,
params=search_params,
score_threshold=score_threshold,
limit=limit,
group_by=group_by,
group_size=group_size,
with_vector=with_vectors,
with_payload=with_payload,
with_lookup=with_lookup,
lookup_from=lookup_from,
)
query_result = await self.http.points_api.query_points_groups(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
query_groups_request=query_request,
)
assert query_result is not None, "Query points groups API returned None"
return query_result.result

async def search_groups(
self,
collection_name: str,
Expand Down Expand Up @@ -612,7 +742,7 @@ async def search_groups(
if isinstance(with_lookup, models.WithLookup):
with_lookup = RestToGrpc.convert_with_lookup(with_lookup)
if isinstance(with_lookup, str):
with_lookup = grpc.WithLookup(lookup=with_lookup)
with_lookup = grpc.WithLookup(collection=with_lookup)
if isinstance(query_vector, types.NamedVector):
vector = query_vector.vector
vector_name = query_vector.name
Expand Down
29 changes: 29 additions & 0 deletions qdrant_client/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,35 @@ def query_points(
) -> types.QueryResponse:
raise NotImplementedError()

def query_points_groups(
self,
collection_name: str,
group_by: str,
query: Union[
types.PointId,
List[float],
List[List[float]],
types.SparseVector,
types.Query,
types.NumpyArray,
types.Document,
None,
] = None,
using: Optional[str] = None,
prefetch: Union[types.Prefetch, List[types.Prefetch], None] = None,
query_filter: Optional[types.Filter] = None,
search_params: Optional[types.SearchParams] = None,
limit: int = 10,
group_size: int = 3,
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
score_threshold: Optional[float] = None,
with_lookup: Optional[types.WithLookupInterface] = None,
lookup_from: Optional[types.LookupLocation] = None,
**kwargs: Any,
) -> types.GroupsResult:
raise NotImplementedError()

def recommend_batch(
self,
collection_name: str,
Expand Down
Loading

0 comments on commit 44bbded

Please sign in to comment.