Skip to content

Commit

Permalink
fix: Faiss score for different metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
aamir-s18 committed Aug 13, 2024
1 parent 5e378b7 commit 029b653
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 21 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,6 @@ dmypy.json

uv.lock

res/
res/

playground.ipynb
47 changes: 38 additions & 9 deletions baguetter/indices/dense/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def __init__(
)

if embed_fn is not None and embedding_dim is None:
embedding_dim = embed_fn([". "], is_query=False, show_progress=False).shape[1]
embedding_dim = embed_fn([". "], is_query=False, show_progress=False).shape[
1
]

if embedding_dim is None:
msg = "embedding_dim must be provided if embed_fn is not None."
Expand All @@ -79,7 +81,9 @@ def __init__(
normalize_score=normalize_score,
)
self.key_mapping: dict[int, Key] = {}
self.faiss_index = faiss.index_factory(self.config.embedding_dim, self.config.faiss_string)
self.faiss_index = faiss.index_factory(
self.config.embedding_dim, self.config.faiss_string
)

if self.require_training() and train_samples:
self.train(train_samples)
Expand Down Expand Up @@ -138,7 +142,9 @@ def _load(
Raises:
FileNotFoundError: If the index files are not found in the repository.
"""
state_file_path, index_file_path = BaseDenseIndex.build_index_file_paths(name_or_path)
state_file_path, index_file_path = BaseDenseIndex.build_index_file_paths(
name_or_path
)

if not repository.exists(state_file_path):
msg = f"Index.state {state_file_path} not found in repository."
Expand All @@ -163,9 +169,13 @@ def _load(

def require_training(self) -> bool:
"""Check if the index requires training."""
return hasattr(self.faiss_index, "is_trained") and not self.faiss_index.is_trained
return (
hasattr(self.faiss_index, "is_trained") and not self.faiss_index.is_trained
)

def train(self, values: list[TextOrVector], *, show_progress: bool = False, **kwargs):
def train(
self, values: list[TextOrVector], *, show_progress: bool = False, **kwargs
):
"""Train the index.
Args:
Expand Down Expand Up @@ -246,10 +256,19 @@ def search_many(

scores, indices = self.faiss_index.search(query_vectors, top_k)

# Metric types https://github.com/facebookresearch/faiss/blob/main/faiss/MetricType.h
if self.faiss_index.metric_type != 0: # IF not METRIC_INNER_PRODUCT
scores = 1 / (1 + scores)
return [
SearchResults(
keys=[self.key_mapping[idx] for idx in query_indices if idx != -1],
scores=np.array([score for idx, score in zip(query_indices, query_scores) if idx != -1]),
scores=np.array(
[
score
for idx, score in zip(query_indices, query_scores)
if idx != -1
]
),
normalized=self.config.normalize_score,
)
for query_scores, query_indices in zip(scores, indices)
Expand All @@ -267,7 +286,13 @@ def add(self, key: Key, value: TextOrVector) -> FaissDenseIndex:
"""
return self.add_many([key], [value])

def add_many(self, keys: list[Key], values: list[TextOrVector], *, show_progress: bool = False) -> FaissDenseIndex:
def add_many(
self,
keys: list[Key],
values: list[TextOrVector],
*,
show_progress: bool = False,
) -> FaissDenseIndex:
"""Add multiple items to the index.
Args:
Expand Down Expand Up @@ -320,13 +345,17 @@ def remove_many(self, keys: list[Key]) -> FaissDenseIndex:
FaissDenseIndex: The index instance for method chaining.
"""
inv_key_mapping = {v: k for k, v in self.key_mapping.items()}
indices_to_remove = [inv_key_mapping[key] for key in keys if key in inv_key_mapping]
indices_to_remove = [
inv_key_mapping[key] for key in keys if key in inv_key_mapping
]

if not indices_to_remove:
return self

self.faiss_index.remove_ids(np.array(indices_to_remove, dtype=np.int64))

remaining_keys = [v for k, v in self.key_mapping.items() if k not in indices_to_remove]
remaining_keys = [
v for k, v in self.key_mapping.items() if k not in indices_to_remove
]
self.key_mapping = dict(enumerate(remaining_keys))
return self
55 changes: 44 additions & 11 deletions tests/indices/dense/faiss_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from operator import index
import tempfile

import numpy as np
Expand Down Expand Up @@ -57,7 +58,9 @@ def test_faiss_duplicate_keys(sample_data):
# Try to add duplicate keys with new values
new_values = all_values[5:10]
index.add_many(keys, new_values)
assert index.size == 5 # Size should remain the same as duplicate keys are overwritten
assert (
index.size == 5
) # Size should remain the same as duplicate keys are overwritten

# Verify that new values are searchable
for key, new_value in zip(keys, new_values):
Expand Down Expand Up @@ -162,42 +165,42 @@ def mock_embed_fn(texts, is_query, show_progress=False):

def test_faiss_initialization_options():
# Test different initialization options for FaissDenseIndex

# Test Flat index (default)
flat_index = FaissDenseIndex(embedding_dim=128)
assert flat_index.config.faiss_string == "Flat"

# Test IVF index
ivf_index = FaissDenseIndex(embedding_dim=128, faiss_string="IVF50,Flat")
assert ivf_index.config.faiss_string == "IVF50,Flat"

# Test IVFPQ index
ivfpq_index = FaissDenseIndex(embedding_dim=128, faiss_string="IVF50,PQ8x8")
assert ivfpq_index.config.faiss_string == "IVF50,PQ8x8"

# Test PQ index
pq_index = FaissDenseIndex(embedding_dim=128, faiss_string="PQ16x8")
assert pq_index.config.faiss_string == "PQ16x8"

# Test HNSW index
hnsw_index = FaissDenseIndex(embedding_dim=128, faiss_string="HNSW32")
assert hnsw_index.config.faiss_string == "HNSW32"

# Test LSH index
lsh_index = FaissDenseIndex(embedding_dim=128, faiss_string="LSH")
assert lsh_index.config.faiss_string == "LSH"

# Test if indices are properly initialized and can perform basic operations
keys = [f"key_{i}" for i in range(300)]
values = [np.random.rand(128).astype(np.float32) for _ in range(300)]

for index in [flat_index, ivf_index, ivfpq_index, pq_index, hnsw_index, lsh_index]:
if index.require_training():
index.train(values)

index.add_many(keys, values)
assert index.size == 300

query = np.random.rand(128).astype(np.float32)
results = index.search(query, top_k=5)
assert len(results.keys) == 5
Expand All @@ -209,3 +212,33 @@ def test_faiss_invalid_initialization():
with pytest.raises(RuntimeError):
FaissDenseIndex(embedding_dim=128, faiss_string="InvalidString")


def test_l2_metric(sample_data):
index = FaissDenseIndex(embedding_dim=128)
query_vector = sample_data[1][0]

index.add_many(sample_data[0][1:], sample_data[1][1:])
results = index.search(query_vector)

l2_distances = (
np.linalg.norm(np.array(sample_data[1][1:]) - query_vector, axis=1)
) ** 2
l2_distances = 1 / (1 + l2_distances)
l2_distances = np.sort(-l2_distances)

np.testing.assert_allclose(results.scores, -l2_distances, atol=1e-5)


def test_inner_product_metric(sample_data):
index = FaissDenseIndex(embedding_dim=128)
index.faiss_index.metric_type = 0 # Inner product

query_vector = sample_data[1][0]

index.add_many(sample_data[0][1:], sample_data[1][1:])
results = index.search(query_vector)

inner_products = np.dot(sample_data[1][1:], query_vector)
inner_products = np.sort(-inner_products)

np.testing.assert_allclose(results.scores, -inner_products, atol=1e-5)

0 comments on commit 029b653

Please sign in to comment.