Skip to content

Commit

Permalink
WIP: hybrid search with fastembed (#553)
Browse files Browse the repository at this point in the history
* WIP: hybrid search with fastembed

* hybrid queries with fastembed

* test for hybrid

* fix typo

* new: extend hybrid search tests, fix mypy, small refactoring (#554)

* refactor: align model name parameters in setters, update tests

* fix: fix async

* fix: add a good test, fix sparse vectors in query batch

* refactoring: reduce branching, refactor fastembed tests

---------

Co-authored-by: George <[email protected]>
  • Loading branch information
generall and joein committed Mar 27, 2024
1 parent 5de0f3d commit 0d7e46c
Show file tree
Hide file tree
Showing 10 changed files with 783 additions and 266 deletions.
198 changes: 102 additions & 96 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ grpcio-tools = ">=1.41.0"
urllib3 = ">=1.26.14,<3"
portalocker = "^2.7.0"
fastembed = [
{ version = "0.2.2", optional = true, python = "<3.13" }
{ version = "0.2.5", optional = true, python = "<3.13" }
]

[tool.poetry.group.dev.dependencies]
Expand Down
288 changes: 247 additions & 41 deletions qdrant_client/async_qdrant_fastembed.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions qdrant_client/conversions/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def get_args_subscribed(tp: type): # type: ignore
SnapshotDescription: TypeAlias = rest.SnapshotDescription
NamedVector: TypeAlias = rest.NamedVector
NamedSparseVector: TypeAlias = rest.NamedSparseVector
SparseVector: TypeAlias = rest.SparseVector
PointVectors: TypeAlias = rest.PointVectors
Vector: TypeAlias = rest.Vector
VectorStruct: TypeAlias = rest.VectorStruct
Expand Down
5 changes: 4 additions & 1 deletion qdrant_client/fastembed_common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, Field

from qdrant_client.conversions.common_types import SparseVector


class QueryResponse(BaseModel, extra="forbid"): # type: ignore
id: Union[str, int]
embedding: Optional[List[float]]
sparse_embedding: Optional[SparseVector] = Field(default=None)
metadata: Dict[str, Any]
document: str
score: float
Empty file.
31 changes: 31 additions & 0 deletions qdrant_client/hybrid/fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Dict, List

from qdrant_client.http import models


def reciprocal_rank_fusion(
responses: List[List[models.ScoredPoint]], limit: int = 10
) -> List[models.ScoredPoint]:
def compute_score(pos: int) -> float:
ranking_constant = (
2 # the constant mitigates the impact of high rankings by outlier systems
)
return 1 / (ranking_constant + pos)

scores: Dict[models.ExtendedPointId, float] = {}
point_pile = {}
for response in responses:
for i, scored_point in enumerate(response):
if scored_point.id in scores:
scores[scored_point.id] += compute_score(i)
else:
point_pile[scored_point.id] = scored_point
scores[scored_point.id] = compute_score(i)

sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
sorted_points = []
for point_id, score in sorted_scores[:limit]:
point = point_pile[point_id]
point.score = score
sorted_points.append(point)
return sorted_points
24 changes: 24 additions & 0 deletions qdrant_client/hybrid/test_reranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from qdrant_client.http import models
from qdrant_client.hybrid.fusion import reciprocal_rank_fusion


def test_reciprocal_rank_fusion() -> None:
responses = [
[
models.ScoredPoint(id="1", score=0.1, version=1),
models.ScoredPoint(id="2", score=0.2, version=1),
models.ScoredPoint(id="3", score=0.3, version=1),
],
[
models.ScoredPoint(id="5", score=12.0, version=1),
models.ScoredPoint(id="6", score=8.0, version=1),
models.ScoredPoint(id="7", score=5.0, version=1),
models.ScoredPoint(id="2", score=3.0, version=1),
],
]

fused = reciprocal_rank_fusion(responses)

assert fused[0].id == "2"
assert fused[1].id in ["1", "5"]
assert fused[2].id in ["1", "5"]
Loading

0 comments on commit 0d7e46c

Please sign in to comment.