Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix upsert check in local mode #432

Merged
merged 3 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ def __init__(
)
self.payload: List[models.Payload] = []
self.deleted = np.zeros(0, dtype=bool)
all_vectors_keys = list(self.vectors.keys()) + list(self.sparse_vectors.keys())
self.deleted_per_vector = {name: np.zeros(0, dtype=bool) for name in all_vectors_keys}
self._all_vectors_keys = list(self.vectors.keys()) + list(self.sparse_vectors.keys())
self.deleted_per_vector = {
name: np.zeros(0, dtype=bool) for name in self._all_vectors_keys
}
self.ids: Dict[models.ExtendedPointId, int] = {} # Mapping from external id to internal id
self.ids_inv: List[models.ExtendedPointId] = [] # Mapping from internal id to external id
self.persistent = location is not None
Expand Down Expand Up @@ -1062,13 +1064,21 @@ def _upsert_point(self, point: models.PointStruct) -> None:
if isinstance(point.vector, dict):
updated_sparse_vectors = {}
for vector_name, vector in point.vector.items():
if vector_name not in self._all_vectors_keys:
raise ValueError(f"Wrong input: Not existing vector name error: {vector_name}")
if isinstance(vector, SparseVector):
# validate sparse vector
validate_sparse_vector(vector)
# sort sparse vector by indices before persistence
updated_sparse_vectors[vector_name] = sort(vector)
# update point.vector with the modified values after iteration
point.vector.update(updated_sparse_vectors)
else:
vector_names = list(self.vectors.keys())
if vector_names != [""]:
raise ValueError(
f"Wrong input: Unnamed vectors are not allowed when a collection has named vectors: {vector_names}"
)

