diff --git a/qdrant_client/client_base.py b/qdrant_client/client_base.py index 5d98fe52..5e9439c1 100644 --- a/qdrant_client/client_base.py +++ b/qdrant_client/client_base.py @@ -133,7 +133,7 @@ def upsert( def update_vectors( self, collection_name: str, - vectors: Sequence[types.PointVectors], + points: Sequence[types.PointVectors], **kwargs: Any, ) -> types.UpdateResult: raise NotImplementedError() diff --git a/qdrant_client/conversions/common_types.py b/qdrant_client/conversions/common_types.py index 73099263..6cf041ef 100644 --- a/qdrant_client/conversions/common_types.py +++ b/qdrant_client/conversions/common_types.py @@ -13,10 +13,29 @@ from qdrant_client import grpc as grpc from qdrant_client.http import models as rest +typing_remap = { + rest.StrictStr: str, + rest.StrictInt: int, + rest.StrictFloat: float, + rest.StrictBool: bool, +} + + +def remap_type(tp: type) -> type: + """Remap type to a type that can be used in type annotations + + Pydantic uses custom types for strict types, so we need to remap them to standard types + so that they can be used in type annotations and isinstance checks + """ + return typing_remap.get(tp, tp) + def get_args_subscribed(tp: type) -> Tuple: """Get type arguments with all substitutions performed. Supports subscripted generics having __origin__""" - return tuple(arg if not hasattr(arg, "__origin__") else arg.__origin__ for arg in get_args(tp)) + return tuple( + remap_type(arg if not hasattr(arg, "__origin__") else arg.__origin__) + for arg in get_args(tp) + ) Filter = Union[rest.Filter, grpc.Filter] diff --git a/qdrant_client/local/local_collection.py b/qdrant_client/local/local_collection.py index a4be3e99..fbe4a15f 100644 --- a/qdrant_client/local/local_collection.py +++ b/qdrant_client/local/local_collection.py @@ -656,11 +656,11 @@ def _update_named_vectors(self, idx: int, vectors: Dict[str, List[float]]) -> No for vector_name, vector in vectors.items(): self.vectors[vector_name][idx] = np.array(vector) - def update_vectors(self, vectors: Sequence[types.PointVectors]) -> None: - for vector in vectors: - point_id = vector.id + def update_vectors(self, points: Sequence[types.PointVectors]) -> None: + for point in points: + point_id = point.id idx = self.ids[point_id] - vector_struct = vector.vector + vector_struct = point.vector if isinstance(vector_struct, list): fixed_vectors = {DEFAULT_VECTOR_NAME: vector_struct} else: diff --git a/qdrant_client/local/qdrant_local.py b/qdrant_client/local/qdrant_local.py index 7d4361a3..b748fbde 100644 --- a/qdrant_client/local/qdrant_local.py +++ b/qdrant_client/local/qdrant_local.py @@ -342,11 +342,11 @@ def upsert( def update_vectors( self, collection_name: str, - vectors: Sequence[types.PointVectors], + points: Sequence[types.PointVectors], **kwargs: Any, ) -> types.UpdateResult: collection = self._get_collection(collection_name) - collection.update_vectors(vectors) + collection.update_vectors(points) return self._default_update_result() def delete_vectors( diff --git a/qdrant_client/qdrant_client.py b/qdrant_client/qdrant_client.py index a4c8af7c..cb4a56f7 100644 --- a/qdrant_client/qdrant_client.py +++ b/qdrant_client/qdrant_client.py @@ -5,8 +5,8 @@ from qdrant_client.conversions import common_types as types from qdrant_client.http import ApiClient, SyncApis from qdrant_client.local.qdrant_local import QdrantLocal -from qdrant_client.qdrant_fastembed import QdrantFastembedMixin from qdrant_client.migrate import migrate +from qdrant_client.qdrant_fastembed import QdrantFastembedMixin from qdrant_client.qdrant_remote import QdrantRemote @@ -779,7 +779,7 @@ def upsert( def update_vectors( self, collection_name: str, - vectors: Sequence[types.PointVectors], + points: Sequence[types.PointVectors], wait: bool = True, ordering: Optional[types.WriteOrdering] = None, **kwargs: Any, @@ -788,7 +788,7 @@ def update_vectors( Args: collection_name: Name of the collection to update vectors in - vectors: List of (id, vector) pairs to update. Vector might be a list of numbers or a dict of named vectors. + points: List of (id, vector) pairs to update. Vector might be a list of numbers or a dict of named vectors. Example - `PointVectors(id=1, vector=[1, 2, 3])` - `PointVectors(id=2, vector={'vector_1': [1, 2, 3], 'vector_2': [4, 5, 6]})` @@ -807,7 +807,7 @@ def update_vectors( return self._client.update_vectors( collection_name=collection_name, - vectors=vectors, + points=points, wait=wait, ordering=ordering, ) diff --git a/qdrant_client/qdrant_remote.py b/qdrant_client/qdrant_remote.py index 4f5dcac4..23ac7257 100644 --- a/qdrant_client/qdrant_remote.py +++ b/qdrant_client/qdrant_remote.py @@ -1060,13 +1060,13 @@ def upsert( def update_vectors( self, collection_name: str, - vectors: Sequence[types.PointVectors], + points: Sequence[types.PointVectors], wait: bool = True, ordering: Optional[types.WriteOrdering] = None, **kwargs: Any, ) -> types.UpdateResult: if self._prefer_grpc: - vectors = [RestToGrpc.convert_point_vectors(vector) for vector in vectors] + points = [RestToGrpc.convert_point_vectors(point) for point in points] if isinstance(ordering, models.WriteOrdering): ordering = RestToGrpc.convert_write_ordering(ordering) @@ -1075,7 +1075,7 @@ def update_vectors( grpc.UpdatePointVectors( collection_name=collection_name, wait=wait, - vectors=vectors, + points=points, ordering=ordering, ) ).result @@ -1085,7 +1085,7 @@ def update_vectors( return self.openapi_client.points_api.update_vectors( collection_name=collection_name, wait=wait, - update_vectors=models.UpdateVectors(points=vectors), + update_vectors=models.UpdateVectors(points=points), ordering=ordering, ).result @@ -1111,7 +1111,7 @@ def delete_vectors( vectors=grpc.VectorsSelector( names=vectors, ), - points=points, + points_selector=points, ordering=ordering, ) ).result diff --git a/tests/congruence_tests/test_common.py b/tests/congruence_tests/test_common.py index f3e8f9d8..9700d2a9 100644 --- a/tests/congruence_tests/test_common.py +++ b/tests/congruence_tests/test_common.py @@ -80,7 +80,11 @@ def compare_collections( collection_1 = client_1.get_collection(collection_name) collection_2 = client_2.get_collection(collection_name) - assert all(getattr(collection_1, attr) == getattr(collection_2, attr) for attr in attrs) + for attr in attrs: + assert getattr(collection_1, attr) == getattr(collection_2, attr), ( + f"client_1.{attr} = {getattr(collection_1, attr)}, " + f"client_2.{attr} = {getattr(collection_2, attr)}" + ) # num_vectors * 2 to be sure that we have no excess points uploaded compare_client_results( diff --git a/tests/congruence_tests/test_delete_points.py b/tests/congruence_tests/test_delete_points.py index ed24a565..8cad0fa1 100644 --- a/tests/congruence_tests/test_delete_points.py +++ b/tests/congruence_tests/test_delete_points.py @@ -30,7 +30,7 @@ def test_delete_points(local_client, remote_client): local_client.delete(COLLECTION_NAME, found_ids) remote_client.delete(COLLECTION_NAME, found_ids) - compare_collections(local_client, remote_client, 100) + compare_collections(local_client, remote_client, 100, attrs=("points_count",)) compare_client_results( local_client, diff --git a/tests/congruence_tests/test_optional_vectors.py b/tests/congruence_tests/test_optional_vectors.py index affdd538..7adfb960 100644 --- a/tests/congruence_tests/test_optional_vectors.py +++ b/tests/congruence_tests/test_optional_vectors.py @@ -62,12 +62,12 @@ def test_simple_opt_vectors_search(): local_client.update_vectors( collection_name=COLLECTION_NAME, - vectors=update_vectors, + points=update_vectors, ) remote_client.update_vectors( collection_name=COLLECTION_NAME, - vectors=update_vectors, + points=update_vectors, ) compare_client_results( diff --git a/tests/test_qdrant_client.py b/tests/test_qdrant_client.py index 6136718f..322eaf13 100644 --- a/tests/test_qdrant_client.py +++ b/tests/test_qdrant_client.py @@ -10,7 +10,7 @@ from qdrant_client import QdrantClient from qdrant_client._pydantic_compat import to_dict -from qdrant_client.conversions.common_types import Record +from qdrant_client.conversions.common_types import PointVectors, Record from qdrant_client.conversions.conversion import grpc_to_payload, json_to_value from qdrant_client.local.qdrant_local import QdrantLocal from qdrant_client.models import ( @@ -727,7 +727,10 @@ def test_qdrant_client_integration_update_collection(prefer_grpc): if version is not None and version >= "v1.4.0": assert collection_info.config.params.vectors["text"].hnsw_config.m == 32 assert collection_info.config.params.vectors["text"].hnsw_config.ef_construct == 123 - assert collection_info.config.params.vectors["text"].quantization_config.product.compression == CompressionRatio.X32 + assert ( + collection_info.config.params.vectors["text"].quantization_config.product.compression + == CompressionRatio.X32 + ) assert collection_info.config.params.vectors["text"].quantization_config.product.always_ram assert collection_info.config.params.vectors["text"].on_disk assert collection_info.config.hnsw_config.ef_construct == 123 @@ -837,6 +840,65 @@ def test_quantization_config(prefer_grpc): ) +@pytest.mark.parametrize("prefer_grpc", [False, True]) +def test_vector_update(prefer_grpc): + client = QdrantClient(prefer_grpc=prefer_grpc, timeout=TIMEOUT) + + client.recreate_collection( + collection_name=COLLECTION_NAME, + vectors_config=VectorParams(size=DIM, distance=Distance.DOT), + timeout=TIMEOUT, + ) + + uuid1 = str(uuid.uuid4()) + uuid2 = str(uuid.uuid4()) + uuid3 = str(uuid.uuid4()) + uuid4 = str(uuid.uuid4()) + + client.upsert( + collection_name=COLLECTION_NAME, + points=[ + PointStruct(id=uuid1, payload={"a": 1}, vector=np.random.rand(DIM).tolist()), + PointStruct(id=uuid2, payload={"a": 2}, vector=np.random.rand(DIM).tolist()), + PointStruct(id=uuid3, payload={"b": 1}, vector=np.random.rand(DIM).tolist()), + PointStruct(id=uuid4, payload={"b": 2}, vector=np.random.rand(DIM).tolist()), + ], + wait=True, + ) + + client.update_vectors( + collection_name=COLLECTION_NAME, + points=[ + PointVectors( + id=uuid2, + vector=[1.0] * DIM, + ) + ], + ) + + result = client.retrieve( + collection_name=COLLECTION_NAME, + ids=[uuid2], + with_vectors=True, + )[0] + + assert result.vector == [1] * DIM + + client.delete_vectors( + collection_name=COLLECTION_NAME, + vectors=[""], + points=Filter(must=[FieldCondition(key="b", range=Range(gte=1))]), + ) + + result = client.retrieve( + collection_name=COLLECTION_NAME, + ids=[uuid4], + with_vectors=True, + )[0] + + assert result.vector == {} + + @pytest.mark.parametrize("prefer_grpc", [False, True]) def test_conditional_payload_update(prefer_grpc): client = QdrantClient(prefer_grpc=prefer_grpc, timeout=TIMEOUT) @@ -1157,4 +1219,4 @@ def test_client_close(): test_points_crud() test_has_id_condition() test_insert_float() - test_legacy_imports() \ No newline at end of file + test_legacy_imports()