Skip to content

Commit

Permalink
new: deprecate upload records, update tests, prohibit migration of co… (
Browse files Browse the repository at this point in the history
#447)

* new: deprecate upload records, update tests, prohibit migration of collections with custom shards

* Update qdrant_client/qdrant_client.py

* Update qdrant_client/async_qdrant_client.py

* new: replace autogenerated int ids with uuids (#448)

* fix: remove redundant kwargs

* fix: regen async

* fix: remove redundant import

---------

Co-authored-by: Andrey Vasnetsov <[email protected]>
  • Loading branch information
joein and generall authored Jan 19, 2024
1 parent 1c3abc0 commit a2adc2d
Show file tree
Hide file tree
Showing 39 changed files with 767 additions and 400 deletions.
5 changes: 5 additions & 0 deletions qdrant_client/async_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ def upload_records(
) -> None:
raise NotImplementedError()

async def upload_points(
self, collection_name: str, points: Iterable[types.PointStruct], **kwargs: Any
) -> None:
raise NotImplementedError()

def upload_collection(
self,
collection_name: str,
Expand Down
55 changes: 53 additions & 2 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#
# ****** WARNING: THIS FILE IS AUTOGENERATED ******

import warnings
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union

from qdrant_client import grpc as grpc
Expand Down Expand Up @@ -1741,6 +1742,11 @@ def upload_records(
This parameter overwrites shard keys written in the records.
"""
warnings.warn(
"`upload_records` is deprecated, use `upload_points` instead",
DeprecationWarning,
stacklevel=2,
)
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
return self._client.upload_records(
collection_name=collection_name,
Expand All @@ -1751,7 +1757,53 @@ def upload_records(
max_retries=max_retries,
wait=wait,
shard_key_selector=shard_key_selector,
**kwargs,
)

async def upload_points(
self,
collection_name: str,
points: Iterable[types.PointStruct],
batch_size: int = 64,
parallel: int = 1,
method: Optional[str] = None,
max_retries: int = 3,
wait: bool = False,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
) -> None:
"""Upload points to the collection
Similar to `upload_collection` method, but operates with points, rather than vector and payload individually.
Args:
collection_name: Name of the collection to upload to
points: Iterator over points to upload
batch_size: How many vectors upload per-request, Default: 64
parallel: Number of parallel processes of upload
method: Start method for parallel processes, Default: forkserver
max_retries: maximum number of retries in case of a failure
during the upload of a batch
wait:
Await for the results to be applied on the server side.
If `true`, each update request will explicitly wait for the confirmation of completion. Might be slower.
If `false`, each update request will return immediately after the confirmation of receiving.
Default: `false`
shard_key_selector: Defines the shard groups that should be used to write updates into.
If multiple shard_keys are provided, the update will be written to each of them.
Only works for collections with `custom` sharding method.
This parameter overwrites shard keys written in the records.
"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
return await self._client.upload_points(
collection_name=collection_name,
points=points,
batch_size=batch_size,
parallel=parallel,
method=method,
max_retries=max_retries,
wait=wait,
shard_key_selector=shard_key_selector,
)

def upload_collection(
Expand Down Expand Up @@ -1806,7 +1858,6 @@ def upload_collection(
max_retries=max_retries,
wait=wait,
shard_key_selector=shard_key_selector,
**kwargs,
)

async def create_payload_index(
Expand Down
14 changes: 6 additions & 8 deletions qdrant_client/async_qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from itertools import tee
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from pydantic import BaseModel

from qdrant_client.async_client_base import AsyncQdrantBase
from qdrant_client.conversions import common_types as types
from qdrant_client.fastembed_common import QueryResponse
Expand Down Expand Up @@ -172,13 +170,13 @@ def _scored_points_to_query_responses(
)
return response

def _records_iterator(
def _points_iterator(
self,
ids: Optional[Iterable[models.ExtendedPointId]],
metadata: Optional[Iterable[Dict[str, Any]]],
encoded_docs: Iterable[Tuple[str, List[float]]],
ids_accumulator: list,
) -> Iterable[models.Record]:
) -> Iterable[models.PointStruct]:
if ids is None:
ids = iter(lambda: uuid.uuid4().hex, None)
if metadata is None:
Expand All @@ -187,7 +185,7 @@ def _records_iterator(
for idx, meta, (doc, vector) in zip(ids, metadata, encoded_docs):
ids_accumulator.append(idx)
payload = {"document": doc, **meta}
yield models.Record(id=idx, payload=payload, vector={vector_name: vector})
yield models.PointStruct(id=idx, payload=payload, vector={vector_name: vector})

def get_fastembed_vector_params(
self,
Expand Down Expand Up @@ -289,12 +287,12 @@ async def add(
distance == vector_params.distance
), f"Distance mismatch: {distance} != {vector_params.distance}"
inserted_ids: list = []
records = self._records_iterator(
points = self._points_iterator(
ids=ids, metadata=metadata, encoded_docs=encoded_docs, ids_accumulator=inserted_ids
)
self.upload_records(
await self.upload_points(
collection_name=collection_name,
records=records,
points=points,
wait=True,
parallel=parallel or 1,
batch_size=batch_size,
Expand Down
25 changes: 25 additions & 0 deletions qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,6 +2054,31 @@ def upload_records(
wait=wait,
)

async def upload_points(
self,
collection_name: str,
points: Iterable[types.PointStruct],
batch_size: int = 64,
parallel: int = 1,
method: Optional[str] = None,
max_retries: int = 3,
wait: bool = False,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
) -> None:
batches_iterator = self._updater_class.iterate_records_batches(
records=points, batch_size=batch_size
)
self._upload_collection(
batches_iterator=batches_iterator,
collection_name=collection_name,
max_retries=max_retries,
parallel=parallel,
method=method,
wait=wait,
shard_key_selector=shard_key_selector,
)

def upload_collection(
self,
collection_name: str,
Expand Down
8 changes: 8 additions & 0 deletions qdrant_client/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ def upload_records(
) -> None:
raise NotImplementedError()

def upload_points(
self,
collection_name: str,
points: Iterable[types.PointStruct],
**kwargs: Any,
) -> None:
raise NotImplementedError()

def upload_collection(
self,
collection_name: str,
Expand Down
3 changes: 2 additions & 1 deletion qdrant_client/conversions/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
else:
from typing_extensions import TypeAlias

from typing import List, Tuple, Union, get_args
from typing import List, Union, get_args

from qdrant_client import grpc as grpc
from qdrant_client.http import models as rest
Expand Down Expand Up @@ -53,6 +53,7 @@ def get_args_subscribed(tp: type): # type: ignore
PayloadSchemaType = Union[
rest.PayloadSchemaType, rest.PayloadSchemaParams, int, grpc.PayloadIndexParams
] # type(grpc.PayloadSchemaType) == int
PointStruct: TypeAlias = rest.PointStruct
Points = Union[rest.Batch, List[Union[rest.PointStruct, grpc.PointStruct]]]
PointsSelector = Union[
List[PointId], rest.Filter, grpc.Filter, rest.PointsSelector, grpc.PointsSelector
Expand Down
34 changes: 30 additions & 4 deletions qdrant_client/local/async_qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,19 @@
import os
import shutil
from io import TextIOWrapper
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import (
Any,
Dict,
Generator,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
from uuid import uuid4

import numpy as np
import portalocker
Expand Down Expand Up @@ -643,16 +655,26 @@ async def recreate_collection(
collection_name, vectors_config, init_from, sparse_vectors_config
)

async def upload_points(
self, collection_name: str, points: Iterable[types.PointStruct], **kwargs: Any
) -> None:
self._upload_points(collection_name, points)

def upload_records(
self, collection_name: str, records: Iterable[types.Record], **kwargs: Any
) -> None:
self._upload_points(collection_name, records)

def _upload_points(
self, collection_name: str, points: Iterable[Union[types.PointStruct, types.Record]]
) -> None:
collection = self._get_collection(collection_name)
collection.upsert(
[
rest_models.PointStruct(
id=record.id, vector=record.vector or {}, payload=record.payload or {}
id=point.id, vector=point.vector or {}, payload=point.payload or {}
)
for record in records
for point in points
]
)

Expand All @@ -666,6 +688,10 @@ def upload_collection(
ids: Optional[Iterable[types.PointId]] = None,
**kwargs: Any,
) -> None:
def uuid_generator() -> Generator[str, None, None]:
while True:
yield str(uuid4())

collection = self._get_collection(collection_name)
if isinstance(vectors, dict) and any(
(isinstance(v, np.ndarray) for v in vectors.values())
Expand All @@ -686,7 +712,7 @@ def upload_collection(
payload=payload or {},
)
for (point_id, vector, payload) in zip(
ids or itertools.count(), iter(vectors), payload or itertools.cycle([{}])
ids or uuid_generator(), iter(vectors), payload or itertools.cycle([{}])
)
]
)
Expand Down
43 changes: 36 additions & 7 deletions qdrant_client/local/qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@
import os
import shutil
from io import TextIOWrapper
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import (
Any,
Dict,
Generator,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
from uuid import uuid4

import numpy as np
import portalocker
Expand Down Expand Up @@ -643,7 +655,10 @@ def create_collection(
if src_collection and from_collection_name:
batch_size = 100
records, next_offset = self.scroll(from_collection_name, limit=2, with_vectors=True)
self.upload_records(collection_name, records)
self.upload_records(
collection_name, records
) # it is not crucial to replace upload_records here
# since it is an internal usage, and we don't have custom shard keys in qdrant local
while next_offset is not None:
records, next_offset = self.scroll(
from_collection_name, offset=next_offset, limit=batch_size, with_vectors=True
Expand All @@ -666,20 +681,31 @@ def recreate_collection(
collection_name, vectors_config, init_from, sparse_vectors_config
)

def upload_points(
self, collection_name: str, points: Iterable[types.PointStruct], **kwargs: Any
) -> None:
self._upload_points(collection_name, points)

def upload_records(
self, collection_name: str, records: Iterable[types.Record], **kwargs: Any
) -> None:
# upload_records in local mode behaves like upload_records with wait=True in server mode
self._upload_points(collection_name, records)

def _upload_points(
self,
collection_name: str,
points: Iterable[Union[types.PointStruct, types.Record]],
) -> None:
collection = self._get_collection(collection_name)
collection.upsert(
[
rest_models.PointStruct(
id=record.id,
vector=record.vector or {},
payload=record.payload or {},
id=point.id,
vector=point.vector or {},
payload=point.payload or {},
)
for record in records
for point in points
]
)

Expand All @@ -694,6 +720,9 @@ def upload_collection(
**kwargs: Any,
) -> None:
# upload_collection in local mode behaves like upload_collection with wait=True in server mode
def uuid_generator() -> Generator[str, None, None]:
while True:
yield str(uuid4())

collection = self._get_collection(collection_name)
if isinstance(vectors, dict) and any(isinstance(v, np.ndarray) for v in vectors.values()):
Expand All @@ -716,7 +745,7 @@ def upload_collection(
payload=payload or {},
)
for (point_id, vector, payload) in zip(
ids or itertools.count(),
ids or uuid_generator(),
iter(vectors),
payload or itertools.cycle([{}]),
)
Expand Down
Loading

0 comments on commit a2adc2d

Please sign in to comment.