Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local inference image support #836

Merged
merged 10 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 105 additions & 10 deletions qdrant_client/async_qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#
# ****** WARNING: THIS FILE IS AUTOGENERATED ******

import base64
import io
import uuid
import warnings
from itertools import tee
Expand All @@ -19,6 +21,7 @@
from qdrant_client.async_client_base import AsyncQdrantBase
from qdrant_client.conversions import common_types as types
from qdrant_client.conversions.conversion import GrpcToRest
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
from qdrant_client.embed.embed_inspector import InspectorEmbed
from qdrant_client.embed.models import NumericVector, NumericVectorStruct
from qdrant_client.embed.schema_parser import ModelSchemaParser
Expand All @@ -29,13 +32,21 @@
from qdrant_client import grpc

try:
from fastembed import SparseTextEmbedding, TextEmbedding, LateInteractionTextEmbedding
from fastembed import (
SparseTextEmbedding,
TextEmbedding,
LateInteractionTextEmbedding,
ImageEmbedding,
)
from fastembed.common import OnnxProvider
from PIL import Image as PilImage
except ImportError:
TextEmbedding = None
SparseTextEmbedding = None
OnnxProvider = None
LateInteractionTextEmbedding = None
ImageEmbedding = None
PilImage = None
SUPPORTED_EMBEDDING_MODELS: Dict[str, Tuple[int, models.Distance]] = (
{
model["model"]: (model["dim"], models.Distance.COSINE)
Expand Down Expand Up @@ -63,13 +74,19 @@
if LateInteractionTextEmbedding
else {}
)
_IMAGE_EMBEDDING_MODELS: Dict[str, Tuple[int, models.Distance]] = (
{model["model"]: model for model in ImageEmbedding.list_supported_models()}
if ImageEmbedding
else {}
)


class AsyncQdrantFastembedMixin(AsyncQdrantBase):
DEFAULT_EMBEDDING_MODEL = "BAAI/bge-small-en"
embedding_models: Dict[str, "TextEmbedding"] = {}
sparse_embedding_models: Dict[str, "SparseTextEmbedding"] = {}
late_interaction_embedding_models: Dict[str, "LateInteractionTextEmbedding"] = {}
image_embedding_models: Dict[str, "ImageEmbedding"] = {}
_FASTEMBED_INSTALLED: bool

def __init__(self, parser: ModelSchemaParser, **kwargs: Any):
Expand Down Expand Up @@ -294,6 +311,31 @@ def _get_or_init_late_interaction_model(
)
return cls.late_interaction_embedding_models[model_name]

@classmethod
def _get_or_init_image_model(
cls,
model_name: str,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
providers: Optional[Sequence["OnnxProvider"]] = None,
**kwargs: Any,
) -> "ImageEmbedding":
if model_name in cls.image_embedding_models:
return cls.image_embedding_models[model_name]
cls._import_fastembed()
if model_name not in _IMAGE_EMBEDDING_MODELS:
raise ValueError(
f"Unsupported embedding model: {model_name}. Supported models: {_IMAGE_EMBEDDING_MODELS}"
)
cls.image_embedding_models[model_name] = ImageEmbedding(
model_name=model_name,
cache_dir=cache_dir,
threads=threads,
providers=providers,
**kwargs,
)
return cls.image_embedding_models[model_name]

def _embed_documents(
self,
documents: Iterable[str],
Expand Down Expand Up @@ -727,8 +769,9 @@ async def query_batch(
]
return [self._scored_points_to_query_responses(response) for response in responses]

@staticmethod
@classmethod
def _resolve_query(
cls,
query: Union[
types.PointId,
List[float],
Expand Down Expand Up @@ -764,10 +807,10 @@ def _resolve_query(
GrpcToRest.convert_point_id(query) if isinstance(query, grpc.PointId) else query
)
return models.NearestQuery(nearest=query)
if isinstance(query, models.Document):
if isinstance(query, INFERENCE_OBJECT_TYPES):
model_name = query.model
if model_name is None:
raise ValueError("`model` field has to be set explicitly in the `Document`")
raise ValueError(f"`model` field has to be set explicitly in the {type(query)}")
return models.NearestQuery(nearest=query)
if query is None:
return None
Expand Down Expand Up @@ -813,7 +856,7 @@ def _embed_models(
A deepcopy of the method with embedded fields
"""
if paths is None:
if isinstance(model, models.Document):
if isinstance(model, INFERENCE_OBJECT_TYPES):
return self._embed_raw_data(model, is_query=is_query)
model = deepcopy(model)
paths = self._embed_inspector.inspect(model)
Expand All @@ -839,6 +882,32 @@ def _embed_models(
setattr(item, path.current, embeddings[0])
return model

@staticmethod
def _resolve_inference_object(data: models.VectorStruct) -> models.VectorStruct:
"""Resolve inference object into a model

Args:
data: models.VectorStruct - data to resolve, if it's an inference object, convert it to a proper type,
otherwise - keep unchanged

Returns:
models.VectorStruct: resolved data
"""
if not isinstance(data, models.InferenceObject):
return data
model_name = data.model
value = data.object
options = data.options
if model_name in (
*SUPPORTED_EMBEDDING_MODELS.keys(),
*SUPPORTED_SPARSE_EMBEDDING_MODELS.keys(),
*_LATE_INTERACTION_EMBEDDING_MODELS.keys(),
):
return models.Document(model=model_name, text=value, options=options)
if model_name in _IMAGE_EMBEDDING_MODELS:
return models.Image(model=model_name, image=value, options=options)
raise ValueError(f"{model_name} is not among supported models")

def _embed_raw_data(
self, data: models.VectorStruct, is_query: bool = False
) -> NumericVectorStruct:
Expand All @@ -851,8 +920,11 @@ def _embed_raw_data(
Returns:
NumericVectorStruct: Embedded data
"""
data = self._resolve_inference_object(data)
if isinstance(data, models.Document):
return self._embed_document(data, is_query=is_query)
elif isinstance(data, models.Image):
return self._embed_image(data)
elif isinstance(data, dict):
return {
key: self._embed_raw_data(value, is_query=is_query)
Expand All @@ -879,18 +951,17 @@ def _embed_document(self, document: models.Document, is_query: bool = False) ->
"""
model_name = document.model
text = document.text
options = document.options or {}
if model_name in SUPPORTED_EMBEDDING_MODELS:
embedding_model_inst = self._get_or_init_model(
model_name=model_name, **document.options or {}
)
embedding_model_inst = self._get_or_init_model(model_name=model_name, **options)
if not is_query:
embedding = list(embedding_model_inst.embed(documents=[text]))[0].tolist()
else:
embedding = list(embedding_model_inst.query_embed(query=text))[0].tolist()
return embedding
elif model_name in SUPPORTED_SPARSE_EMBEDDING_MODELS:
sparse_embedding_model_inst = self._get_or_init_sparse_model(
model_name=model_name, **document.options or {}
model_name=model_name, **options
)
if not is_query:
sparse_embedding = list(sparse_embedding_model_inst.embed(documents=[text]))[0]
Expand All @@ -901,7 +972,7 @@ def _embed_document(self, document: models.Document, is_query: bool = False) ->
)
elif model_name in _LATE_INTERACTION_EMBEDDING_MODELS:
li_embedding_model_inst = self._get_or_init_late_interaction_model(
model_name=model_name, **document.options or {}
model_name=model_name, **options
)
if not is_query:
embedding = list(li_embedding_model_inst.embed(documents=[text]))[0].tolist()
Expand All @@ -910,3 +981,27 @@ def _embed_document(self, document: models.Document, is_query: bool = False) ->
return embedding
else:
raise ValueError(f"{model_name} is not among supported models")

def _embed_image(self, image: models.Image) -> NumericVector:
"""Embed an image using the specified embedding model

Args:
image: Image to embed

Returns:
NumericVector: Image's embedding

Raises:
ValueError: If model is not supported
"""
model_name = image.model
if model_name in _IMAGE_EMBEDDING_MODELS:
embedding_model_inst = self._get_or_init_image_model(
model_name=model_name, **image.options or {}
)
image_data = base64.b64decode(image.image)
with io.BytesIO(image_data) as buffer:
with PilImage.open(buffer) as image:
embedding = list(embedding_model_inst.embed(images=[image]))[0].tolist()
return embedding
raise ValueError(f"{model_name} is not among supported models")
8 changes: 8 additions & 0 deletions qdrant_client/embed/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Set, Type, Tuple

from qdrant_client.http import models

INFERENCE_OBJECT_NAMES: Set[str] = {"Document", "Image", "InferenceObject"}
INFERENCE_OBJECT_TYPES: Tuple[
Type[models.Document], Type[models.Image], Type[models.InferenceObject]
] = (models.Document, models.Image, models.InferenceObject)
7 changes: 4 additions & 3 deletions qdrant_client/embed/embed_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import BaseModel

from qdrant_client._pydantic_compat import model_fields_set
from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
from qdrant_client.embed.schema_parser import ModelSchemaParser

from qdrant_client.embed.utils import convert_paths, Path
Expand Down Expand Up @@ -112,7 +113,7 @@ def inspect_recursive(member: BaseModel, accumulator: str) -> List[str]:
if model is None:
return []

if isinstance(model, models.Document):
if isinstance(model, INFERENCE_OBJECT_TYPES):
return [accum]

if isinstance(model, BaseModel):
Expand All @@ -132,7 +133,7 @@ def inspect_recursive(member: BaseModel, accumulator: str) -> List[str]:
if not isinstance(current_model, BaseModel):
continue

if isinstance(current_model, models.Document):
if isinstance(current_model, INFERENCE_OBJECT_TYPES):
found_paths.append(accum)

found_paths.extend(inspect_recursive(current_model, accum))
Expand All @@ -157,7 +158,7 @@ def inspect_recursive(member: BaseModel, accumulator: str) -> List[str]:
if not isinstance(current_model, BaseModel):
continue

if isinstance(current_model, models.Document):
if isinstance(current_model, INFERENCE_OBJECT_TYPES):
found_paths.append(accum)

found_paths.extend(inspect_recursive(current_model, accum))
Expand Down
6 changes: 2 additions & 4 deletions qdrant_client/embed/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

from pydantic import StrictFloat, StrictStr

from qdrant_client.grpc import SparseVector
from qdrant_client.http.models import ExtendedPointId
from qdrant_client.models import Document # type: ignore[attr-defined]
from qdrant_client.http.models import ExtendedPointId, SparseVector


NumericVector = Union[
Expand All @@ -24,4 +22,4 @@
Dict[StrictStr, NumericVector],
]

__all__ = ["Document", "NumericVector", "NumericVectorInput", "NumericVectorStruct"]
__all__ = ["NumericVector", "NumericVectorInput", "NumericVectorStruct"]
3 changes: 2 additions & 1 deletion qdrant_client/embed/schema_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class ModelSchemaParser:
"""

CACHE_PATH = "_inspection_cache.py"
INFERENCE_OBJECT_NAMES = {"Document", "Image"}

def __init__(self) -> None:
self._defs: Dict[str, Union[Dict[str, Any], List[Dict[str, Any]]]] = deepcopy(DEFS) # type: ignore[arg-type]
Expand Down Expand Up @@ -159,7 +160,7 @@ def _find_document_paths(
if not isinstance(schema, dict):
return document_paths

if "title" in schema and schema["title"] == "Document":
if "title" in schema and schema["title"] in self.INFERENCE_OBJECT_NAMES:
document_paths.append(current_path)
return document_paths

Expand Down
9 changes: 5 additions & 4 deletions qdrant_client/embed/type_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pydantic import BaseModel

from qdrant_client.embed.common import INFERENCE_OBJECT_TYPES
from qdrant_client.embed.schema_parser import ModelSchemaParser
from qdrant_client.embed.utils import Path
from qdrant_client.http import models
Expand Down Expand Up @@ -41,7 +42,7 @@ def inspect(self, points: Union[Iterable[BaseModel], BaseModel]) -> bool:
return False

def _inspect_model(self, model: BaseModel, paths: Optional[List[Path]] = None) -> bool:
if isinstance(model, models.Document):
if isinstance(model, INFERENCE_OBJECT_TYPES):
return True

paths = (
Expand Down Expand Up @@ -80,7 +81,7 @@ def inspect_recursive(member: BaseModel) -> bool:
if model is None:
return False

if isinstance(model, models.Document):
if isinstance(model, INFERENCE_OBJECT_TYPES):
return True

if isinstance(model, BaseModel):
Expand All @@ -98,7 +99,7 @@ def inspect_recursive(member: BaseModel) -> bool:

elif isinstance(model, list):
for current_model in model:
if isinstance(current_model, models.Document):
if isinstance(current_model, INFERENCE_OBJECT_TYPES):
return True

if not isinstance(current_model, BaseModel):
Expand All @@ -121,7 +122,7 @@ def inspect_recursive(member: BaseModel) -> bool:
for key, values in model.items():
values = [values] if not isinstance(values, list) else values
for current_model in values:
if isinstance(current_model, models.Document):
if isinstance(current_model, INFERENCE_OBJECT_TYPES):
return True

if not isinstance(current_model, BaseModel):
Expand Down
1 change: 1 addition & 0 deletions qdrant_client/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from qdrant_client.http.models import *
from qdrant_client.fastembed_common import *
from qdrant_client.embed.models import *
Loading
Loading