Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migration tool #243

Merged
merged 11 commits into from
Aug 21, 2023
9 changes: 9 additions & 0 deletions qdrant_client/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,12 @@ def get_locks(self, **kwargs: Any) -> types.LocksOption:

def close(self, **kwargs: Any) -> None:
pass

def migrate(
self,
dest_client: "QdrantBase",
collection_names: Optional[List[str]] = None,
batch_size: int = 100,
recreate_on_collision: bool = False,
) -> None:
raise NotImplementedError()
1 change: 1 addition & 0 deletions qdrant_client/migrate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .migrate import migrate
128 changes: 128 additions & 0 deletions qdrant_client/migrate/migrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import Dict, List, Optional

from qdrant_client._pydantic_compat import to_dict
from qdrant_client.client_base import QdrantBase
from qdrant_client.http import models


def migrate(
source_client: QdrantBase,
dest_client: QdrantBase,
collection_names: Optional[List[str]] = None,
recreate_on_collision: bool = False,
batch_size: int = 100,
) -> None:
"""
Migrate collections from source client to destination client

Args:
source_client (QdrantBase): Source client
dest_client (QdrantBase): Destination client
collection_names (list[str], optional): List of collection names to migrate.
If None - migrate all source client collections. Defaults to None.
recreate_on_collision (bool, optional): If True - recreate collection if it exists, otherwise
raise ValueError.
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
"""
collection_names = _select_source_collections(source_client, collection_names)
collisions = _find_collisions(dest_client, collection_names)
absent_dest_collections = set(collection_names) - set(collisions)

if collisions and not recreate_on_collision:
raise ValueError(f"Collections already exist in dest_client: {collisions}")

for collection_name in absent_dest_collections:
_recreate_collection(source_client, dest_client, collection_name)
_migrate_collection(source_client, dest_client, collection_name, batch_size)

for collection_name in collisions:
_recreate_collection(source_client, dest_client, collection_name)
_migrate_collection(source_client, dest_client, collection_name, batch_size)


def _select_source_collections(
source_client: QdrantBase, collection_names: Optional[List[str]] = None
) -> List[str]:
source_collections = source_client.get_collections().collections
source_collection_names = [collection.name for collection in source_collections]

if collection_names is not None:
assert all(
collection_name in source_collection_names for collection_name in collection_names
), f"Source client does not have collections: {set(collection_names) - set(source_collection_names)}"
else:
collection_names = source_collection_names
return collection_names


def _find_collisions(dest_client: QdrantBase, collection_names: List[str]) -> List[str]:
dest_collections = dest_client.get_collections().collections
dest_collection_names = {collection.name for collection in dest_collections}
existing_dest_collections = dest_collection_names & set(collection_names)
return list(existing_dest_collections)


def _recreate_collection(
source_client: QdrantBase,
dest_client: QdrantBase,
collection_name: str,
) -> None:
src_collection_info = source_client.get_collection(collection_name)
src_config = src_collection_info.config
src_payload_schema = src_collection_info.payload_schema

dest_client.recreate_collection(
collection_name,
vectors_config=src_config.params.vectors,
shard_number=src_config.params.shard_number,
replication_factor=src_config.params.replication_factor,
write_consistency_factor=src_config.params.write_consistency_factor,
on_disk_payload=src_config.params.on_disk_payload,
hnsw_config=models.HnswConfigDiff(**to_dict(src_config.hnsw_config)),
optimizers_config=models.OptimizersConfigDiff(**to_dict(src_config.optimizer_config)),
wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)),
quantization_config=src_config.quantization_config,
)

_recreate_payload_schema(dest_client, collection_name, src_payload_schema)


def _recreate_payload_schema(
dest_client: QdrantBase,
collection_name: str,
payload_schema: Dict[str, models.PayloadIndexInfo],
) -> None:
for field_name, field_info in payload_schema.items():
dest_client.create_payload_index(
collection_name,
field_name=field_name,
field_schema=field_info.data_type if field_info.params is None else field_info.params,
)


