Skip to content

Commit

Permalink
fix: fix QueryVector isinstance check in local mode (#562)
Browse files Browse the repository at this point in the history
* fix: fix QueryVector isinstance check in local mode

* fix: update query vector type, fix invalid vector type in search in local mode

* fix: remove get_args_subscribed type hint, restore QueryVector, remove redundant array check
  • Loading branch information
joein authored Apr 2, 2024
1 parent 483578a commit a49b6f6
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
11 changes: 9 additions & 2 deletions qdrant_client/conversions/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,15 @@ def remap_type(tp: type) -> type:
return typing_remap.get(tp, tp)


def get_args_subscribed(tp: type): # type: ignore
"""Get type arguments with all substitutions performed. Supports subscripted generics having __origin__"""
def get_args_subscribed(tp): # type: ignore
"""Get type arguments with all substitutions performed. Supports subscripted generics having __origin__
Args:
tp: type to get arguments from. Can be either a type or a subscripted generic
Returns:
tuple of type arguments
"""
return tuple(
remap_type(arg if not hasattr(arg, "__origin__") else arg.__origin__)
for arg in get_args(tp)
Expand Down
7 changes: 2 additions & 5 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from qdrant_client import grpc as grpc
from qdrant_client._pydantic_compat import construct
from qdrant_client.conversions import common_types as types
from qdrant_client.conversions.common_types import Vector
from qdrant_client.conversions.common_types import get_args_subscribed
from qdrant_client.conversions.conversion import GrpcToRest
from qdrant_client.http import models
from qdrant_client.http.models.models import Distance, ExtendedPointId, SparseVector
Expand Down Expand Up @@ -202,10 +202,7 @@ def _resolve_query_vector_name(
elif isinstance(query_vector, list):
name = DEFAULT_VECTOR_NAME
vector = np.array(query_vector)
elif isinstance(query_vector, np.ndarray):
name = DEFAULT_VECTOR_NAME
vector = query_vector
elif isinstance(query_vector, get_args(QueryVector)):
elif isinstance(query_vector, get_args_subscribed(QueryVector)):
name = DEFAULT_VECTOR_NAME
vector = query_vector
else:
Expand Down
18 changes: 18 additions & 0 deletions tests/congruence_tests/test_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List

import numpy as np
import pytest

from qdrant_client.client_base import QdrantBase
from qdrant_client.http.models import models
Expand Down Expand Up @@ -318,3 +319,20 @@ def test_search_with_persistence_and_skipped_vectors():
except AssertionError as e:
print(f"\nFailed with filter {query_filter}")
raise e


def test_search_invalid_vector_type():
fixture_points = generate_fixtures()

local_client = init_local()
init_client(local_client, fixture_points)

remote_client = init_remote()
init_client(remote_client, fixture_points)

vector_invalid_type = {"text": [1, 2, 3, 4]}
with pytest.raises(ValueError):
local_client.search(collection_name=COLLECTION_NAME, query_vector=vector_invalid_type)

with pytest.raises(ValueError):
remote_client.search(collection_name=COLLECTION_NAME, query_vector=vector_invalid_type)

0 comments on commit a49b6f6

Please sign in to comment.