Skip to content

Commit

Permalink
Migration tool (#243)
Browse files Browse the repository at this point in the history
* new: migration tool v1

* new: add docstrings

* fix: fix collection comparison, add tests

* fix: fix mypy complaints

* tests: test multiple collections scenario

* new: add migrate method to client's interface

* new: update migration tool

* new: simplify collisions

* new: update migration logic and tests

* fix: fix mypy complaints

* tests: replace tokenizer type in favour of backward compatibility tests
  • Loading branch information
joein authored Aug 21, 2023
1 parent 898de18 commit 897a7ff
Show file tree
Hide file tree
Showing 6 changed files with 572 additions and 5 deletions.
9 changes: 9 additions & 0 deletions qdrant_client/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,12 @@ def get_locks(self, **kwargs: Any) -> types.LocksOption:

def close(self, **kwargs: Any) -> None:
pass

def migrate(
self,
dest_client: "QdrantBase",
collection_names: Optional[List[str]] = None,
batch_size: int = 100,
recreate_on_collision: bool = False,
) -> None:
raise NotImplementedError()
1 change: 1 addition & 0 deletions qdrant_client/migrate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .migrate import migrate
128 changes: 128 additions & 0 deletions qdrant_client/migrate/migrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import Dict, List, Optional

from qdrant_client._pydantic_compat import to_dict
from qdrant_client.client_base import QdrantBase
from qdrant_client.http import models


def migrate(
source_client: QdrantBase,
dest_client: QdrantBase,
collection_names: Optional[List[str]] = None,
recreate_on_collision: bool = False,
batch_size: int = 100,
) -> None:
"""
Migrate collections from source client to destination client
Args:
source_client (QdrantBase): Source client
dest_client (QdrantBase): Destination client
collection_names (list[str], optional): List of collection names to migrate.
If None - migrate all source client collections. Defaults to None.
recreate_on_collision (bool, optional): If True - recreate collection if it exists, otherwise
raise ValueError.
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
"""
collection_names = _select_source_collections(source_client, collection_names)
collisions = _find_collisions(dest_client, collection_names)
absent_dest_collections = set(collection_names) - set(collisions)

if collisions and not recreate_on_collision:
raise ValueError(f"Collections already exist in dest_client: {collisions}")

for collection_name in absent_dest_collections:
_recreate_collection(source_client, dest_client, collection_name)
_migrate_collection(source_client, dest_client, collection_name, batch_size)

for collection_name in collisions:
_recreate_collection(source_client, dest_client, collection_name)
_migrate_collection(source_client, dest_client, collection_name, batch_size)


def _select_source_collections(
source_client: QdrantBase, collection_names: Optional[List[str]] = None
) -> List[str]:
source_collections = source_client.get_collections().collections
source_collection_names = [collection.name for collection in source_collections]

if collection_names is not None:
assert all(
collection_name in source_collection_names for collection_name in collection_names
), f"Source client does not have collections: {set(collection_names) - set(source_collection_names)}"
else:
collection_names = source_collection_names
return collection_names


def _find_collisions(dest_client: QdrantBase, collection_names: List[str]) -> List[str]:
dest_collections = dest_client.get_collections().collections
dest_collection_names = {collection.name for collection in dest_collections}
existing_dest_collections = dest_collection_names & set(collection_names)
return list(existing_dest_collections)


def _recreate_collection(
source_client: QdrantBase,
dest_client: QdrantBase,
collection_name: str,
) -> None:
src_collection_info = source_client.get_collection(collection_name)
src_config = src_collection_info.config
src_payload_schema = src_collection_info.payload_schema

dest_client.recreate_collection(
collection_name,
vectors_config=src_config.params.vectors,
shard_number=src_config.params.shard_number,
replication_factor=src_config.params.replication_factor,
write_consistency_factor=src_config.params.write_consistency_factor,
on_disk_payload=src_config.params.on_disk_payload,
hnsw_config=models.HnswConfigDiff(**to_dict(src_config.hnsw_config)),
optimizers_config=models.OptimizersConfigDiff(**to_dict(src_config.optimizer_config)),
wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)),
quantization_config=src_config.quantization_config,
)

_recreate_payload_schema(dest_client, collection_name, src_payload_schema)


