Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new: deprecate upload records, update tests, prohibit migration of co… #447

Merged
merged 7 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions qdrant_client/async_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ def upload_records(
) -> None:
raise NotImplementedError()

async def upload_points(
self, collection_name: str, points: Iterable[types.PointStruct], **kwargs: Any
) -> None:
raise NotImplementedError()

def upload_collection(
self,
collection_name: str,
Expand Down
55 changes: 53 additions & 2 deletions qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#
# ****** WARNING: THIS FILE IS AUTOGENERATED ******

import warnings
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union

from qdrant_client import grpc as grpc
Expand Down Expand Up @@ -1741,6 +1742,11 @@ def upload_records(
This parameter overwrites shard keys written in the records.

"""
warnings.warn(
"`upload_records` is deprecated, use `upload_points` instead",
DeprecationWarning,
stacklevel=2,
)
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
return self._client.upload_records(
collection_name=collection_name,
Expand All @@ -1751,7 +1757,53 @@ def upload_records(
max_retries=max_retries,
wait=wait,
shard_key_selector=shard_key_selector,
**kwargs,
)

async def upload_points(
self,
collection_name: str,
points: Iterable[types.PointStruct],
batch_size: int = 64,
parallel: int = 1,
method: Optional[str] = None,
max_retries: int = 3,
wait: bool = False,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
) -> None:
"""Upload points to the collection

Similar to `upload_collection` method, but operates with points, rather than vector and payload individually.

Args:
collection_name: Name of the collection to upload to
points: Iterator over points to upload
batch_size: How many vectors upload per-request, Default: 64
parallel: Number of parallel processes of upload
method: Start method for parallel processes, Default: forkserver
max_retries: maximum number of retries in case of a failure
during the upload of a batch
wait:
Await for the results to be applied on the server side.
If `true`, each update request will explicitly wait for the confirmation of completion. Might be slower.
If `false`, each update request will return immediately after the confirmation of receiving.
Default: `false`
shard_key_selector: Defines the shard groups that should be used to write updates into.
If multiple shard_keys are provided, the update will be written to each of them.
Only works for collections with `custom` sharding method.
This parameter overwrites shard keys written in the records.

"""
assert len(kwargs) == 0, f"Unknown arguments: {list(kwargs.keys())}"
return await self._client.upload_points(
collection_name=collection_name,
points=points,
batch_size=batch_size,
parallel=parallel,
method=method,
max_retries=max_retries,
wait=wait,
shard_key_selector=shard_key_selector,
)

def upload_collection(
Expand Down Expand Up @@ -1806,7 +1858,6 @@ def upload_collection(
max_retries=max_retries,
wait=wait,
shard_key_selector=shard_key_selector,
**kwargs,
)

async def create_payload_index(
Expand Down
14 changes: 6 additions & 8 deletions qdrant_client/async_qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from itertools import tee
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from pydantic import BaseModel

from qdrant_client.async_client_base import AsyncQdrantBase
from qdrant_client.conversions import common_types as types
from qdrant_client.fastembed_common import QueryResponse
Expand Down Expand Up @@ -172,13 +170,13 @@ def _scored_points_to_query_responses(
)
return response

def _records_iterator(
def _points_iterator(
self,
ids: Optional[Iterable[models.ExtendedPointId]],
metadata: Optional[Iterable[Dict[str, Any]]],
encoded_docs: Iterable[Tuple[str, List[float]]],
ids_accumulator: list,
) -> Iterable[models.Record]:
) -> Iterable[models.PointStruct]:
if ids is None:
ids = iter(lambda: uuid.uuid4().hex, None)
if metadata is None:
Expand All @@ -187,7 +185,7 @@ def _records_iterator(
for idx, meta, (doc, vector) in zip(ids, metadata, encoded_docs):
ids_accumulator.append(idx)
payload = {"document": doc, **meta}
yield models.Record(id=idx, payload=payload, vector={vector_name: vector})
yield models.PointStruct(id=idx, payload=payload, vector={vector_name: vector})

def get_fastembed_vector_params(
self,
Expand Down Expand Up @@ -289,12 +287,12 @@ async def add(
distance == vector_params.distance
), f"Distance mismatch: {distance} != {vector_params.distance}"
inserted_ids: list = []
records = self._records_iterator(
points = self._points_iterator(
ids=ids, metadata=metadata, encoded_docs=encoded_docs, ids_accumulator=inserted_ids
)
self.upload_records(
await self.upload_points(
collection_name=collection_name,
records=records,
points=points,
wait=True,
parallel=parallel or 1,
batch_size=batch_size,
Expand Down
25 changes: 25 additions & 0 deletions qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,6 +2054,31 @@ def upload_records(
wait=wait,
)

async def upload_points(
self,
collection_name: str,
points: Iterable[types.PointStruct],
batch_size: int = 64,
parallel: int = 1,
method: Optional[str] = None,
max_retries: int = 3,
wait: bool = False,
shard_key_selector: Optional[types.ShardKeySelector] = None,
**kwargs: Any,
) -> None:
batches_iterator = self._updater_class.iterate_records_batches(
records=points, batch_size=batch_size
)
self._upload_collection(
batches_iterator=batches_iterator,
collection_name=collection_name,
max_retries=max_retries,
parallel=parallel,
method=method,
wait=wait,
shard_key_selector=shard_key_selector,
)

def upload_collection(
self,
collection_name: str,
Expand Down
8 changes: 8 additions & 0 deletions qdrant_client/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ def upload_records(
) -> None:
raise NotImplementedError()

def upload_points(
self,
collection_name: str,
points: Iterable[types.PointStruct],
**kwargs: Any,
) -> None:
raise NotImplementedError()

def upload_collection(
self,
collection_name: str,
Expand Down
3 changes: 2 additions & 1 deletion qdrant_client/conversions/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
else:
from typing_extensions import TypeAlias

from typing import List, Tuple, Union, get_args
from typing import List, Union, get_args

from qdrant_client import grpc as grpc
from qdrant_client.http import models as rest
Expand Down Expand Up @@ -53,6 +53,7 @@ def get_args_subscribed(tp: type): # type: ignore
PayloadSchemaType = Union[
rest.PayloadSchemaType, rest.PayloadSchemaParams, int, grpc.PayloadIndexParams
] # type(grpc.PayloadSchemaType) == int
PointStruct: TypeAlias = rest.PointStruct
Points = Union[rest.Batch, List[Union[rest.PointStruct, grpc.PointStruct]]]
PointsSelector = Union[
List[PointId], rest.Filter, grpc.Filter, rest.PointsSelector, grpc.PointsSelector
Expand Down
34 changes: 30 additions & 4 deletions qdrant_client/local/async_qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,19 @@
import os
import shutil
from io import TextIOWrapper
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import (
Any,
Dict,
Generator,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
from uuid import uuid4

import numpy as np
import portalocker
Expand Down Expand Up @@ -643,16 +655,26 @@ async def recreate_collection(
collection_name, vectors_config, init_from, sparse_vectors_config
)

async def upload_points(
self, collection_name: str, points: Iterable[types.PointStruct], **kwargs: Any
) -> None:
self._upload_points(collection_name, points)

def upload_records(
self, collection_name: str, records: Iterable[types.Record], **kwargs: Any
) -> None:
self._upload_points(collection_name, records)

def _upload_points(
self, collection_name: str, points: Iterable[Union[types.PointStruct, types.Record]]
) -> None:
collection = self._get_collection(collection_name)
collection.upsert(
[
rest_models.PointStruct(
id=record.id, vector=record.vector or {}, payload=record.payload or {}
id=point.id, vector=point.vector or {}, payload=point.payload or {}
)
for record in records
for point in points
]
)

Expand All @@ -666,6 +688,10 @@ def upload_collection(
ids: Optional[Iterable[types.PointId]] = None,
**kwargs: Any,
) -> None:
def uuid_generator() -> Generator[str, None, None]:
while True:
yield str(uuid4())

collection = self._get_collection(collection_name)
if isinstance(vectors, dict) and any(
(isinstance(v, np.ndarray) for v in vectors.values())
Expand All @@ -686,7 +712,7 @@ def upload_collection(
payload=payload or {},
)
for (point_id, vector, payload) in zip(
ids or itertools.count(), iter(vectors), payload or itertools.cycle([{}])
ids or uuid_generator(), iter(vectors), payload or itertools.cycle([{}])
)
]
)
Expand Down
43 changes: 36 additions & 7 deletions qdrant_client/local/qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@
import os
import shutil
from io import TextIOWrapper
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import (
Any,
Dict,
Generator,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
from uuid import uuid4

import numpy as np
import portalocker
Expand Down Expand Up @@ -643,7 +655,10 @@ def create_collection(
if src_collection and from_collection_name:
batch_size = 100
records, next_offset = self.scroll(from_collection_name, limit=2, with_vectors=True)
self.upload_records(collection_name, records)
self.upload_records(
collection_name, records
) # it is not crucial to replace upload_records here
# since it is an internal usage, and we don't have custom shard keys in qdrant local
while next_offset is not None:
records, next_offset = self.scroll(
from_collection_name, offset=next_offset, limit=batch_size, with_vectors=True
Expand All @@ -666,20 +681,31 @@ def recreate_collection(
collection_name, vectors_config, init_from, sparse_vectors_config
)

def upload_points(
self, collection_name: str, points: Iterable[types.PointStruct], **kwargs: Any
) -> None:
self._upload_points(collection_name, points)

def upload_records(
self, collection_name: str, records: Iterable[types.Record], **kwargs: Any
) -> None:
# upload_records in local mode behaves like upload_records with wait=True in server mode
self._upload_points(collection_name, records)

def _upload_points(
self,
collection_name: str,
points: Iterable[Union[types.PointStruct, types.Record]],
) -> None:
collection = self._get_collection(collection_name)
collection.upsert(
[
rest_models.PointStruct(
id=record.id,
vector=record.vector or {},
payload=record.payload or {},
id=point.id,
vector=point.vector or {},
payload=point.payload or {},
)
for record in records
for point in points
]
)

Expand All @@ -694,6 +720,9 @@ def upload_collection(
**kwargs: Any,
) -> None:
# upload_collection in local mode behaves like upload_collection with wait=True in server mode
def uuid_generator() -> Generator[str, None, None]:
while True:
yield str(uuid4())

collection = self._get_collection(collection_name)
if isinstance(vectors, dict) and any(isinstance(v, np.ndarray) for v in vectors.values()):
Expand All @@ -716,7 +745,7 @@ def upload_collection(
payload=payload or {},
)
for (point_id, vector, payload) in zip(
ids or itertools.count(),
ids or uuid_generator(),
iter(vectors),
payload or itertools.cycle([{}]),
)
Expand Down
Loading
Loading