Skip to content

Commit

Permalink
fix: remove get_args_subscribed type hint, restore QueryVector, remov…
Browse files Browse the repository at this point in the history
…e redundant array check
  • Loading branch information
joein committed Apr 2, 2024
1 parent 3fcc896 commit a62bffa
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
11 changes: 9 additions & 2 deletions qdrant_client/conversions/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,15 @@ def remap_type(tp: type) -> type:
return typing_remap.get(tp, tp)


def get_args_subscribed(tp: type): # type: ignore
"""Get type arguments with all substitutions performed. Supports subscripted generics having __origin__"""
def get_args_subscribed(tp): # type: ignore
"""Get type arguments with all substitutions performed. Supports subscripted generics having __origin__
Args:
tp: type to get arguments from. Can be either a type or a subscripted generic
Returns:
tuple of type arguments
"""
return tuple(
remap_type(arg if not hasattr(arg, "__origin__") else arg.__origin__)
for arg in get_args(tp)
Expand Down
2 changes: 1 addition & 1 deletion qdrant_client/local/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, context_pairs: List[ContextPair]):
self.context_pairs = context_pairs


QueryVector = Union[DiscoveryQuery, ContextQuery, RecoQuery, SparseVector]
QueryVector = Union[DiscoveryQuery, ContextQuery, RecoQuery, types.NumpyArray, SparseVector]


class DistanceOrder(str, Enum):
Expand Down
10 changes: 4 additions & 6 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from qdrant_client import grpc as grpc
from qdrant_client._pydantic_compat import construct
from qdrant_client.conversions import common_types as types
from qdrant_client.conversions.common_types import get_args_subscribed
from qdrant_client.conversions.conversion import GrpcToRest
from qdrant_client.http import models
from qdrant_client.http.models.models import Distance, ExtendedPointId, SparseVector
Expand Down Expand Up @@ -184,8 +185,8 @@ def _resolve_query_vector_name(
QueryVector,
Tuple[str, QueryVector],
],
) -> Tuple[str, Union[types.NumpyArray, QueryVector]]:
vector: Union[QueryVector, types.NumpyArray]
) -> Tuple[str, QueryVector]:
vector: QueryVector
if isinstance(query_vector, tuple):
name, query = query_vector
if isinstance(query, list):
Expand All @@ -201,10 +202,7 @@ def _resolve_query_vector_name(
elif isinstance(query_vector, list):
name = DEFAULT_VECTOR_NAME
vector = np.array(query_vector)
elif isinstance(query_vector, np.ndarray):
name = DEFAULT_VECTOR_NAME
vector = query_vector
elif isinstance(query_vector, get_args(QueryVector)):
elif isinstance(query_vector, get_args_subscribed(QueryVector)):
name = DEFAULT_VECTOR_NAME
vector = query_vector
else:
Expand Down

0 comments on commit a62bffa

Please sign in to comment.