Skip to content

Commit

Permalink
new: add migrate method to client's interface
Browse files Browse the repository at this point in the history
  • Loading branch information
joein committed Aug 10, 2023
1 parent 5cac772 commit 2611f0e
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 19 deletions.
9 changes: 9 additions & 0 deletions qdrant_client/client_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import sys
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union

if sys.version_info <= (3, 11):
from typing_extensions import Self
else:
from typing import Self

from qdrant_client.conversions import common_types as types
from qdrant_client.http import models

Expand Down Expand Up @@ -326,3 +332,6 @@ def get_locks(self, **kwargs: Any) -> types.LocksOption:

def close(self, **kwargs: Any) -> None:
pass

def migrate(self, dest_client: Self, batch_size: int = 100) -> None:
raise NotImplementedError()
24 changes: 12 additions & 12 deletions qdrant_client/migrate/migrate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import List

from qdrant_client import QdrantClient
from qdrant_client.client_base import QdrantBase
from qdrant_client.http import models


def migrate(source_client: QdrantClient, dest_client: QdrantClient, batch_size: int = 100) -> None:
def migrate(source_client: QdrantBase, dest_client: QdrantBase, batch_size: int = 100) -> None:
"""Migrate all collections from source client to destination client
Args:
source_client (QdrantClient): Source client
dest_client (QdrantClient): Destination client
source_client (QdrantBase): Source client
dest_client (QdrantBase): Destination client
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
"""
source_collections = source_client.get_collections().collections
Expand All @@ -34,15 +34,15 @@ def migrate(source_client: QdrantClient, dest_client: QdrantClient, batch_size:

def _compare_collections(
source_collection_names: List[str],
source_client: QdrantClient,
dest_client: QdrantClient,
source_client: QdrantBase,
dest_client: QdrantBase,
) -> bool:
"""Compare collections from source client and destination client
Args:
source_collection_names (list[str]): List of collection names
source_client (QdrantClient): Source client
dest_client (QdrantClient): Destination client
source_client (QdrantBase): Source client
dest_client (QdrantBase): Destination client
Returns:
bool: True if collections have the same vector and distance params
Expand Down Expand Up @@ -81,16 +81,16 @@ def _compare_collections(

def _migrate_collection(
collection_name: str,
source_client: QdrantClient,
dest_client: QdrantClient,
source_client: QdrantBase,
dest_client: QdrantBase,
batch_size: int = 100,
) -> None:
"""Migrate collection from source client to destination client
Args:
collection_name (str): Collection name
source_client (QdrantClient): Source client
dest_client (QdrantClient): Destination client
source_client (QdrantBase): Source client
dest_client (QdrantBase): Destination client
batch_size (int, optional): Batch size for scrolling and uploading vectors. Defaults to 100.
"""
records, next_offset = source_client.scroll(
Expand Down
4 changes: 4 additions & 0 deletions qdrant_client/qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from qdrant_client.conversions import common_types as types
from qdrant_client.http import ApiClient, SyncApis
from qdrant_client.local.qdrant_local import QdrantLocal
from qdrant_client.migrate import migrate
from qdrant_client.qdrant_remote import QdrantRemote


Expand Down Expand Up @@ -1708,3 +1709,6 @@ def get_locks(self, **kwargs: Any) -> types.LocksOption:
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"

return self._client.get_locks(**kwargs)

def migrate(self, dest_client: QdrantBase, batch_size: int = 100) -> None:
migrate(self, dest_client, batch_size=batch_size)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.migrate import migrate

VECTOR_NUMBER = 1000

Expand Down Expand Up @@ -37,7 +36,7 @@ def test_single_vector_collection(source_client: QdrantClient, dest_client: Qdra
),
)

migrate(source_client, dest_client)
source_client.migrate(dest_client)


def test_multiple_vectors_collection(
Expand Down Expand Up @@ -65,6 +64,7 @@ def test_multiple_vectors_collection(
),
},
)
source_client.migrate(dest_client)


def test_multiple_collections(source_client: QdrantClient, dest_client: QdrantClient) -> None:
Expand All @@ -78,7 +78,7 @@ def test_multiple_collections(source_client: QdrantClient, dest_client: QdrantCl
vectors=np.random.randn(VECTOR_NUMBER, vector_params.size),
)

migrate(source_client, dest_client)
source_client.migrate(dest_client)

for collection_name in collection_names:
source_vector_number = source_client.get_collection(collection_name).vectors_count
Expand All @@ -98,7 +98,7 @@ def test_different_distances(source_client: QdrantClient, dest_client: QdrantCli
dest_client.recreate_collection(collection_name, vectors_config=euclid_params)

with pytest.raises(AssertionError):
migrate(source_client, dest_client)
source_client.migrate(dest_client)


def test_different_vector_sizes(source_client: QdrantClient, dest_client: QdrantClient) -> None:
Expand All @@ -110,7 +110,7 @@ def test_different_vector_sizes(source_client: QdrantClient, dest_client: Qdrant
dest_client.recreate_collection(collection_name, vectors_config=big_vector_params)

with pytest.raises(AssertionError):
migrate(source_client, dest_client)
source_client.migrate(dest_client)


def test_single_vs_multiple_vectors(
Expand All @@ -127,4 +127,4 @@ def test_single_vs_multiple_vectors(
dest_client.recreate_collection(collection_name, vectors_config=multiple_vectors_params)

with pytest.raises(AssertionError):
migrate(source_client, dest_client)
source_client.migrate(dest_client)
1 change: 0 additions & 1 deletion tests/test_qdrant_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import random
import uuid
from pprint import pprint
from tempfile import mkdtemp
Expand Down

0 comments on commit 2611f0e

Please sign in to comment.