From 81d28732427e2c473e0826058277ccaac5f01566 Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Thu, 3 Aug 2023 20:24:23 +0200 Subject: [PATCH 01/11] new: migration tool v1 --- migrate.py | 122 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 migrate.py diff --git a/migrate.py b/migrate.py new file mode 100644 index 00000000..70aedd7b --- /dev/null +++ b/migrate.py @@ -0,0 +1,122 @@ +from qdrant_client import QdrantClient +from qdrant_client.http import models + + +def migrate(source_client: QdrantClient, dest_client: QdrantClient, batch_size: int = 100) -> None: + source_collections = source_client.get_collections().collections + dest_collections = dest_client.get_collections().collections + + source_collection_names = {collection.name for collection in source_collections} + dest_collection_names = {collection.name for collection in dest_collections} + + missing_collections = source_collection_names - dest_collection_names + assert ( + not missing_collections + ), f"Destination client should have all collections from source client. Missing collections: {missing_collections}" + + compare_collections(list(source_collection_names), source_client, dest_client) + + for collection_name in source_collection_names: + migrate_collection(collection_name, source_client, dest_client, batch_size) + + +def compare_collections( + source_collection_names: list[str], + source_client: QdrantClient, + dest_client: QdrantClient, +) -> bool: + for collection_name in source_collection_names: + source_collection = source_client.get_collection(collection_name) + source_vector_params = source_collection.config.params.vectors + dest_collection = dest_client.get_collection(collection_name) + dest_vector_params = dest_collection.config.params.vectors + + if isinstance(source_vector_params, models.VectorParams): + assert isinstance(dest_vector_params, models.VectorParams), "Mismatched vector params" + assert ( + source_vector_params.size == dest_vector_params.size + ), "Vector size should be equal" + assert ( + source_vector_params.distance == dest_vector_params.distance + ), "Distance should be equal" + + elif isinstance(source_vector_params, dict): + for key, source_vector_param in source_vector_params.items(): + dest_vector_param = dest_vector_params[key] + assert ( + source_vector_param.size == dest_vector_param.size + ), f"Vector size is not equal for {key} in {collection_name}" + assert ( + source_vector_param.distance == dest_vector_param.distance + ), f"Distance is not the same for {key} in {collection_name}" + + return True + + +def migrate_collection( + collection_name: str, + source_client: QdrantClient, + dest_client: QdrantClient, + batch_size: int = 100, +) -> None: + records, next_offset = source_client.scroll( + collection_name, limit=batch_size, with_vectors=True + ) + dest_client.upload_records(collection_name, records) + while next_offset: + records, next_offset = source_client.scroll( + collection_name, offset=next_offset, limit=batch_size, with_vectors=True + ) + dest_client.upload_records(collection_name, records) + + source_client_vectors_count = source_client.get_collection(collection_name).vectors_count + dest_client_vectors_count = dest_client.get_collection(collection_name).vectors_count + assert ( + source_client_vectors_count == dest_client_vectors_count + ), f"Migration failed, vectors count are not equal: source vector count {source_client_vectors_count}, dest vector count {dest_client_vectors_count}" + + +if __name__ == "__main__": + import numpy as np + + VECTOR_NUMBER = 1000 + + local_client = QdrantClient(":memory:") + remote_client = QdrantClient() + + single_vector_collection_kwargs = { + "collection_name": "single_vector_collection", + "vectors_config": models.VectorParams(size=10, distance=models.Distance.COSINE), + } + multiple_vectors_collection_kwargs = { + "collection_name": "multiple_vectors_collection", + "vectors_config": { + "text": models.VectorParams(size=10, distance=models.Distance.EUCLID), + "image": models.VectorParams(size=11, distance=models.Distance.COSINE), + }, + } + local_client.recreate_collection(**single_vector_collection_kwargs) + local_client.recreate_collection(**multiple_vectors_collection_kwargs) + remote_client.recreate_collection(**single_vector_collection_kwargs) + remote_client.recreate_collection(**multiple_vectors_collection_kwargs) + + local_client.upload_collection( + single_vector_collection_kwargs["collection_name"], + vectors=np.random.randn( + VECTOR_NUMBER, single_vector_collection_kwargs["vectors_config"].size + ), + ) + local_client.upload_collection( + multiple_vectors_collection_kwargs["collection_name"], + vectors={ + "text": np.random.randn( + VECTOR_NUMBER, + multiple_vectors_collection_kwargs["vectors_config"]["text"].size, + ), + "image": np.random.randn( + VECTOR_NUMBER, + multiple_vectors_collection_kwargs["vectors_config"]["image"].size, + ), + }, + ) + migrate(local_client, remote_client) From 7c0458d0346ad178f9e5eb94ba6f593927592bc2 Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Sat, 5 Aug 2023 22:32:16 +0200 Subject: [PATCH 02/11] new: add docstrings --- migrate.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/migrate.py b/migrate.py index 70aedd7b..28d9e156 100644 --- a/migrate.py +++ b/migrate.py @@ -3,6 +3,13 @@ def migrate(source_client: QdrantClient, dest_client: QdrantClient, batch_size: int = 100) -> None: + """Migrate all collections from source client to destination client + + Args: + source_client (QdrantClient): Source client + dest_client (QdrantClient): Destination client + batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100. + """ source_collections = source_client.get_collections().collections dest_collections = dest_client.get_collections().collections @@ -14,17 +21,30 @@ def migrate(source_client: QdrantClient, dest_client: QdrantClient, batch_size: not missing_collections ), f"Destination client should have all collections from source client. Missing collections: {missing_collections}" - compare_collections(list(source_collection_names), source_client, dest_client) + _compare_collections(list(source_collection_names), source_client, dest_client) + print(f"Number of collections to migrate: {len(source_collection_names)}", end="\n\n") for collection_name in source_collection_names: - migrate_collection(collection_name, source_client, dest_client, batch_size) + print(f"Start migrating collection `{collection_name}`") + _migrate_collection(collection_name, source_client, dest_client, batch_size) + print(f"Finish migrating collection `{collection_name}`", end="\n\n") -def compare_collections( +def _compare_collections( source_collection_names: list[str], source_client: QdrantClient, dest_client: QdrantClient, ) -> bool: + """Compare collections from source client and destination client + + Args: + source_collection_names (list[str]): List of collection names + source_client (QdrantClient): Source client + dest_client (QdrantClient): Destination client + + Returns: + bool: True if collections have the same vector and distance params + """ for collection_name in source_collection_names: source_collection = source_client.get_collection(collection_name) source_vector_params = source_collection.config.params.vectors @@ -53,12 +73,20 @@ def compare_collections( return True -def migrate_collection( +def _migrate_collection( collection_name: str, source_client: QdrantClient, dest_client: QdrantClient, batch_size: int = 100, ) -> None: + """Migrate collection from source client to destination client + + Args: + collection_name (str): Collection name + source_client (QdrantClient): Source client + dest_client (QdrantClient): Destination client + batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100. + """ records, next_offset = source_client.scroll( collection_name, limit=batch_size, with_vectors=True ) From 50bc5cb64eadcf56b9b4321259e444268ea66d1a Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Sat, 5 Aug 2023 22:54:52 +0200 Subject: [PATCH 03/11] fix: fix collection comparison, add tests --- qdrant_client/migrate/__init__.py | 1 + .../migrate/migrate.py | 52 +-------- qdrant_client/migrate/tests/test_migrate.py | 108 ++++++++++++++++++ 3 files changed, 114 insertions(+), 47 deletions(-) create mode 100644 qdrant_client/migrate/__init__.py rename migrate.py => qdrant_client/migrate/migrate.py (72%) create mode 100644 qdrant_client/migrate/tests/test_migrate.py diff --git a/qdrant_client/migrate/__init__.py b/qdrant_client/migrate/__init__.py new file mode 100644 index 00000000..1c629e26 --- /dev/null +++ b/qdrant_client/migrate/__init__.py @@ -0,0 +1 @@ +from .migrate import migrate diff --git a/migrate.py b/qdrant_client/migrate/migrate.py similarity index 72% rename from migrate.py rename to qdrant_client/migrate/migrate.py index 28d9e156..f5c4e698 100644 --- a/migrate.py +++ b/qdrant_client/migrate/migrate.py @@ -61,6 +61,10 @@ def _compare_collections( ), "Distance should be equal" elif isinstance(source_vector_params, dict): + assert len(source_vector_params) == len( + dest_vector_params + ), "Mismatched vector params: number of named vectors is not equal" + for key, source_vector_param in source_vector_params.items(): dest_vector_param = dest_vector_params[key] assert ( @@ -70,7 +74,7 @@ def _compare_collections( source_vector_param.distance == dest_vector_param.distance ), f"Distance is not the same for {key} in {collection_name}" - return True + return True def _migrate_collection( @@ -102,49 +106,3 @@ def _migrate_collection( assert ( source_client_vectors_count == dest_client_vectors_count ), f"Migration failed, vectors count are not equal: source vector count {source_client_vectors_count}, dest vector count {dest_client_vectors_count}" - - -if __name__ == "__main__": - import numpy as np - - VECTOR_NUMBER = 1000 - - local_client = QdrantClient(":memory:") - remote_client = QdrantClient() - - single_vector_collection_kwargs = { - "collection_name": "single_vector_collection", - "vectors_config": models.VectorParams(size=10, distance=models.Distance.COSINE), - } - multiple_vectors_collection_kwargs = { - "collection_name": "multiple_vectors_collection", - "vectors_config": { - "text": models.VectorParams(size=10, distance=models.Distance.EUCLID), - "image": models.VectorParams(size=11, distance=models.Distance.COSINE), - }, - } - local_client.recreate_collection(**single_vector_collection_kwargs) - local_client.recreate_collection(**multiple_vectors_collection_kwargs) - remote_client.recreate_collection(**single_vector_collection_kwargs) - remote_client.recreate_collection(**multiple_vectors_collection_kwargs) - - local_client.upload_collection( - single_vector_collection_kwargs["collection_name"], - vectors=np.random.randn( - VECTOR_NUMBER, single_vector_collection_kwargs["vectors_config"].size - ), - ) - local_client.upload_collection( - multiple_vectors_collection_kwargs["collection_name"], - vectors={ - "text": np.random.randn( - VECTOR_NUMBER, - multiple_vectors_collection_kwargs["vectors_config"]["text"].size, - ), - "image": np.random.randn( - VECTOR_NUMBER, - multiple_vectors_collection_kwargs["vectors_config"]["image"].size, - ), - }, - ) - migrate(local_client, remote_client) diff --git a/qdrant_client/migrate/tests/test_migrate.py b/qdrant_client/migrate/tests/test_migrate.py new file mode 100644 index 00000000..ad53db0c --- /dev/null +++ b/qdrant_client/migrate/tests/test_migrate.py @@ -0,0 +1,108 @@ +import numpy as np +import pytest + +from qdrant_client import QdrantClient +from qdrant_client.http import models +from qdrant_client.migrate import migrate + +VECTOR_NUMBER = 1000 + + +@pytest.fixture +def source_client() -> QdrantClient: + client = QdrantClient(":memory:") + yield client + client.close() + + +@pytest.fixture +def dest_client() -> QdrantClient: + client = QdrantClient() + yield client + client.close() + + +def test_single_vector_collection(source_client: QdrantClient, dest_client: QdrantClient) -> None: + single_vector_collection_kwargs = { + "collection_name": "single_vector_collection", + "vectors_config": models.VectorParams(size=10, distance=models.Distance.COSINE), + } + source_client.recreate_collection(**single_vector_collection_kwargs) + dest_client.recreate_collection(**single_vector_collection_kwargs) + + source_client.upload_collection( + single_vector_collection_kwargs["collection_name"], + vectors=np.random.randn( + VECTOR_NUMBER, single_vector_collection_kwargs["vectors_config"].size + ), + ) + + migrate(source_client, dest_client) + + +def test_multiple_vectors_collection( + source_client: QdrantClient, dest_client: QdrantClient +) -> None: + multiple_vectors_collection_kwargs = { + "collection_name": "multiple_vectors_collection", + "vectors_config": { + "text": models.VectorParams(size=10, distance=models.Distance.EUCLID), + "image": models.VectorParams(size=11, distance=models.Distance.COSINE), + }, + } + source_client.recreate_collection(**multiple_vectors_collection_kwargs) + dest_client.recreate_collection(**multiple_vectors_collection_kwargs) + source_client.upload_collection( + multiple_vectors_collection_kwargs["collection_name"], + vectors={ + "text": np.random.randn( + VECTOR_NUMBER, + multiple_vectors_collection_kwargs["vectors_config"]["text"].size, + ), + "image": np.random.randn( + VECTOR_NUMBER, + multiple_vectors_collection_kwargs["vectors_config"]["image"].size, + ), + }, + ) + + +def test_different_distances(source_client: QdrantClient, dest_client: QdrantClient) -> None: + collection_name = "single_vector_collection" + cosine_params = models.VectorParams(size=10, distance=models.Distance.COSINE) + euclid_params = models.VectorParams(size=10, distance=models.Distance.EUCLID) + + source_client.recreate_collection(collection_name, vectors_config=cosine_params) + dest_client.recreate_collection(collection_name, vectors_config=euclid_params) + + with pytest.raises(AssertionError): + migrate(source_client, dest_client) + + +def test_different_vector_sizes(source_client: QdrantClient, dest_client: QdrantClient) -> None: + collection_name = "single_vector_collection" + small_vector_params = models.VectorParams(size=10, distance=models.Distance.COSINE) + big_vector_params = models.VectorParams(size=100, distance=models.Distance.COSINE) + + source_client.recreate_collection(collection_name, vectors_config=small_vector_params) + dest_client.recreate_collection(collection_name, vectors_config=big_vector_params) + + with pytest.raises(AssertionError): + migrate(source_client, dest_client) + + +def test_single_vs_multiple_vectors( + source_client: QdrantClient, dest_client: QdrantClient +) -> None: + collection_name = "test_collection" + single_vector_params = {"text": models.VectorParams(size=10, distance=models.Distance.COSINE)} + multiple_vectors_params = { + "text": models.VectorParams(size=10, distance=models.Distance.COSINE), + "image": models.VectorParams(size=11, distance=models.Distance.COSINE), + } + + source_client.recreate_collection(collection_name, vectors_config=single_vector_params) + dest_client.recreate_collection(collection_name, vectors_config=multiple_vectors_params) + + with pytest.raises(AssertionError): + migrate(source_client, dest_client) From 37b738277701876263ff36119bb98d8a3d70d53e Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Sat, 5 Aug 2023 22:56:50 +0200 Subject: [PATCH 04/11] fix: fix mypy complaints --- qdrant_client/migrate/migrate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/qdrant_client/migrate/migrate.py b/qdrant_client/migrate/migrate.py index f5c4e698..5df6a24d 100644 --- a/qdrant_client/migrate/migrate.py +++ b/qdrant_client/migrate/migrate.py @@ -1,3 +1,5 @@ +from typing import List + from qdrant_client import QdrantClient from qdrant_client.http import models @@ -31,7 +33,7 @@ def migrate(source_client: QdrantClient, dest_client: QdrantClient, batch_size: def _compare_collections( - source_collection_names: list[str], + source_collection_names: List[str], source_client: QdrantClient, dest_client: QdrantClient, ) -> bool: From 366557c5380697205ebb22ee1fad9755afe3e8dc Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Sat, 5 Aug 2023 23:05:16 +0200 Subject: [PATCH 05/11] tests: test multiple collections scenario --- qdrant_client/migrate/tests/test_migrate.py | 22 +++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/qdrant_client/migrate/tests/test_migrate.py b/qdrant_client/migrate/tests/test_migrate.py index ad53db0c..1d2e4650 100644 --- a/qdrant_client/migrate/tests/test_migrate.py +++ b/qdrant_client/migrate/tests/test_migrate.py @@ -67,6 +67,28 @@ def test_multiple_vectors_collection( ) +def test_multiple_collections(source_client: QdrantClient, dest_client: QdrantClient) -> None: + collection_names = ["collection_1", "collection_2", "collection_3"] + vector_params = models.VectorParams(size=10, distance=models.Distance.COSINE) + for collection_name in collection_names: + source_client.recreate_collection(collection_name, vectors_config=vector_params) + dest_client.recreate_collection(collection_name, vectors_config=vector_params) + source_client.upload_collection( + collection_name, + vectors=np.random.randn(VECTOR_NUMBER, vector_params.size), + ) + + migrate(source_client, dest_client) + + for collection_name in collection_names: + source_vector_number = source_client.get_collection(collection_name).vectors_count + dest_vector_number = dest_client.get_collection(collection_name).vectors_count + + assert ( + source_vector_number == dest_vector_number == VECTOR_NUMBER + ), f"Migration failed. Source vectors count {source_vector_number}, dest vectors count {dest_vector_number}, expected {VECTOR_NUMBER}" + + def test_different_distances(source_client: QdrantClient, dest_client: QdrantClient) -> None: collection_name = "single_vector_collection" cosine_params = models.VectorParams(size=10, distance=models.Distance.COSINE) From 22a61aab196304f2f7c91f1007cabf5b8ed0644f Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Thu, 10 Aug 2023 21:11:30 +0200 Subject: [PATCH 06/11] new: add migrate method to client's interface --- qdrant_client/client_base.py | 9 +++++++ qdrant_client/migrate/migrate.py | 24 +++++++++---------- qdrant_client/qdrant_client.py | 4 ++++ .../migrate/tests => tests}/test_migrate.py | 12 +++++----- 4 files changed, 31 insertions(+), 18 deletions(-) rename {qdrant_client/migrate/tests => tests}/test_migrate.py (95%) diff --git a/qdrant_client/client_base.py b/qdrant_client/client_base.py index 2d3e89d5..97fdc3c0 100644 --- a/qdrant_client/client_base.py +++ b/qdrant_client/client_base.py @@ -1,5 +1,11 @@ +import sys from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union +if sys.version_info <= (3, 11): + from typing_extensions import Self +else: + from typing import Self + from qdrant_client.conversions import common_types as types from qdrant_client.http import models @@ -326,3 +332,6 @@ def get_locks(self, **kwargs: Any) -> types.LocksOption: def close(self, **kwargs: Any) -> None: pass + + def migrate(self, dest_client: Self, batch_size: int = 100) -> None: + raise NotImplementedError() diff --git a/qdrant_client/migrate/migrate.py b/qdrant_client/migrate/migrate.py index 5df6a24d..5231d255 100644 --- a/qdrant_client/migrate/migrate.py +++ b/qdrant_client/migrate/migrate.py @@ -1,15 +1,15 @@ from typing import List -from qdrant_client import QdrantClient +from qdrant_client.client_base import QdrantBase from qdrant_client.http import models -def migrate(source_client: QdrantClient, dest_client: QdrantClient, batch_size: int = 100) -> None: +def migrate(source_client: QdrantBase, dest_client: QdrantBase, batch_size: int = 100) -> None: """Migrate all collections from source client to destination client Args: - source_client (QdrantClient): Source client - dest_client (QdrantClient): Destination client + source_client (QdrantBase): Source client + dest_client (QdrantBase): Destination client batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100. """ source_collections = source_client.get_collections().collections @@ -34,15 +34,15 @@ def migrate(source_client: QdrantClient, dest_client: QdrantClient, batch_size: def _compare_collections( source_collection_names: List[str], - source_client: QdrantClient, - dest_client: QdrantClient, + source_client: QdrantBase, + dest_client: QdrantBase, ) -> bool: """Compare collections from source client and destination client Args: source_collection_names (list[str]): List of collection names - source_client (QdrantClient): Source client - dest_client (QdrantClient): Destination client + source_client (QdrantBase): Source client + dest_client (QdrantBase): Destination client Returns: bool: True if collections have the same vector and distance params @@ -81,16 +81,16 @@ def _compare_collections( def _migrate_collection( collection_name: str, - source_client: QdrantClient, - dest_client: QdrantClient, + source_client: QdrantBase, + dest_client: QdrantBase, batch_size: int = 100, ) -> None: """Migrate collection from source client to destination client Args: collection_name (str): Collection name - source_client (QdrantClient): Source client - dest_client (QdrantClient): Destination client + source_client (QdrantBase): Source client + dest_client (QdrantBase): Destination client batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100. """ records, next_offset = source_client.scroll( diff --git a/qdrant_client/qdrant_client.py b/qdrant_client/qdrant_client.py index c07c5a73..afabf002 100644 --- a/qdrant_client/qdrant_client.py +++ b/qdrant_client/qdrant_client.py @@ -6,6 +6,7 @@ from qdrant_client.http import ApiClient, SyncApis from qdrant_client.local.qdrant_local import QdrantLocal from qdrant_client.qdrant_fastembed import QdrantFastembedMixin +from qdrant_client.migrate import migrate from qdrant_client.qdrant_remote import QdrantRemote @@ -1732,3 +1733,6 @@ def get_locks(self, **kwargs: Any) -> types.LocksOption: assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}" return self._client.get_locks(**kwargs) + + def migrate(self, dest_client: QdrantBase, batch_size: int = 100) -> None: + migrate(self, dest_client, batch_size=batch_size) diff --git a/qdrant_client/migrate/tests/test_migrate.py b/tests/test_migrate.py similarity index 95% rename from qdrant_client/migrate/tests/test_migrate.py rename to tests/test_migrate.py index 1d2e4650..57b99344 100644 --- a/qdrant_client/migrate/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -3,7 +3,6 @@ from qdrant_client import QdrantClient from qdrant_client.http import models -from qdrant_client.migrate import migrate VECTOR_NUMBER = 1000 @@ -37,7 +36,7 @@ def test_single_vector_collection(source_client: QdrantClient, dest_client: Qdra ), ) - migrate(source_client, dest_client) + source_client.migrate(dest_client) def test_multiple_vectors_collection( @@ -65,6 +64,7 @@ def test_multiple_vectors_collection( ), }, ) + source_client.migrate(dest_client) def test_multiple_collections(source_client: QdrantClient, dest_client: QdrantClient) -> None: @@ -78,7 +78,7 @@ def test_multiple_collections(source_client: QdrantClient, dest_client: QdrantCl vectors=np.random.randn(VECTOR_NUMBER, vector_params.size), ) - migrate(source_client, dest_client) + source_client.migrate(dest_client) for collection_name in collection_names: source_vector_number = source_client.get_collection(collection_name).vectors_count @@ -98,7 +98,7 @@ def test_different_distances(source_client: QdrantClient, dest_client: QdrantCli dest_client.recreate_collection(collection_name, vectors_config=euclid_params) with pytest.raises(AssertionError): - migrate(source_client, dest_client) + source_client.migrate(dest_client) def test_different_vector_sizes(source_client: QdrantClient, dest_client: QdrantClient) -> None: @@ -110,7 +110,7 @@ def test_different_vector_sizes(source_client: QdrantClient, dest_client: Qdrant dest_client.recreate_collection(collection_name, vectors_config=big_vector_params) with pytest.raises(AssertionError): - migrate(source_client, dest_client) + source_client.migrate(dest_client) def test_single_vs_multiple_vectors( @@ -127,4 +127,4 @@ def test_single_vs_multiple_vectors( dest_client.recreate_collection(collection_name, vectors_config=multiple_vectors_params) with pytest.raises(AssertionError): - migrate(source_client, dest_client) + source_client.migrate(dest_client) From 948d4f5e808089334c7f2d51d093b077b1bade23 Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Fri, 18 Aug 2023 15:33:39 +0200 Subject: [PATCH 07/11] new: update migration tool --- qdrant_client/client_base.py | 16 +-- qdrant_client/migrate/migrate.py | 224 +++++++++++++++++++++++-------- qdrant_client/qdrant_client.py | 16 ++- 3 files changed, 187 insertions(+), 69 deletions(-) diff --git a/qdrant_client/client_base.py b/qdrant_client/client_base.py index 97fdc3c0..905d155c 100644 --- a/qdrant_client/client_base.py +++ b/qdrant_client/client_base.py @@ -1,13 +1,7 @@ -import sys from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union -if sys.version_info <= (3, 11): - from typing_extensions import Self -else: - from typing import Self - +from qdrant_client import models from qdrant_client.conversions import common_types as types -from qdrant_client.http import models class QdrantBase: @@ -333,5 +327,11 @@ def get_locks(self, **kwargs: Any) -> types.LocksOption: def close(self, **kwargs: Any) -> None: pass - def migrate(self, dest_client: Self, batch_size: int = 100) -> None: + def migrate( + self, + dest_client: "QdrantBase", + collection_names: Optional[List[str]] = None, + batch_size: int = 100, + raise_on_collision: bool = True, + ) -> None: raise NotImplementedError() diff --git a/qdrant_client/migrate/migrate.py b/qdrant_client/migrate/migrate.py index 5231d255..0fdd5f5a 100644 --- a/qdrant_client/migrate/migrate.py +++ b/qdrant_client/migrate/migrate.py @@ -1,88 +1,194 @@ -from typing import List +from enum import Enum +from typing import Dict, List, Optional +from qdrant_client import models +from qdrant_client._pydantic_compat import to_dict from qdrant_client.client_base import QdrantBase -from qdrant_client.http import models -def migrate(source_client: QdrantBase, dest_client: QdrantBase, batch_size: int = 100) -> None: - """Migrate all collections from source client to destination client +class MigrationCollisionAction(str, Enum): + """Action on collection configuration collision""" + + RAISE = "Raise" + RECREATE = "Recreate" + SKIP = "Skip" + + +def migrate( + source_client: QdrantBase, + dest_client: QdrantBase, + collection_names: Optional[List[str]], + on_collision_action: MigrationCollisionAction = MigrationCollisionAction.RAISE, + batch_size: int = 100, +): + """ + Migrate collections from source client to destination client Args: source_client (QdrantBase): Source client dest_client (QdrantBase): Destination client + collection_names (list[str], optional): List of collection names to migrate. + If None - migrate all source client collections. Defaults to None. + on_collision_action (MigrationCollisionAction): Action on collection configuration collision. Defaults to + MigrationCollisionAction.RAISE. batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100. """ - source_collections = source_client.get_collections().collections - dest_collections = dest_client.get_collections().collections + collection_names = select_source_collections(source_client, collection_names) + existing_dest_collections = get_existing_dest_collections(dest_client, collection_names) + absent_dest_collections = set(collection_names) - set(existing_dest_collections) + + collisions = find_collisions(source_client, dest_client, existing_dest_collections) + if collisions and on_collision_action == MigrationCollisionAction.RAISE: + raise ValueError(f"Collision detected: {collisions}") + elif collisions and on_collision_action == MigrationCollisionAction.SKIP: + collisions = [] + + for collection_name in absent_dest_collections: + recreate_collection(source_client, dest_client, collection_name) + migrate_collection(source_client, dest_client, collection_name, batch_size) + + for collection_name in collisions: + recreate_collection(source_client, dest_client, collection_name) + migrate_collection(source_client, dest_client, collection_name, batch_size) + + +def recreate_collection( + source_client: QdrantBase, + dest_client: QdrantBase, + collection_name: str, +) -> None: + src_collection_info = source_client.get_collection(collection_name) + src_config = src_collection_info.config + src_payload_schema = src_collection_info.payload_schema + + dest_client.recreate_collection( + collection_name, + vectors_config=src_config.params.vectors, + shard_number=src_config.params.shard_number, + replication_factor=src_config.params.replication_factor, + write_consistency_factor=src_config.params.write_consistency_factor, + on_disk_payload=src_config.params.on_disk_payload, + hnsw_config=models.HnswConfigDiff(**to_dict(src_config.hnsw_config)), + optimizers_config=models.OptimizersConfigDiff(**to_dict(src_config.optimizer_config)), + wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)), + quantization_config=src_config.quantization_config, + ) + + recreate_payload_schema(dest_client, collection_name, src_payload_schema) + + +def recreate_payload_schema( + dest_client: QdrantBase, + collection_name: str, + payload_schema: Dict[str, models.PayloadIndexInfo], +) -> None: + for field_name, field_info in payload_schema.items(): + dest_client.create_payload_index( + collection_name, + field_name=field_name, + field_schema=field_info.type if field_info.params is None else field_info.params, + ) + +def select_source_collections(source_client: QdrantBase, collection_names: List[str]) -> List[str]: + source_collections = source_client.get_collections().collections source_collection_names = {collection.name for collection in source_collections} + + if collection_names: + assert all( + collection_name in source_collection_names for collection_name in collection_names + ), f"Source client does not have collections: {set(collection_names) - source_collection_names}" + else: + collection_names = source_collection_names + return collection_names + + +def get_existing_dest_collections( + dest_client: QdrantBase, collection_names: List[str] +) -> List[str]: + dest_collections = dest_client.get_collections().collections dest_collection_names = {collection.name for collection in dest_collections} + existing_dest_collections = dest_collection_names & set(collection_names) + return list(existing_dest_collections) - missing_collections = source_collection_names - dest_collection_names - assert ( - not missing_collections - ), f"Destination client should have all collections from source client. Missing collections: {missing_collections}" - _compare_collections(list(source_collection_names), source_client, dest_client) +def find_collisions( + source_client: QdrantBase, dest_client: QdrantBase, collection_names: List[str] +) -> List[str]: + collision_collection_names = [] + for collection_name in collection_names: + src_collection_info = source_client.get_collection(collection_name) + dest_collection_info = dest_client.get_collection(collection_name) + collision = check_collision(src_collection_info, dest_collection_info) + if collision: + collision_collection_names.append(collection_name) - print(f"Number of collections to migrate: {len(source_collection_names)}", end="\n\n") - for collection_name in source_collection_names: - print(f"Start migrating collection `{collection_name}`") - _migrate_collection(collection_name, source_client, dest_client, batch_size) - print(f"Finish migrating collection `{collection_name}`", end="\n\n") + return collection_names -def _compare_collections( - source_collection_names: List[str], - source_client: QdrantBase, - dest_client: QdrantBase, +def check_collision( + src_collection_info: models.CollectionInfo, + dest_collection_info: models.CollectionInfo, ) -> bool: - """Compare collections from source client and destination client + """Check if collection configurations collide Args: - source_collection_names (list[str]): List of collection names - source_client (QdrantBase): Source client - dest_client (QdrantBase): Destination client + src_collection_info (models.CollectionInfo): Source collection info + dest_collection_info (models.CollectionInfo): Destination collection info Returns: - bool: True if collections have the same vector and distance params + bool: True if collection configurations collide, False otherwise """ - for collection_name in source_collection_names: - source_collection = source_client.get_collection(collection_name) - source_vector_params = source_collection.config.params.vectors - dest_collection = dest_client.get_collection(collection_name) - dest_vector_params = dest_collection.config.params.vectors - - if isinstance(source_vector_params, models.VectorParams): - assert isinstance(dest_vector_params, models.VectorParams), "Mismatched vector params" - assert ( - source_vector_params.size == dest_vector_params.size - ), "Vector size should be equal" - assert ( - source_vector_params.distance == dest_vector_params.distance - ), "Distance should be equal" - - elif isinstance(source_vector_params, dict): - assert len(source_vector_params) == len( - dest_vector_params - ), "Mismatched vector params: number of named vectors is not equal" - - for key, source_vector_param in source_vector_params.items(): - dest_vector_param = dest_vector_params[key] - assert ( - source_vector_param.size == dest_vector_param.size - ), f"Vector size is not equal for {key} in {collection_name}" - assert ( - source_vector_param.distance == dest_vector_param.distance - ), f"Distance is not the same for {key} in {collection_name}" - - return True - - -def _migrate_collection( - collection_name: str, + + def check_vector_params_collision( + src_vector_params: models.VectorParams, dest_vector_params: models.VectorParams + ) -> bool: + """Check if vector params collide + + Args: + src_vector_params (models.VectorParams): Source vector params + dest_vector_params (models.VectorParams): Destination vector params + + Returns: + bool: True if vector params collide, False otherwise + """ + if src_vector_params.size != dest_vector_params.size: + return True + + if src_vector_params.distance != dest_vector_params.distance: + return True + + return False + + source_vector_params = src_collection_info.config.params.vectors + destination_vector_params = dest_collection_info.config.params.vectors + + if isinstance(source_vector_params, models.VectorParams): + if not isinstance(destination_vector_params, models.VectorParams): + return True + + return check_vector_params_collision(source_vector_params, destination_vector_params) + + if isinstance(source_vector_params, dict): + if not isinstance(destination_vector_params, dict): + return True + + for key in source_vector_params: + if key not in destination_vector_params: + return True + + if check_vector_params_collision( + source_vector_params[key], destination_vector_params[key] + ): + return True + + return False + + +def migrate_collection( source_client: QdrantBase, dest_client: QdrantBase, + collection_name: str, batch_size: int = 100, ) -> None: """Migrate collection from source client to destination client diff --git a/qdrant_client/qdrant_client.py b/qdrant_client/qdrant_client.py index afabf002..6e671a8e 100644 --- a/qdrant_client/qdrant_client.py +++ b/qdrant_client/qdrant_client.py @@ -1734,5 +1734,17 @@ def get_locks(self, **kwargs: Any) -> types.LocksOption: return self._client.get_locks(**kwargs) - def migrate(self, dest_client: QdrantBase, batch_size: int = 100) -> None: - migrate(self, dest_client, batch_size=batch_size) + def migrate( + self, + dest_client: QdrantBase, + collection_names: Optional[List[str]] = None, + batch_size: int = 100, + raise_on_collision: bool = True, + ) -> None: + migrate( + self, + dest_client, + collection_names=collection_names, + batch_size=batch_size, + raise_on_collision=raise_on_collision, + ) From 524082e422efbebf904c9b5cf78e49e8e4059657 Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Fri, 18 Aug 2023 15:40:14 +0200 Subject: [PATCH 08/11] new: simplify collisions --- qdrant_client/migrate/migrate.py | 126 ++++++------------------------- 1 file changed, 25 insertions(+), 101 deletions(-) diff --git a/qdrant_client/migrate/migrate.py b/qdrant_client/migrate/migrate.py index 0fdd5f5a..b875a92f 100644 --- a/qdrant_client/migrate/migrate.py +++ b/qdrant_client/migrate/migrate.py @@ -7,7 +7,7 @@ class MigrationCollisionAction(str, Enum): - """Action on collection configuration collision""" + """Action on collection collision""" RAISE = "Raise" RECREATE = "Recreate" @@ -29,17 +29,16 @@ def migrate( dest_client (QdrantBase): Destination client collection_names (list[str], optional): List of collection names to migrate. If None - migrate all source client collections. Defaults to None. - on_collision_action (MigrationCollisionAction): Action on collection configuration collision. Defaults to + on_collision_action (MigrationCollisionAction): Action on collection collision. Defaults to MigrationCollisionAction.RAISE. batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100. """ collection_names = select_source_collections(source_client, collection_names) - existing_dest_collections = get_existing_dest_collections(dest_client, collection_names) - absent_dest_collections = set(collection_names) - set(existing_dest_collections) + collisions = find_collisions(dest_client, collection_names) + absent_dest_collections = set(collection_names) - set(collisions) - collisions = find_collisions(source_client, dest_client, existing_dest_collections) if collisions and on_collision_action == MigrationCollisionAction.RAISE: - raise ValueError(f"Collision detected: {collisions}") + raise ValueError(f"Collections already exist in dest_client: {collisions}") elif collisions and on_collision_action == MigrationCollisionAction.SKIP: collisions = [] @@ -52,6 +51,26 @@ def migrate( migrate_collection(source_client, dest_client, collection_name, batch_size) +def select_source_collections(source_client: QdrantBase, collection_names: List[str]) -> List[str]: + source_collections = source_client.get_collections().collections + source_collection_names = {collection.name for collection in source_collections} + + if collection_names: + assert all( + collection_name in source_collection_names for collection_name in collection_names + ), f"Source client does not have collections: {set(collection_names) - source_collection_names}" + else: + collection_names = source_collection_names + return collection_names + + +def find_collisions(dest_client: QdrantBase, collection_names: List[str]) -> List[str]: + dest_collections = dest_client.get_collections().collections + dest_collection_names = {collection.name for collection in dest_collections} + existing_dest_collections = dest_collection_names & set(collection_names) + return list(existing_dest_collections) + + def recreate_collection( source_client: QdrantBase, dest_client: QdrantBase, @@ -90,101 +109,6 @@ def recreate_payload_schema( ) -def select_source_collections(source_client: QdrantBase, collection_names: List[str]) -> List[str]: - source_collections = source_client.get_collections().collections - source_collection_names = {collection.name for collection in source_collections} - - if collection_names: - assert all( - collection_name in source_collection_names for collection_name in collection_names - ), f"Source client does not have collections: {set(collection_names) - source_collection_names}" - else: - collection_names = source_collection_names - return collection_names - - -def get_existing_dest_collections( - dest_client: QdrantBase, collection_names: List[str] -) -> List[str]: - dest_collections = dest_client.get_collections().collections - dest_collection_names = {collection.name for collection in dest_collections} - existing_dest_collections = dest_collection_names & set(collection_names) - return list(existing_dest_collections) - - -def find_collisions( - source_client: QdrantBase, dest_client: QdrantBase, collection_names: List[str] -) -> List[str]: - collision_collection_names = [] - for collection_name in collection_names: - src_collection_info = source_client.get_collection(collection_name) - dest_collection_info = dest_client.get_collection(collection_name) - collision = check_collision(src_collection_info, dest_collection_info) - if collision: - collision_collection_names.append(collection_name) - - return collection_names - - -def check_collision( - src_collection_info: models.CollectionInfo, - dest_collection_info: models.CollectionInfo, -) -> bool: - """Check if collection configurations collide - - Args: - src_collection_info (models.CollectionInfo): Source collection info - dest_collection_info (models.CollectionInfo): Destination collection info - - Returns: - bool: True if collection configurations collide, False otherwise - """ - - def check_vector_params_collision( - src_vector_params: models.VectorParams, dest_vector_params: models.VectorParams - ) -> bool: - """Check if vector params collide - - Args: - src_vector_params (models.VectorParams): Source vector params - dest_vector_params (models.VectorParams): Destination vector params - - Returns: - bool: True if vector params collide, False otherwise - """ - if src_vector_params.size != dest_vector_params.size: - return True - - if src_vector_params.distance != dest_vector_params.distance: - return True - - return False - - source_vector_params = src_collection_info.config.params.vectors - destination_vector_params = dest_collection_info.config.params.vectors - - if isinstance(source_vector_params, models.VectorParams): - if not isinstance(destination_vector_params, models.VectorParams): - return True - - return check_vector_params_collision(source_vector_params, destination_vector_params) - - if isinstance(source_vector_params, dict): - if not isinstance(destination_vector_params, dict): - return True - - for key in source_vector_params: - if key not in destination_vector_params: - return True - - if check_vector_params_collision( - source_vector_params[key], destination_vector_params[key] - ): - return True - - return False - - def migrate_collection( source_client: QdrantBase, dest_client: QdrantBase, From 80c821b99056e45fed59b215aa73a181a84cb8c4 Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Sat, 19 Aug 2023 18:07:38 +0200 Subject: [PATCH 09/11] new: update migration logic and tests --- qdrant_client/client_base.py | 2 +- qdrant_client/migrate/migrate.py | 54 ++-- qdrant_client/qdrant_client.py | 19 +- tests/congruence_tests/test_common.py | 7 +- tests/test_migrate.py | 427 +++++++++++++++++++++----- 5 files changed, 390 insertions(+), 119 deletions(-) diff --git a/qdrant_client/client_base.py b/qdrant_client/client_base.py index 905d155c..8d51dcfb 100644 --- a/qdrant_client/client_base.py +++ b/qdrant_client/client_base.py @@ -332,6 +332,6 @@ def migrate( dest_client: "QdrantBase", collection_names: Optional[List[str]] = None, batch_size: int = 100, - raise_on_collision: bool = True, + recreate_on_collision: bool = False, ) -> None: raise NotImplementedError() diff --git a/qdrant_client/migrate/migrate.py b/qdrant_client/migrate/migrate.py index b875a92f..1bd2698c 100644 --- a/qdrant_client/migrate/migrate.py +++ b/qdrant_client/migrate/migrate.py @@ -1,4 +1,3 @@ -from enum import Enum from typing import Dict, List, Optional from qdrant_client import models @@ -6,19 +5,11 @@ from qdrant_client.client_base import QdrantBase -class MigrationCollisionAction(str, Enum): - """Action on collection collision""" - - RAISE = "Raise" - RECREATE = "Recreate" - SKIP = "Skip" - - def migrate( source_client: QdrantBase, dest_client: QdrantBase, - collection_names: Optional[List[str]], - on_collision_action: MigrationCollisionAction = MigrationCollisionAction.RAISE, + collection_names: Optional[List[str]] = None, + recreate_on_collision: bool = False, batch_size: int = 100, ): """ @@ -29,29 +20,29 @@ def migrate( dest_client (QdrantBase): Destination client collection_names (list[str], optional): List of collection names to migrate. If None - migrate all source client collections. Defaults to None. - on_collision_action (MigrationCollisionAction): Action on collection collision. Defaults to - MigrationCollisionAction.RAISE. + recreate_on_collision (bool, optional): If True - recreate collection if it exists, otherwise + raise ValueError. batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100. """ - collection_names = select_source_collections(source_client, collection_names) - collisions = find_collisions(dest_client, collection_names) + collection_names = _select_source_collections(source_client, collection_names) + collisions = _find_collisions(dest_client, collection_names) absent_dest_collections = set(collection_names) - set(collisions) - if collisions and on_collision_action == MigrationCollisionAction.RAISE: + if collisions and not recreate_on_collision: raise ValueError(f"Collections already exist in dest_client: {collisions}") - elif collisions and on_collision_action == MigrationCollisionAction.SKIP: - collisions = [] for collection_name in absent_dest_collections: - recreate_collection(source_client, dest_client, collection_name) - migrate_collection(source_client, dest_client, collection_name, batch_size) + _recreate_collection(source_client, dest_client, collection_name) + _migrate_collection(source_client, dest_client, collection_name, batch_size) for collection_name in collisions: - recreate_collection(source_client, dest_client, collection_name) - migrate_collection(source_client, dest_client, collection_name, batch_size) + _recreate_collection(source_client, dest_client, collection_name) + _migrate_collection(source_client, dest_client, collection_name, batch_size) -def select_source_collections(source_client: QdrantBase, collection_names: List[str]) -> List[str]: +def _select_source_collections( + source_client: QdrantBase, collection_names: Optional[List[str]] = None +) -> List[str]: source_collections = source_client.get_collections().collections source_collection_names = {collection.name for collection in source_collections} @@ -64,14 +55,14 @@ def select_source_collections(source_client: QdrantBase, collection_names: List[ return collection_names -def find_collisions(dest_client: QdrantBase, collection_names: List[str]) -> List[str]: +def _find_collisions(dest_client: QdrantBase, collection_names: List[str]) -> List[str]: dest_collections = dest_client.get_collections().collections dest_collection_names = {collection.name for collection in dest_collections} existing_dest_collections = dest_collection_names & set(collection_names) return list(existing_dest_collections) -def recreate_collection( +def _recreate_collection( source_client: QdrantBase, dest_client: QdrantBase, collection_name: str, @@ -93,10 +84,10 @@ def recreate_collection( quantization_config=src_config.quantization_config, ) - recreate_payload_schema(dest_client, collection_name, src_payload_schema) + _recreate_payload_schema(dest_client, collection_name, src_payload_schema) -def recreate_payload_schema( +def _recreate_payload_schema( dest_client: QdrantBase, collection_name: str, payload_schema: Dict[str, models.PayloadIndexInfo], @@ -105,11 +96,11 @@ def recreate_payload_schema( dest_client.create_payload_index( collection_name, field_name=field_name, - field_schema=field_info.type if field_info.params is None else field_info.params, + field_schema=field_info.data_type if field_info.params is None else field_info.params, ) -def migrate_collection( +def _migrate_collection( source_client: QdrantBase, dest_client: QdrantBase, collection_name: str, @@ -123,16 +114,13 @@ def migrate_collection( dest_client (QdrantBase): Destination client batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100. """ - records, next_offset = source_client.scroll( - collection_name, limit=batch_size, with_vectors=True - ) + records, next_offset = source_client.scroll(collection_name, limit=2, with_vectors=True) dest_client.upload_records(collection_name, records) while next_offset: records, next_offset = source_client.scroll( collection_name, offset=next_offset, limit=batch_size, with_vectors=True ) dest_client.upload_records(collection_name, records) - source_client_vectors_count = source_client.get_collection(collection_name).vectors_count dest_client_vectors_count = dest_client.get_collection(collection_name).vectors_count assert ( diff --git a/qdrant_client/qdrant_client.py b/qdrant_client/qdrant_client.py index 6e671a8e..a4c8af7c 100644 --- a/qdrant_client/qdrant_client.py +++ b/qdrant_client/qdrant_client.py @@ -472,7 +472,8 @@ def recommend( collection_name: Collection to search in positive: List of stored point IDs, which should be used as reference for similarity search. - If there is only one ID provided - this request is equivalent to the regular search with vector of that point. + If there is only one ID provided - this request is equivalent to the regular search with vector of that + point. If there are more than one IDs, Qdrant will attempt to search for similar to all of them. Recommendation for multiple vectors is experimental. Its behaviour may change in the future. negative: @@ -570,7 +571,8 @@ def recommend_groups( collection_name: Collection to search in positive: List of stored point IDs, which should be used as reference for similarity search. - If there is only one ID provided - this request is equivalent to the regular search with vector of that point. + If there is only one ID provided - this request is equivalent to the regular search with vector of that + point. If there are more than one IDs, Qdrant will attempt to search for similar to all of them. Recommendation for multiple vectors is experimental. Its behaviour may change in the future. negative: @@ -1739,12 +1741,21 @@ def migrate( dest_client: QdrantBase, collection_names: Optional[List[str]] = None, batch_size: int = 100, - raise_on_collision: bool = True, + recreate_on_collision: bool = False, ) -> None: + """Migrate data from one Qdrant instance to another. + + Args: + dest_client: Destination Qdrant instance either in local or remote mode + collection_names: List of collection names to migrate. If None - migrate all collections + batch_size: Batch size to be in scroll and upsert operations during migration + recreate_on_collision: If True - recreate collection on destination if it already exists, otherwise + raise ValueError exception + """ migrate( self, dest_client, collection_names=collection_names, batch_size=batch_size, - raise_on_collision=raise_on_collision, + recreate_on_collision=recreate_on_collision, ) diff --git a/tests/congruence_tests/test_common.py b/tests/congruence_tests/test_common.py index b3d1e9b0..f3e8f9d8 100644 --- a/tests/congruence_tests/test_common.py +++ b/tests/congruence_tests/test_common.py @@ -75,9 +75,10 @@ def compare_collections( client_2, num_vectors, attrs=("vectors_count", "indexed_vectors_count", "points_count"), + collection_name: str = COLLECTION_NAME, ): - collection_1 = client_1.get_collection(COLLECTION_NAME) - collection_2 = client_2.get_collection(COLLECTION_NAME) + collection_1 = client_1.get_collection(collection_name) + collection_2 = client_2.get_collection(collection_name) assert all(getattr(collection_1, attr) == getattr(collection_2, attr) for attr in attrs) @@ -85,7 +86,7 @@ def compare_collections( compare_client_results( client_1, client_2, - lambda client: client.scroll(COLLECTION_NAME, with_vectors=True, limit=num_vectors * 2), + lambda client: client.scroll(collection_name, with_vectors=True, limit=num_vectors * 2), ) diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 57b99344..e6bc1ff6 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -1,130 +1,401 @@ -import numpy as np +from copy import deepcopy + import pytest -from qdrant_client import QdrantClient -from qdrant_client.http import models +import qdrant_client.http.exceptions as qdrant_exceptions +from qdrant_client import QdrantClient, models +from tests.congruence_tests.test_common import ( + compare_collections, + generate_fixtures, + initialize_fixture_collection, +) VECTOR_NUMBER = 1000 @pytest.fixture -def source_client() -> QdrantClient: +def local_client() -> QdrantClient: client = QdrantClient(":memory:") + delete_collections(client) yield client + delete_collections(client) client.close() +second_local_client = deepcopy(local_client) + + @pytest.fixture -def dest_client() -> QdrantClient: +def remote_client() -> QdrantClient: client = QdrantClient() + delete_collections(client) yield client + delete_collections(client) client.close() -def test_single_vector_collection(source_client: QdrantClient, dest_client: QdrantClient) -> None: - single_vector_collection_kwargs = { - "collection_name": "single_vector_collection", - "vectors_config": models.VectorParams(size=10, distance=models.Distance.COSINE), - } - source_client.recreate_collection(**single_vector_collection_kwargs) - dest_client.recreate_collection(**single_vector_collection_kwargs) +def delete_collections(client: QdrantClient) -> None: + collection_names = [collection.name for collection in client.get_collections().collections] + for collection_name in collection_names: + client.delete_collection(collection_name) - source_client.upload_collection( - single_vector_collection_kwargs["collection_name"], - vectors=np.random.randn( - VECTOR_NUMBER, single_vector_collection_kwargs["vectors_config"].size - ), + +@pytest.mark.parametrize( + "source_client,dest_client", + [ + ("local_client", "remote_client"), + ("remote_client", "local_client"), + ("local_client", "second_local_client"), + # remote - remote requires two launched Qdrant instances and is not tested + ], +) +def test_single_vector_collection(source_client, dest_client, request) -> None: + """ + Args: + source_client: fixture + dest_client: fixture + request: pytest internal object to get launch fixtures from parametrize + """ + source_client = request.getfixturevalue(source_client) + dest_client = request.getfixturevalue(dest_client) + vectors_config = models.VectorParams(size=10, distance=models.Distance.COSINE) + collection_name = "single_vector_collection" + initialize_fixture_collection( + source_client, collection_name=collection_name, vectors_config=vectors_config ) + records = generate_fixtures(VECTOR_NUMBER, vectors_sizes=vectors_config.size) + source_client.upload_records(collection_name, records) source_client.migrate(dest_client) + dest_client.upload_records(collection_name, records) + compare_collections( + source_client, + dest_client, + num_vectors=VECTOR_NUMBER, + collection_name=collection_name, + ) -def test_multiple_vectors_collection( - source_client: QdrantClient, dest_client: QdrantClient -) -> None: - multiple_vectors_collection_kwargs = { - "collection_name": "multiple_vectors_collection", - "vectors_config": { - "text": models.VectorParams(size=10, distance=models.Distance.EUCLID), - "image": models.VectorParams(size=11, distance=models.Distance.COSINE), - }, - } - source_client.recreate_collection(**multiple_vectors_collection_kwargs) - dest_client.recreate_collection(**multiple_vectors_collection_kwargs) - source_client.upload_collection( - multiple_vectors_collection_kwargs["collection_name"], - vectors={ - "text": np.random.randn( - VECTOR_NUMBER, - multiple_vectors_collection_kwargs["vectors_config"]["text"].size, - ), - "image": np.random.randn( - VECTOR_NUMBER, - multiple_vectors_collection_kwargs["vectors_config"]["image"].size, - ), - }, - ) +@pytest.mark.parametrize( + "source_client,dest_client", + [ + ("local_client", "remote_client"), + ("remote_client", "local_client"), + ("local_client", "second_local_client"), + # remote - remote requires two launched Qdrant instances and is not tested + ], +) +def test_multiple_vectors_collection(source_client, dest_client, request) -> None: + """ + Args: + source_client: fixture + dest_client: fixture + request: pytest internal object to get launch fixtures from parametrize + """ + source_client = request.getfixturevalue(source_client) + dest_client = request.getfixturevalue(dest_client) + collection_name = "multiple_vectors_collection" + initialize_fixture_collection(source_client, collection_name="multiple_vectors_collection") + records = generate_fixtures(VECTOR_NUMBER) + + source_client.upload_records(collection_name, records) source_client.migrate(dest_client) + compare_collections( + source_client, + dest_client, + num_vectors=VECTOR_NUMBER, + collection_name=collection_name, + ) -def test_multiple_collections(source_client: QdrantClient, dest_client: QdrantClient) -> None: +@pytest.mark.parametrize( + "source_client,dest_client", + [ + ("local_client", "remote_client"), + ("remote_client", "local_client"), + ("local_client", "second_local_client"), + # remote - remote requires two launched Qdrant instances and is not tested + ], +) +def test_migrate_all_collections(source_client, dest_client, request) -> None: + """ + Args: + source_client: fixture + dest_client: fixture + request: pytest internal object to get launch fixtures from parametrize + """ + vector_number = 100 collection_names = ["collection_1", "collection_2", "collection_3"] - vector_params = models.VectorParams(size=10, distance=models.Distance.COSINE) + source_client = request.getfixturevalue(source_client) + dest_client = request.getfixturevalue(dest_client) for collection_name in collection_names: - source_client.recreate_collection(collection_name, vectors_config=vector_params) - dest_client.recreate_collection(collection_name, vectors_config=vector_params) - source_client.upload_collection( + initialize_fixture_collection(source_client, collection_name=collection_name) + records = generate_fixtures(vector_number) + source_client.upload_records( collection_name, - vectors=np.random.randn(VECTOR_NUMBER, vector_params.size), + records, ) source_client.migrate(dest_client) for collection_name in collection_names: - source_vector_number = source_client.get_collection(collection_name).vectors_count - dest_vector_number = dest_client.get_collection(collection_name).vectors_count + compare_collections( + source_client, + dest_client, + num_vectors=vector_number, + collection_name=collection_name, + ) - assert ( - source_vector_number == dest_vector_number == VECTOR_NUMBER - ), f"Migration failed. Source vectors count {source_vector_number}, dest vectors count {dest_vector_number}, expected {VECTOR_NUMBER}" +@pytest.mark.parametrize( + "source_client,dest_client", + [ + ("local_client", "remote_client"), + ("remote_client", "local_client"), + ("local_client", "second_local_client"), + # remote - remote requires two launched Qdrant instances and is not tested + ], +) +def test_migrate_particular_collections(source_client, dest_client, request) -> None: + """ + Args: + source_client: fixture + dest_client: fixture + request: pytest internal object to get launch fixtures from parametrize + """ + vector_number = 100 + collection_names = ["collection_1", "collection_2", "collection_3"] + source_client = request.getfixturevalue(source_client) + dest_client = request.getfixturevalue(dest_client) + for collection_name in collection_names: + initialize_fixture_collection(source_client, collection_name=collection_name) + records = generate_fixtures(vector_number) + source_client.upload_records( + collection_name, + records, + ) -def test_different_distances(source_client: QdrantClient, dest_client: QdrantClient) -> None: - collection_name = "single_vector_collection" - cosine_params = models.VectorParams(size=10, distance=models.Distance.COSINE) - euclid_params = models.VectorParams(size=10, distance=models.Distance.EUCLID) + source_client.migrate(dest_client, collection_names=collection_names[:2]) - source_client.recreate_collection(collection_name, vectors_config=cosine_params) - dest_client.recreate_collection(collection_name, vectors_config=euclid_params) + for collection_name in collection_names[:2]: + compare_collections( + source_client, + dest_client, + num_vectors=vector_number, + collection_name=collection_name, + ) - with pytest.raises(AssertionError): - source_client.migrate(dest_client) + for collection_name in collection_names[2:]: + with pytest.raises((qdrant_exceptions.UnexpectedResponse, ValueError)): # type: ignore + dest_client.get_collection(collection_name) -def test_different_vector_sizes(source_client: QdrantClient, dest_client: QdrantClient) -> None: - collection_name = "single_vector_collection" - small_vector_params = models.VectorParams(size=10, distance=models.Distance.COSINE) - big_vector_params = models.VectorParams(size=100, distance=models.Distance.COSINE) +@pytest.mark.parametrize( + "source_client,dest_client", + [ + ("local_client", "remote_client"), + ("remote_client", "local_client"), + ("local_client", "second_local_client"), + # remote - remote requires two launched Qdrant instances and is not tested + ], +) +def test_action_on_collision(source_client, dest_client, request) -> None: + """ + Args: + source_client: fixture + dest_client: fixture + request: pytest internal object to get launch fixtures from parametrize + """ + collection_name = "test_collection" + source_client = request.getfixturevalue(source_client) + dest_client = request.getfixturevalue(dest_client) + initialize_fixture_collection(source_client, collection_name=collection_name) + initialize_fixture_collection(dest_client, collection_name=collection_name) - source_client.recreate_collection(collection_name, vectors_config=small_vector_params) - dest_client.recreate_collection(collection_name, vectors_config=big_vector_params) + with pytest.raises(ValueError): + source_client.migrate(dest_client, recreate_on_collision=False) - with pytest.raises(AssertionError): - source_client.migrate(dest_client) + records = generate_fixtures(VECTOR_NUMBER) + source_client.upload_records( + collection_name, + records, + ) + source_client.migrate(dest_client, recreate_on_collision=True) + compare_collections( + source_client, + dest_client, + num_vectors=VECTOR_NUMBER, + collection_name=collection_name, + ) -def test_single_vs_multiple_vectors( - source_client: QdrantClient, dest_client: QdrantClient -) -> None: +def test_vector_params( + local_client: QdrantClient, + second_local_client: QdrantClient, + remote_client: QdrantClient, +): collection_name = "test_collection" - single_vector_params = {"text": models.VectorParams(size=10, distance=models.Distance.COSINE)} - multiple_vectors_params = { + + image_hnsw_config = models.HnswConfigDiff( + m=9, + ef_construct=99, + full_scan_threshold=42, + max_indexing_threads=4, + on_disk=True, + payload_m=5, + ) + image_quantization_config = models.ScalarQuantization( + scalar=models.ScalarQuantizationConfig( + type=models.ScalarType.INT8, quantile=0.69, always_ram=False + ) + ) + + image_on_disk = True + + vectors_config = { "text": models.VectorParams(size=10, distance=models.Distance.COSINE), - "image": models.VectorParams(size=11, distance=models.Distance.COSINE), + "image": models.VectorParams( + size=20, + distance=models.Distance.DOT, + hnsw_config=image_hnsw_config, + quantization_config=image_quantization_config, + on_disk=image_on_disk, + ), } - source_client.recreate_collection(collection_name, vectors_config=single_vector_params) - dest_client.recreate_collection(collection_name, vectors_config=multiple_vectors_params) + local_client.recreate_collection( + collection_name=collection_name, vectors_config=vectors_config + ) + + local_client.migrate(second_local_client) + + assert local_client.get_collection(collection_name) == second_local_client.get_collection( + collection_name + ) + local_client.migrate(remote_client) + + local_collection_vector_params = local_client.get_collection( + collection_name + ).config.params.vectors + remote_collection_vector_params = remote_client.get_collection( + collection_name + ).config.params.vectors + + assert local_collection_vector_params == remote_collection_vector_params + + local_client.delete_collection(collection_name) + + remote_client.migrate(local_client) + local_collection_vector_params = local_client.get_collection( + collection_name + ).config.params.vectors + + assert local_collection_vector_params == remote_collection_vector_params + + +def test_migrate_missing_collections( + local_client: QdrantClient, second_local_client: QdrantClient +): + collection_name = "test_collection" with pytest.raises(AssertionError): - source_client.migrate(dest_client) + local_client.migrate(second_local_client, collection_names=[collection_name]) + + +def test_recreate_collection(remote_client: QdrantClient): + collection_name = "test_collection" + initialize_fixture_collection(remote_client, collection_name=collection_name) + collection_before_migrate = remote_client.get_collection(collection_name) + remote_client.migrate(remote_client, recreate_on_collision=True) + assert collection_before_migrate == remote_client.get_collection(collection_name) + + remote_client.delete_collection(collection_name) + + image_hnsw_config = models.HnswConfigDiff( + m=9, + ef_construct=99, + full_scan_threshold=4200, + max_indexing_threads=2, + on_disk=True, + payload_m=5, + ) + image_quantization_config = models.ScalarQuantization( + scalar=models.ScalarQuantizationConfig( + type=models.ScalarType.INT8, quantile=0.89, always_ram=False + ) + ) + + image_on_disk = True + + vectors_config = { + "text": models.VectorParams(size=10, distance=models.Distance.COSINE), + "image": models.VectorParams( + size=20, + distance=models.Distance.DOT, + hnsw_config=image_hnsw_config, + quantization_config=image_quantization_config, + on_disk=image_on_disk, + ), + } + + general_hnsw_config = models.HnswConfigDiff( + m=13, + ef_construct=101, + full_scan_threshold=10_001, + max_indexing_threads=1, + on_disk=True, + payload_m=16, + ) + optimizers_config = models.OptimizersConfigDiff( + deleted_threshold=0.21, + vacuum_min_vector_number=1001, + default_segment_number=2, + max_segment_size=42_000, + memmap_threshold=42_000, + indexing_threshold=42_000, + flush_interval_sec=6, + max_optimization_threads=2, + ) + + wal_config = models.WalConfigDiff(wal_capacity_mb=42, wal_segments_ahead=3) + + general_quantization_config = models.ProductQuantization( + product=models.ProductQuantizationConfig( + compression=models.CompressionRatio.X4, always_ram=False + ) + ) + + remote_client.recreate_collection( + collection_name, + vectors_config=vectors_config, + shard_number=3, + replication_factor=3, + write_consistency_factor=2, + on_disk_payload=True, + hnsw_config=general_hnsw_config, + optimizers_config=optimizers_config, + wal_config=wal_config, + quantization_config=general_quantization_config, + ) + + remote_client.create_payload_index( + collection_name, + field_name="title", + field_schema=models.PayloadSchemaType.KEYWORD, + ) + + remote_client.create_payload_index( + collection_name, + field_name="description", + field_schema=models.TextIndexParams( + type=models.TextIndexType.TEXT, + tokenizer=models.TokenizerType.MULTILINGUAL, + min_token_len=3, + max_token_len=5, + lowercase=False, + ), + ) + + collection_before_migrate = remote_client.get_collection(collection_name) + remote_client.migrate(remote_client, recreate_on_collision=True) + assert collection_before_migrate == remote_client.get_collection(collection_name) From 56d406761447f126b32bfe30b3a2f6f1500c7522 Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Sat, 19 Aug 2023 18:11:32 +0200 Subject: [PATCH 10/11] fix: fix mypy complaints --- qdrant_client/client_base.py | 2 +- qdrant_client/migrate/migrate.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/qdrant_client/client_base.py b/qdrant_client/client_base.py index 8d51dcfb..5d98fe52 100644 --- a/qdrant_client/client_base.py +++ b/qdrant_client/client_base.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union -from qdrant_client import models from qdrant_client.conversions import common_types as types +from qdrant_client.http import models class QdrantBase: diff --git a/qdrant_client/migrate/migrate.py b/qdrant_client/migrate/migrate.py index 1bd2698c..257212f0 100644 --- a/qdrant_client/migrate/migrate.py +++ b/qdrant_client/migrate/migrate.py @@ -1,8 +1,8 @@ from typing import Dict, List, Optional -from qdrant_client import models from qdrant_client._pydantic_compat import to_dict from qdrant_client.client_base import QdrantBase +from qdrant_client.http import models def migrate( @@ -11,7 +11,7 @@ def migrate( collection_names: Optional[List[str]] = None, recreate_on_collision: bool = False, batch_size: int = 100, -): +) -> None: """ Migrate collections from source client to destination client @@ -44,12 +44,12 @@ def _select_source_collections( source_client: QdrantBase, collection_names: Optional[List[str]] = None ) -> List[str]: source_collections = source_client.get_collections().collections - source_collection_names = {collection.name for collection in source_collections} + source_collection_names = [collection.name for collection in source_collections] - if collection_names: + if collection_names is not None: assert all( collection_name in source_collection_names for collection_name in collection_names - ), f"Source client does not have collections: {set(collection_names) - source_collection_names}" + ), f"Source client does not have collections: {set(collection_names) - set(source_collection_names)}" else: collection_names = source_collection_names return collection_names From 566a91339e9908ec831bbfd725bcd52911f6f949 Mon Sep 17 00:00:00 2001 From: George Panchuk Date: Sat, 19 Aug 2023 18:22:13 +0200 Subject: [PATCH 11/11] tests: replace tokenizer type in favour of backward compatibility tests --- tests/test_migrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_migrate.py b/tests/test_migrate.py index e6bc1ff6..2d162ab6 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -389,7 +389,7 @@ def test_recreate_collection(remote_client: QdrantClient): field_name="description", field_schema=models.TextIndexParams( type=models.TextIndexType.TEXT, - tokenizer=models.TokenizerType.MULTILINGUAL, + tokenizer=models.TokenizerType.PREFIX, min_token_len=3, max_token_len=5, lowercase=False,