def _recreate_payload_schema(
dest_client: QdrantBase,
collection_name: str,
payload_schema: Dict[str, models.PayloadIndexInfo],
) -> None:
for field_name, field_info in payload_schema.items():
dest_client.create_payload_index(
collection_name,
field_name=field_name,
field_schema=field_info.data_type if field_info.params is None else field_info.params,
)


def _migrate_collection(
source_client: QdrantBase,
dest_client: QdrantBase,
collection_name: str,
batch_size: int = 100,
) -> None:
"""Migrate collection from source client to destination client
Args:
collection_name (str): Collection name
source_client (QdrantBase): Source client
dest_client (QdrantBase): Destination client
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
"""
records, next_offset = source_client.scroll(collection_name, limit=2, with_vectors=True)
dest_client.upload_records(collection_name, records)
while next_offset:
records, next_offset = source_client.scroll(
collection_name, offset=next_offset, limit=batch_size, with_vectors=True
)
dest_client.upload_records(collection_name, records)
source_client_vectors_count = source_client.get_collection(collection_name).vectors_count
dest_client_vectors_count = dest_client.get_collection(collection_name).vectors_count
assert (
source_client_vectors_count == dest_client_vectors_count
), f"Migration failed, vectors count are not equal: source vector count {source_client_vectors_count}, dest vector count {dest_client_vectors_count}"
31 changes: 29 additions & 2 deletions qdrant_client/qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from qdrant_client.http import ApiClient, SyncApis
from qdrant_client.local.qdrant_local import QdrantLocal
from qdrant_client.qdrant_fastembed import QdrantFastembedMixin
from qdrant_client.migrate import migrate
from qdrant_client.qdrant_remote import QdrantRemote


Expand Down Expand Up @@ -471,7 +472,8 @@ def recommend(
collection_name: Collection to search in
positive:
List of stored point IDs, which should be used as reference for similarity search.
If there is only one ID provided - this request is equivalent to the regular search with vector of that point.
If there is only one ID provided - this request is equivalent to the regular search with vector of that
point.
If there are more than one IDs, Qdrant will attempt to search for similar to all of them.
Recommendation for multiple vectors is experimental. Its behaviour may change in the future.
negative:
Expand Down Expand Up @@ -569,7 +571,8 @@ def recommend_groups(
collection_name: Collection to search in
positive:
List of stored point IDs, which should be used as reference for similarity search.
If there is only one ID provided - this request is equivalent to the regular search with vector of that point.
If there is only one ID provided - this request is equivalent to the regular search with vector of that
point.
If there are more than one IDs, Qdrant will attempt to search for similar to all of them.
Recommendation for multiple vectors is experimental. Its behaviour may change in the future.
negative:
Expand Down Expand Up @@ -1732,3 +1735,27 @@ def get_locks(self, **kwargs: Any) -> types.LocksOption:
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"

return self._client.get_locks(**kwargs)

def migrate(
self,
dest_client: QdrantBase,
collection_names: Optional[List[str]] = None,
batch_size: int = 100,
recreate_on_collision: bool = False,
) -> None:
"""Migrate data from one Qdrant instance to another.
Args:
dest_client: Destination Qdrant instance either in local or remote mode
collection_names: List of collection names to migrate. If None - migrate all collections
batch_size: Batch size to be in scroll and upsert operations during migration
recreate_on_collision: If True - recreate collection on destination if it already exists, otherwise
raise ValueError exception
"""
migrate(
self,
dest_client,
collection_names=collection_names,
batch_size=batch_size,
recreate_on_collision=recreate_on_collision,
)
7 changes: 4 additions & 3 deletions tests/congruence_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,18 @@ def compare_collections(
client_2,
num_vectors,
attrs=("vectors_count", "indexed_vectors_count", "points_count"),
collection_name: str = COLLECTION_NAME,
):
collection_1 = client_1.get_collection(COLLECTION_NAME)
collection_2 = client_2.get_collection(COLLECTION_NAME)
collection_1 = client_1.get_collection(collection_name)
collection_2 = client_2.get_collection(collection_name)

assert all(getattr(collection_1, attr) == getattr(collection_2, attr) for attr in attrs)

# num_vectors * 2 to be sure that we have no excess points uploaded
compare_client_results(
client_1,
client_2,
lambda client: client.scroll(COLLECTION_NAME, with_vectors=True, limit=num_vectors * 2),
lambda client: client.scroll(collection_name, with_vectors=True, limit=num_vectors * 2),
)


Expand Down
Loading

0 comments on commit 897a7ff

Please sign in to comment.