Skip to content

Commit

Permalink
Feature: Distance Matrix API (#769)
Browse files Browse the repository at this point in the history
* add remote impls

* add client impl

* regen async client

* start local

* regen async

* local mode

* start congruence

* Fix local mode

* regen async of course

* test filtering

* fix min samples count

* simplify comparaison loop

* simplify samples loop

* add rest/gRPC conversion tests

* fix conversions + tests
  • Loading branch information
agourlay authored Oct 4, 2024
1 parent 9f736ac commit ad89cb6
Show file tree
Hide file tree
Showing 14 changed files with 826 additions and 0 deletions.
22 changes: 22 additions & 0 deletions qdrant_client/async_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,28 @@ async def search_groups(
) -> types.GroupsResult:
raise NotImplementedError()

async def search_distance_matrix_offsets(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
**kwargs: Any,
) -> types.SearchMatrixOffsetsResponse:
raise NotImplementedError()

async def search_distance_matrix_pairs(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
**kwargs: Any,
) -> types.SearchMatrixPairsResponse:
raise NotImplementedError()

async def query_batch_points(
self, collection_name: str, requests: Sequence[types.QueryRequest], **kwargs: Any
) -> List[types.QueryResponse]:
Expand Down
92 changes: 92 additions & 0 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,98 @@ async def recommend(
**kwargs,
)

async def search_distance_matrix_pairs(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
consistency: Optional[types.ReadConsistency] = None,
timeout: Optional[int] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
) -> types.SearchMatrixPairsResponse:
"""Compute distance matrix for sampled points with a pair-based output format.
Args:
collection_name: Name of the collection
query_filter: Filter to apply
limit: How many neighbours per sample to find
sample: How many points to select and search within
using: Name of the vectors to use for search.
If `None` - use default vectors.
consistency:
Read consistency of the search. Defines how many replicas should be queried before returning the result. Values:
- int - number of replicas to query, values should present in all queried replicas
- 'majority' - query all replicas, but return values present in the majority of replicas
- 'quorum' - query the majority of replicas, return values present in all of them
- 'all' - query all replicas, and return values present in all replicas
timeout: Overrides global timeout for this search. Unit is seconds.
shard_key_selector:
This parameter allows to specify which shards should be queried.
If `None` - query all shards. Only works for collections with `custom` sharding method.
Returns:
Return distance matrix using a pair-based encoding.
"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
return await self._client.search_distance_matrix_pairs(
collection_name=collection_name,
query_filter=query_filter,
limit=limit,
sample=sample,
using=using,
consistency=consistency,
timeout=timeout,
shard_key_selector=shard_key_selector,
**kwargs,
)

async def search_distance_matrix_offsets(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
consistency: Optional[types.ReadConsistency] = None,
timeout: Optional[int] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
) -> types.SearchMatrixOffsetsResponse:
"""Compute distance matrix for sampled points with an offset-based output format.
Args:
collection_name: Name of the collection
query_filter: Filter to apply
limit: How many neighbours per sample to find
sample: How many points to select and search within
using: Name of the vectors to use for search.
If `None` - use default vectors.
consistency:
Read consistency of the search. Defines how many replicas should be queried before returning the result. Values:
- int - number of replicas to query, values should present in all queried replicas
- 'majority' - query all replicas, but return values present in the majority of replicas
- 'quorum' - query the majority of replicas, return values present in all of them
- 'all' - query all replicas, and return values present in all replicas
timeout: Overrides global timeout for this search. Unit is seconds.
shard_key_selector:
This parameter allows to specify which shards should be queried.
If `None` - query all shards. Only works for collections with `custom` sharding method.
Returns:
Return distance matrix using an offset-based encoding.
"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
return await self._client.search_distance_matrix_offsets(
collection_name=collection_name,
query_filter=query_filter,
limit=limit,
sample=sample,
using=using,
consistency=consistency,
timeout=timeout,
shard_key_selector=shard_key_selector,
**kwargs,
)

async def recommend_groups(
self,
collection_name: str,
Expand Down
104 changes: 104 additions & 0 deletions qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,110 @@ async def search_groups(
)
).result

async def search_distance_matrix_pairs(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
consistency: Optional[types.ReadConsistency] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> types.SearchMatrixPairsResponse:
if self._prefer_grpc:
if isinstance(query_filter, models.Filter):
query_filter = RestToGrpc.convert_filter(model=query_filter)
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
consistency = RestToGrpc.convert_read_consistency(consistency)
response = await self.grpc_points.SearchMatrixPairs(
grpc.SearchMatrixPoints(
collection_name=collection_name,
filter=query_filter,
sample=sample,
limit=limit,
using=using,
timeout=timeout,
read_consistency=consistency,
shard_key_selector=shard_key_selector,
),
timeout=timeout if timeout is not None else self._timeout,
)
return GrpcToRest.convert_search_matrix_pairs(response.result)
if isinstance(query_filter, grpc.Filter):
search_filter = GrpcToRest.convert_filter(model=query_filter)
search_matrix_result = (
await self.openapi_client.points_api.search_points_matrix_pairs(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
search_matrix_request=models.SearchMatrixRequest(
shard_key=shard_key_selector,
limit=limit,
sample=sample,
using=using,
filter=query_filter,
),
)
).result
assert search_matrix_result is not None, "Search matrix pairs returned None result"
return search_matrix_result

async def search_distance_matrix_offsets(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
consistency: Optional[types.ReadConsistency] = None,
shard_key_selector: Optional[types.ShardKeySelector] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> types.SearchMatrixOffsetsResponse:
if self._prefer_grpc:
if isinstance(query_filter, models.Filter):
query_filter = RestToGrpc.convert_filter(model=query_filter)
if isinstance(shard_key_selector, get_args_subscribed(models.ShardKeySelector)):
shard_key_selector = RestToGrpc.convert_shard_key_selector(shard_key_selector)
if isinstance(consistency, get_args_subscribed(models.ReadConsistency)):
consistency = RestToGrpc.convert_read_consistency(consistency)
response = await self.grpc_points.SearchMatrixOffsets(
grpc.SearchMatrixPoints(
collection_name=collection_name,
filter=query_filter,
sample=sample,
limit=limit,
using=using,
timeout=timeout,
read_consistency=consistency,
shard_key_selector=shard_key_selector,
),
timeout=timeout if timeout is not None else self._timeout,
)
return GrpcToRest.convert_search_matrix_offsets(response.result)
if isinstance(query_filter, grpc.Filter):
search_filter = GrpcToRest.convert_filter(model=query_filter)
search_matrix_result = (
await self.openapi_client.points_api.search_points_matrix_offsets(
collection_name=collection_name,
consistency=consistency,
timeout=timeout,
search_matrix_request=models.SearchMatrixRequest(
shard_key=shard_key_selector,
limit=limit,
sample=sample,
using=using,
filter=query_filter,
),
)
).result
assert search_matrix_result is not None, "Search matrix offsets returned None result"
return search_matrix_result

async def recommend_batch(
self,
collection_name: str,
Expand Down
22 changes: 22 additions & 0 deletions qdrant_client/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ def search_groups(
) -> types.GroupsResult:
raise NotImplementedError()

def search_distance_matrix_offsets(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
**kwargs: Any,
) -> types.SearchMatrixOffsetsResponse:
raise NotImplementedError()

def search_distance_matrix_pairs(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
**kwargs: Any,
) -> types.SearchMatrixPairsResponse:
raise NotImplementedError()

def query_batch_points(
self,
collection_name: str,
Expand Down
4 changes: 4 additions & 0 deletions qdrant_client/conversions/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def get_args_subscribed(tp): # type: ignore

FacetValue: TypeAlias = rest.FacetValue
FacetResponse: TypeAlias = rest.FacetResponse
SearchMatrixRequest = Union[rest.SearchMatrixRequest, grpc.SearchMatrixPoints]
SearchMatrixOffsetsResponse: TypeAlias = rest.SearchMatrixOffsetsResponse
SearchMatrixPairsResponse: TypeAlias = rest.SearchMatrixPairsResponse
SearchMatrixPair: TypeAlias = rest.SearchMatrixPair

VersionInfo: TypeAlias = rest.VersionInfo

Expand Down
46 changes: 46 additions & 0 deletions qdrant_client/conversions/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,6 +1809,29 @@ def convert_health_check_reply(cls, model: grpc.HealthCheckReply) -> rest.Versio
commit=model.commit if model.HasField("commit") else None,
)

@classmethod
def convert_search_matrix_pair(cls, model: grpc.SearchMatrixPair) -> rest.SearchMatrixPair:
return rest.SearchMatrixPair(
a=cls.convert_point_id(model.a),
b=cls.convert_point_id(model.b),
score=model.score,
)

@classmethod
def convert_search_matrix_pairs(cls, model: grpc.SearchMatrixPairs) -> rest.SearchMatrixPairsResponse:
return rest.SearchMatrixPairsResponse(
pairs=[cls.convert_search_matrix_pair(pair) for pair in model.pairs],
)

@classmethod
def convert_search_matrix_offsets(cls, model: grpc.SearchMatrixOffsets) -> rest.SearchMatrixOffsetsResponse:
return rest.SearchMatrixOffsetsResponse(
offsets_row=list(model.offsets_row),
offsets_col=list(model.offsets_col),
scores=list(model.scores),
ids=[cls.convert_point_id(p_id) for p_id in model.ids],
)


# ----------------------------------------
#
Expand Down Expand Up @@ -3581,3 +3604,26 @@ def convert_health_check_reply(cls, model: rest.VersionInfo) -> grpc.HealthCheck
version=model.version,
commit=model.commit,
)

@classmethod
def convert_search_matrix_pair(cls, model: rest.SearchMatrixPair) -> grpc.SearchMatrixPair:
return grpc.SearchMatrixPair(
a=cls.convert_extended_point_id(model.a),
b=cls.convert_extended_point_id(model.b),
score=model.score,
)

@classmethod
def convert_search_matrix_pairs(cls, model: rest.SearchMatrixPairsResponse) -> grpc.SearchMatrixPairs:
return grpc.SearchMatrixPairs(
pairs=[cls.convert_search_matrix_pair(pair) for pair in model.pairs],
)

@classmethod
def convert_search_matrix_offsets(cls, model: rest.SearchMatrixOffsetsResponse) -> grpc.SearchMatrixOffsets:
return grpc.SearchMatrixOffsets(
offsets_row=list(model.offsets_row),
offsets_col=list(model.offsets_col),
scores=list(model.scores),
ids=[cls.convert_extended_point_id(p_id) for p_id in model.ids],
)
28 changes: 28 additions & 0 deletions qdrant_client/local/async_qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,34 @@ async def search(
score_threshold=score_threshold,
)

async def search_distance_matrix_offsets(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
**kwargs: Any,
) -> types.SearchMatrixOffsetsResponse:
collection = self._get_collection(collection_name)
return collection.search_distance_matrix_offsets(
query_filter=query_filter, limit=limit, sample=sample, using=using
)

async def search_distance_matrix_pairs(
self,
collection_name: str,
query_filter: Optional[types.Filter] = None,
limit: int = 3,
sample: int = 10,
using: Optional[str] = None,
**kwargs: Any,
) -> types.SearchMatrixPairsResponse:
collection = self._get_collection(collection_name)
return collection.search_distance_matrix_pairs(
query_filter=query_filter, limit=limit, sample=sample, using=using
)

async def search_groups(
self,
collection_name: str,
Expand Down
Loading

0 comments on commit ad89cb6

Please sign in to comment.