if point.id in self.ids:
self._update_point(point)
Expand Down
4 changes: 4 additions & 0 deletions qdrant_client/local/qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,8 @@ def recreate_collection(
def upload_records(
self, collection_name: str, records: Iterable[types.Record], **kwargs: Any
) -> None:
# upload_records in local mode behaves like upload_records with wait=True in server mode

collection = self._get_collection(collection_name)
collection.upsert(
[
Expand All @@ -691,6 +693,8 @@ def upload_collection(
ids: Optional[Iterable[types.PointId]] = None,
**kwargs: Any,
) -> None:
# upload_collection in local mode behaves like upload_collection with wait=True in server mode

collection = self._get_collection(collection_name)
if isinstance(vectors, dict) and any(isinstance(v, np.ndarray) for v in vectors.values()):
assert (
Expand Down
74 changes: 74 additions & 0 deletions tests/congruence_tests/test_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,77 @@ def test_upload_payload_contain_nan_values():

local_client.delete_collection(nans_collection)
remote_client.delete_collection(nans_collection)


def test_upload_wrong_vectors():
local_client = init_local()
remote_client = init_remote()

vector_size = 2
wrong_vectors_collection = "test_collection"
vectors_config = {
"text": models.VectorParams(size=vector_size, distance=models.Distance.COSINE)
}
sparse_vectors_config = {"text-sparse": models.SparseVectorParams()}
local_client.recreate_collection(
collection_name=wrong_vectors_collection,
vectors_config=vectors_config,
sparse_vectors_config=sparse_vectors_config,
)
remote_client.recreate_collection(
collection_name=wrong_vectors_collection,
vectors_config=vectors_config,
sparse_vectors_config=sparse_vectors_config,
)

dense_vector = {"why_am_I_so_dense": [0.1, 0.3]}
dense_vectors = {"why_am_I_so_dense": [[0.1, 0.3]]}
sparse_vector = {"why_am_I_so_sparse": models.SparseVector(indices=[0, 1], values=[0.5, 0.6])}
sparse_vectors = {
"why_am_I_so_sparse": [models.SparseVector(indices=[0, 2], values=[0.3, 0.4])]
}

list_points = [models.PointStruct(id=1, vector=dense_vector)]
batch = models.Batch(ids=[2], vectors=dense_vectors)
list_points_sparse = [models.PointStruct(id=1, vector=sparse_vector)]
batch_sparse = models.Batch(ids=[2], vectors=sparse_vectors)

for points in (list_points, list_points_sparse, batch, batch_sparse):
with pytest.raises(qdrant_client.http.exceptions.UnexpectedResponse):
remote_client.upsert(wrong_vectors_collection, points)

with pytest.raises(ValueError):
local_client.upsert(wrong_vectors_collection, points)

for vector in (dense_vector, sparse_vector):
# does not raise without wait=True
with pytest.raises(qdrant_client.http.exceptions.UnexpectedResponse):
remote_client.upload_collection(wrong_vectors_collection, vectors=[vector], wait=True)

with pytest.raises(ValueError):
local_client.upload_collection(wrong_vectors_collection, vectors=[vector])

# does not raise without wait=True
with pytest.raises(qdrant_client.http.exceptions.UnexpectedResponse):
remote_client.upload_records(
wrong_vectors_collection,
records=[models.Record(id=3, vector=dense_vector)],
wait=True,
)

with pytest.raises(ValueError):
local_client.upload_records(
wrong_vectors_collection, records=[models.Record(id=3, vector=dense_vector)]
)

unnamed_vector = [0.1, 0.3]
with pytest.raises(qdrant_client.http.exceptions.UnexpectedResponse):
remote_client.upsert(
wrong_vectors_collection,
points=[models.PointStruct(id=1, vector=unnamed_vector)],
)
with pytest.raises(ValueError):
local_client.upsert(
wrong_vectors_collection,
points=[models.PointStruct(id=1, vector=unnamed_vector)],
)
27 changes: 7 additions & 20 deletions tests/test_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@ def test_dense_in_memory_key_filter_returns_results(qdrant: QdrantClient):
collection_name="test_collection",
wait=True,
points=[
models.PointStruct(
id=1,
vector=[0.05, 0.61, 0.76, 0.74],
payload={"city": "Berlin"}
),
models.PointStruct(id=1, vector=[0.05, 0.61, 0.76, 0.74], payload={"city": "Berlin"}),
models.PointStruct(
id=2,
vector=[0.19, 0.81, 0.75, 0.11],
Expand Down Expand Up @@ -73,18 +69,16 @@ def test_sparse_in_memory_key_filter_returns_results(qdrant: QdrantClient):
id=1,
vector={
"text": models.SparseVector(
indices=[0, 1, 2, 3],
values=[0.05, 0.61, 0.76, 0.74]
indices=[0, 1, 2, 3], values=[0.05, 0.61, 0.76, 0.74]
)
},
payload={"city": "Berlin"}
payload={"city": "Berlin"},
),
models.PointStruct(
id=2,
vector={
"text": models.SparseVector(
indices=[0, 1, 2, 3],
values=[0.19, 0.81, 0.75, 0.11]
indices=[0, 1, 2, 3], values=[0.19, 0.81, 0.75, 0.11]
)
},
payload={"city": ["Berlin", "London"]},
Expand All @@ -93,8 +87,7 @@ def test_sparse_in_memory_key_filter_returns_results(qdrant: QdrantClient):
id=3,
vector={
"text": models.SparseVector(
indices=[0, 1, 2, 3],
values=[0.36, 0.55, 0.47, 0.94]
indices=[0, 1, 2, 3], values=[0.36, 0.55, 0.47, 0.94]
)
},
payload={"city": ["Berlin", "Moscow"]},
Expand All @@ -103,14 +96,11 @@ def test_sparse_in_memory_key_filter_returns_results(qdrant: QdrantClient):
id=4,
vector={
"text": models.SparseVector(
indices=[0, 1, 2, 3],
values=[0.18, 0.01, 0.85, 0.80]
indices=[0, 1, 2, 3], values=[0.18, 0.01, 0.85, 0.80]
)
},
payload={"city": ["London", "Moscow"]},
),
models.PointStruct(id=5, vector=[0.24, 0.18, 0.22, 0.44], payload={"count": [0]}),
models.PointStruct(id=6, vector=[0.35, 0.08, 0.11, 0.44]),
],
)

Expand All @@ -121,10 +111,7 @@ def test_sparse_in_memory_key_filter_returns_results(qdrant: QdrantClient):
collection_name="test_collection",
query_vector=models.NamedSparseVector(
name="text",
vector=models.SparseVector(
indices=[0, 1, 2, 3],
values=[0.2, 0.1, 0.9, 0.7]
)
vector=models.SparseVector(indices=[0, 1, 2, 3], values=[0.2, 0.1, 0.9, 0.7]),
),
query_filter=models.Filter(
must=[models.FieldCondition(key="city", match=models.MatchValue(value="London"))]
Expand Down
Loading