Skip to content

Commit

Permalink
add tests for vector operations (#275)
Browse files Browse the repository at this point in the history
* add tests for vector operations

* rename vectors->points in vector update API

* fix tests
  • Loading branch information
generall authored Sep 4, 2023
1 parent 77837b0 commit 578db81
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 24 deletions.
2 changes: 1 addition & 1 deletion qdrant_client/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def upsert(
def update_vectors(
self,
collection_name: str,
vectors: Sequence[types.PointVectors],
points: Sequence[types.PointVectors],
**kwargs: Any,
) -> types.UpdateResult:
raise NotImplementedError()
Expand Down
21 changes: 20 additions & 1 deletion qdrant_client/conversions/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,29 @@
from qdrant_client import grpc as grpc
from qdrant_client.http import models as rest

typing_remap = {
rest.StrictStr: str,
rest.StrictInt: int,
rest.StrictFloat: float,
rest.StrictBool: bool,
}


def remap_type(tp: type) -> type:
"""Remap type to a type that can be used in type annotations
Pydantic uses custom types for strict types, so we need to remap them to standard types
so that they can be used in type annotations and isinstance checks
"""
return typing_remap.get(tp, tp)


def get_args_subscribed(tp: type) -> Tuple:
"""Get type arguments with all substitutions performed. Supports subscripted generics having __origin__"""
return tuple(arg if not hasattr(arg, "__origin__") else arg.__origin__ for arg in get_args(tp))
return tuple(
remap_type(arg if not hasattr(arg, "__origin__") else arg.__origin__)
for arg in get_args(tp)
)


Filter = Union[rest.Filter, grpc.Filter]
Expand Down
8 changes: 4 additions & 4 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,11 +656,11 @@ def _update_named_vectors(self, idx: int, vectors: Dict[str, List[float]]) -> No
for vector_name, vector in vectors.items():
self.vectors[vector_name][idx] = np.array(vector)

def update_vectors(self, vectors: Sequence[types.PointVectors]) -> None:
for vector in vectors:
point_id = vector.id
def update_vectors(self, points: Sequence[types.PointVectors]) -> None:
for point in points:
point_id = point.id
idx = self.ids[point_id]
vector_struct = vector.vector
vector_struct = point.vector
if isinstance(vector_struct, list):
fixed_vectors = {DEFAULT_VECTOR_NAME: vector_struct}
else:
Expand Down
4 changes: 2 additions & 2 deletions qdrant_client/local/qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,11 @@ def upsert(
def update_vectors(
self,
collection_name: str,
vectors: Sequence[types.PointVectors],
points: Sequence[types.PointVectors],
**kwargs: Any,
) -> types.UpdateResult:
collection = self._get_collection(collection_name)
collection.update_vectors(vectors)
collection.update_vectors(points)
return self._default_update_result()

def delete_vectors(
Expand Down
8 changes: 4 additions & 4 deletions qdrant_client/qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from qdrant_client.conversions import common_types as types
from qdrant_client.http import ApiClient, SyncApis
from qdrant_client.local.qdrant_local import QdrantLocal
from qdrant_client.qdrant_fastembed import QdrantFastembedMixin
from qdrant_client.migrate import migrate
from qdrant_client.qdrant_fastembed import QdrantFastembedMixin
from qdrant_client.qdrant_remote import QdrantRemote


Expand Down Expand Up @@ -779,7 +779,7 @@ def upsert(
def update_vectors(
self,
collection_name: str,
vectors: Sequence[types.PointVectors],
points: Sequence[types.PointVectors],
wait: bool = True,
ordering: Optional[types.WriteOrdering] = None,
**kwargs: Any,
Expand All @@ -788,7 +788,7 @@ def update_vectors(
Args:
collection_name: Name of the collection to update vectors in
vectors: List of (id, vector) pairs to update. Vector might be a list of numbers or a dict of named vectors.
points: List of (id, vector) pairs to update. Vector might be a list of numbers or a dict of named vectors.
Example
- `PointVectors(id=1, vector=[1, 2, 3])`
- `PointVectors(id=2, vector={'vector_1': [1, 2, 3], 'vector_2': [4, 5, 6]})`
Expand All @@ -807,7 +807,7 @@ def update_vectors(

return self._client.update_vectors(
collection_name=collection_name,
vectors=vectors,
points=points,
wait=wait,
ordering=ordering,
)
Expand Down
10 changes: 5 additions & 5 deletions qdrant_client/qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,13 +1060,13 @@ def upsert(
def update_vectors(
self,
collection_name: str,
vectors: Sequence[types.PointVectors],
points: Sequence[types.PointVectors],
wait: bool = True,
ordering: Optional[types.WriteOrdering] = None,
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
vectors = [RestToGrpc.convert_point_vectors(vector) for vector in vectors]
points = [RestToGrpc.convert_point_vectors(point) for point in points]

if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -1075,7 +1075,7 @@ def update_vectors(
grpc.UpdatePointVectors(
collection_name=collection_name,
wait=wait,
vectors=vectors,
points=points,
ordering=ordering,
)
).result
Expand All @@ -1085,7 +1085,7 @@ def update_vectors(
return self.openapi_client.points_api.update_vectors(
collection_name=collection_name,
wait=wait,
update_vectors=models.UpdateVectors(points=vectors),
update_vectors=models.UpdateVectors(points=points),
ordering=ordering,
).result

Expand All @@ -1111,7 +1111,7 @@ def delete_vectors(
vectors=grpc.VectorsSelector(
names=vectors,
),
points=points,
points_selector=points,
ordering=ordering,
)
).result
Expand Down
6 changes: 5 additions & 1 deletion tests/congruence_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ def compare_collections(
collection_1 = client_1.get_collection(collection_name)
collection_2 = client_2.get_collection(collection_name)

assert all(getattr(collection_1, attr) == getattr(collection_2, attr) for attr in attrs)
for attr in attrs:
assert getattr(collection_1, attr) == getattr(collection_2, attr), (
f"client_1.{attr} = {getattr(collection_1, attr)}, "
f"client_2.{attr} = {getattr(collection_2, attr)}"
)

# num_vectors * 2 to be sure that we have no excess points uploaded
compare_client_results(
Expand Down
2 changes: 1 addition & 1 deletion tests/congruence_tests/test_delete_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_delete_points(local_client, remote_client):
local_client.delete(COLLECTION_NAME, found_ids)
remote_client.delete(COLLECTION_NAME, found_ids)

compare_collections(local_client, remote_client, 100)
compare_collections(local_client, remote_client, 100, attrs=("points_count",))

compare_client_results(
local_client,
Expand Down
4 changes: 2 additions & 2 deletions tests/congruence_tests/test_optional_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def test_simple_opt_vectors_search():

local_client.update_vectors(
collection_name=COLLECTION_NAME,
vectors=update_vectors,
points=update_vectors,
)

remote_client.update_vectors(
collection_name=COLLECTION_NAME,
vectors=update_vectors,
points=update_vectors,
)

compare_client_results(
Expand Down
68 changes: 65 additions & 3 deletions tests/test_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from qdrant_client import QdrantClient
from qdrant_client._pydantic_compat import to_dict
from qdrant_client.conversions.common_types import Record
from qdrant_client.conversions.common_types import PointVectors, Record
from qdrant_client.conversions.conversion import grpc_to_payload, json_to_value
from qdrant_client.local.qdrant_local import QdrantLocal
from qdrant_client.models import (
Expand Down Expand Up @@ -727,7 +727,10 @@ def test_qdrant_client_integration_update_collection(prefer_grpc):
if version is not None and version >= "v1.4.0":
assert collection_info.config.params.vectors["text"].hnsw_config.m == 32
assert collection_info.config.params.vectors["text"].hnsw_config.ef_construct == 123
assert collection_info.config.params.vectors["text"].quantization_config.product.compression == CompressionRatio.X32
assert (
collection_info.config.params.vectors["text"].quantization_config.product.compression
== CompressionRatio.X32
)
assert collection_info.config.params.vectors["text"].quantization_config.product.always_ram
assert collection_info.config.params.vectors["text"].on_disk
assert collection_info.config.hnsw_config.ef_construct == 123
Expand Down Expand Up @@ -837,6 +840,65 @@ def test_quantization_config(prefer_grpc):
)


@pytest.mark.parametrize("prefer_grpc", [False, True])
def test_vector_update(prefer_grpc):
client = QdrantClient(prefer_grpc=prefer_grpc, timeout=TIMEOUT)

client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=DIM, distance=Distance.DOT),
timeout=TIMEOUT,
)

uuid1 = str(uuid.uuid4())
uuid2 = str(uuid.uuid4())
uuid3 = str(uuid.uuid4())
uuid4 = str(uuid.uuid4())

client.upsert(
collection_name=COLLECTION_NAME,
points=[
PointStruct(id=uuid1, payload={"a": 1}, vector=np.random.rand(DIM).tolist()),
PointStruct(id=uuid2, payload={"a": 2}, vector=np.random.rand(DIM).tolist()),
PointStruct(id=uuid3, payload={"b": 1}, vector=np.random.rand(DIM).tolist()),
PointStruct(id=uuid4, payload={"b": 2}, vector=np.random.rand(DIM).tolist()),
],
wait=True,
)

client.update_vectors(
collection_name=COLLECTION_NAME,
points=[
PointVectors(
id=uuid2,
vector=[1.0] * DIM,
)
],
)

result = client.retrieve(
collection_name=COLLECTION_NAME,
ids=[uuid2],
with_vectors=True,
)[0]

assert result.vector == [1] * DIM

client.delete_vectors(
collection_name=COLLECTION_NAME,
vectors=[""],
points=Filter(must=[FieldCondition(key="b", range=Range(gte=1))]),
)

result = client.retrieve(
collection_name=COLLECTION_NAME,
ids=[uuid4],
with_vectors=True,
)[0]

assert result.vector == {}


@pytest.mark.parametrize("prefer_grpc", [False, True])
def test_conditional_payload_update(prefer_grpc):
client = QdrantClient(prefer_grpc=prefer_grpc, timeout=TIMEOUT)
Expand Down Expand Up @@ -1157,4 +1219,4 @@ def test_client_close():
test_points_crud()
test_has_id_condition()
test_insert_float()
test_legacy_imports()
test_legacy_imports()

0 comments on commit 578db81

Please sign in to comment.