Skip to content

Commit

Permalink
fix: fix generation of grpc async create shard key, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joein committed Sep 24, 2024
1 parent 2c59378 commit bb5fd65
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 15 deletions.
15 changes: 8 additions & 7 deletions qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2765,16 +2765,17 @@ async def create_shard_key(
if self._prefer_grpc:
if isinstance(shard_key, get_args_subscribed(models.ShardKey)):
shard_key = RestToGrpc.convert_shard_key(shard_key)
request = grpc.CreateShardKey(
shard_key=shard_key,
shards_number=shards_number,
replication_factor=replication_factor,
placement=placement or [],
)
return (
await self.grpc_collections.CreateShardKey(
grpc.CreateShardKeyRequest(
collection_name=collection_name, timeout=timeout, request=request
collection_name=collection_name,
timeout=timeout,
request=grpc.CreateShardKey(
shard_key=shard_key,
shards_number=shards_number,
replication_factor=replication_factor,
placement=placement or [],
),
),
timeout=self._timeout,
)
Expand Down
14 changes: 6 additions & 8 deletions qdrant_client/qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -3138,18 +3138,16 @@ def create_shard_key(
if isinstance(shard_key, get_args_subscribed(models.ShardKey)):
shard_key = RestToGrpc.convert_shard_key(shard_key)

request = grpc.CreateShardKey(
shard_key=shard_key,
shards_number=shards_number,
replication_factor=replication_factor,
placement=placement or [],
)

return self.grpc_collections.CreateShardKey(
grpc.CreateShardKeyRequest(
collection_name=collection_name,
timeout=timeout,
request=request,
request=grpc.CreateShardKey(
shard_key=shard_key,
shards_number=shards_number,
replication_factor=replication_factor,
placement=placement or [],
),
),
timeout=self._timeout,
).result
Expand Down
22 changes: 22 additions & 0 deletions tests/test_async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,3 +609,25 @@ def auth_token_provider():

await client.unlock_storage()
assert sync_token == "token_2"


@pytest.mark.asyncio
@pytest.mark.parametrize("prefer_grpc", [False, True])
async def test_custom_sharding(prefer_grpc):
client = AsyncQdrantClient(prefer_grpc=prefer_grpc)

if await client.collection_exists(COLLECTION_NAME):
await client.delete_collection(collection_name=COLLECTION_NAME)
await client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=models.VectorParams(size=DIM, distance=models.Distance.DOT),
sharding_method=models.ShardingMethod.CUSTOM,
)

await client.create_shard_key(collection_name=COLLECTION_NAME, shard_key="cats")
await client.create_shard_key(collection_name=COLLECTION_NAME, shard_key="dogs")

collection_info = await client.get_collection(COLLECTION_NAME)

assert collection_info.config.params.shard_number == 1
# assert collection_info.config.params.sharding_method == models.ShardingMethod.CUSTOM # todo: fix in grpc

0 comments on commit bb5fd65

Please sign in to comment.