def _migrate_collection(
source_client: QdrantBase,
dest_client: QdrantBase,
collection_name: str,
batch_size: int = 100,
) -> None:
"""Migrate collection from source client to destination client

Args:
collection_name (str): Collection name
source_client (QdrantBase): Source client
dest_client (QdrantBase): Destination client
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
"""
records, next_offset = source_client.scroll(collection_name, limit=2, with_vectors=True)
dest_client.upload_records(collection_name, records)
while next_offset:
records, next_offset = source_client.scroll(
collection_name, offset=next_offset, limit=batch_size, with_vectors=True
)
dest_client.upload_records(collection_name, records)
source_client_vectors_count = source_client.get_collection(collection_name).vectors_count
dest_client_vectors_count = dest_client.get_collection(collection_name).vectors_count
assert (
source_client_vectors_count == dest_client_vectors_count
), f"Migration failed, vectors count are not equal: source vector count {source_client_vectors_count}, dest vector count {dest_client_vectors_count}"
31 changes: 29 additions & 2 deletions qdrant_client/qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from qdrant_client.http import ApiClient, SyncApis
from qdrant_client.local.qdrant_local import QdrantLocal
from qdrant_client.qdrant_fastembed import QdrantFastembedMixin
from qdrant_client.migrate import migrate
from qdrant_client.qdrant_remote import QdrantRemote


Expand Down Expand Up @@ -471,7 +472,8 @@ def recommend(
collection_name: Collection to search in
positive:
List of stored point IDs, which should be used as reference for similarity search.
If there is only one ID provided - this request is equivalent to the regular search with vector of that point.
If there is only one ID provided - this request is equivalent to the regular search with vector of that
point.
If there are more than one IDs, Qdrant will attempt to search for similar to all of them.
Recommendation for multiple vectors is experimental. Its behaviour may change in the future.
negative:
Expand Down Expand Up @@ -569,7 +571,8 @@ def recommend_groups(
collection_name: Collection to search in
positive:
List of stored point IDs, which should be used as reference for similarity search.
If there is only one ID provided - this request is equivalent to the regular search with vector of that point.
If there is only one ID provided - this request is equivalent to the regular search with vector of that
point.
If there are more than one IDs, Qdrant will attempt to search for similar to all of them.
Recommendation for multiple vectors is experimental. Its behaviour may change in the future.
negative:
Expand Down Expand Up @@ -1732,3 +1735,27 @@ def get_locks(self, **kwargs: Any) -> types.LocksOption:
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"

return self._client.get_locks(**kwargs)

def migrate(
self,
dest_client: QdrantBase,
collection_names: Optional[List[str]] = None,
batch_size: int = 100,
recreate_on_collision: bool = False,
) -> None:
"""Migrate data from one Qdrant instance to another.

Args:
dest_client: Destination Qdrant instance either in local or remote mode
collection_names: List of collection names to migrate. If None - migrate all collections
batch_size: Batch size to be in scroll and upsert operations during migration
recreate_on_collision: If True - recreate collection on destination if it already exists, otherwise
raise ValueError exception
"""
migrate(
self,
dest_client,
collection_names=collection_names,
batch_size=batch_size,
recreate_on_collision=recreate_on_collision,
)
7 changes: 4 additions & 3 deletions tests/congruence_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,18 @@ def compare_collections(
client_2,
num_vectors,
attrs=("vectors_count", "indexed_vectors_count", "points_count"),
collection_name: str = COLLECTION_NAME,
):
collection_1 = client_1.get_collection(COLLECTION_NAME)
collection_2 = client_2.get_collection(COLLECTION_NAME)
collection_1 = client_1.get_collection(collection_name)
collection_2 = client_2.get_collection(collection_name)

assert all(getattr(collection_1, attr) == getattr(collection_2, attr) for attr in attrs)

# num_vectors * 2 to be sure that we have no excess points uploaded
compare_client_results(
client_1,
client_2,
lambda client: client.scroll(COLLECTION_NAME, with_vectors=True, limit=num_vectors * 2),
lambda client: client.scroll(collection_name, with_vectors=True, limit=num_vectors * 2),
)


Expand Down
Loading