Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix duplicate vector ids in FAISS #395

Merged
merged 4 commits into from
Sep 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions haystack/document_store/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,22 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O
phi = self._get_phi(document_objects)

for i in range(0, len(document_objects), self.index_buffer_size):
vector_id = faiss_index.ntotal
if add_vectors:
embeddings = [doc.embedding for doc in document_objects[i: i + self.index_buffer_size]]
hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi)
faiss_index.add(hnsw_vectors)

docs_to_write_in_sql = []
for vector_id, doc in enumerate(document_objects[i : i + self.index_buffer_size]):
for doc in document_objects[i : i + self.index_buffer_size]:
meta = doc.meta
if add_vectors:
meta["vector_id"] = vector_id
vector_id += 1
docs_to_write_in_sql.append(doc)

super(FAISSDocumentStore, self).write_documents(docs_to_write_in_sql, index=index)

self.faiss_index = faiss_index

def _get_hnsw_vectors(self, embeddings: List[np.array], phi: int) -> np.array:
Expand Down Expand Up @@ -121,17 +124,19 @@ def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = Non
doc.embedding = embeddings[i]

phi = self._get_phi(documents)
doc_meta_to_update = []

for i in range(0, len(documents), self.index_buffer_size):
embeddings = [doc.embedding for doc in documents[i : i + self.index_buffer_size]]
vector_id = faiss_index.ntotal
embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]]
hnsw_vectors = self._get_hnsw_vectors(embeddings=embeddings, phi=phi)
faiss_index.add(hnsw_vectors)

doc_meta_to_update = []
for vector_id, doc in enumerate(documents[i : i + self.index_buffer_size]):
meta = doc.meta or {}
meta["vector_id"] = vector_id
doc_meta_to_update.append((doc.id, meta))
for doc in documents[i: i + self.index_buffer_size]:
meta = doc.meta or {}
meta["vector_id"] = vector_id
vector_id += 1
doc_meta_to_update.append((doc.id, meta))

for doc_id, meta in doc_meta_to_update:
super(FAISSDocumentStore, self).update_document_meta(id=doc_id, meta=meta)
Expand Down Expand Up @@ -161,6 +166,7 @@ def query_by_embedding(
for doc in documents:
doc.score = scores_for_vector_ids[doc.meta["vector_id"]] # type: ignore
doc.probability = (doc.score + 1) / 2

return documents

def save(self, file_path: Union[str, Path]):
Expand Down
61 changes: 59 additions & 2 deletions test/test_faiss.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
import numpy as np
import pytest
from haystack import Document

from haystack.retriever.dense import DensePassageRetriever

@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_indexing(document_store):
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
def test_faiss_write_docs(document_store, index_buffer_size):
documents = [
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_3", "text": "text_3", "embedding": np.random.rand(768).astype(np.float32)},
]

document_store.index_buffer_size = index_buffer_size

document_store.write_documents(documents)
documents_indexed = document_store.get_all_documents()

# test if correct vector_ids are assigned
for i, doc in enumerate(documents_indexed):
assert doc.meta["vector_id"] == str(i)

# test if correct vectors are associated with docs
for i, doc in enumerate(documents_indexed):
# we currently don't get the embeddings back when we call document_store.get_all_documents()
original_doc = [d for d in documents if d["text"] == doc.text][0]
stored_emb = document_store.faiss_index.reconstruct(int(doc.meta["vector_id"]))
# compare original input vec with stored one (ignore extra dim added by hnsw)
assert np.allclose(original_doc["embedding"], stored_emb[:-1])

# test insertion of documents in an existing index fails
with pytest.raises(Exception):
document_store.write_documents(documents)
Expand All @@ -25,4 +38,48 @@ def test_faiss_indexing(document_store):
document_store.save("haystack_test_faiss")

# test loading the index
document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss")
document_store.load(sql_url="sqlite:///haystack_test.db", faiss_file_path="haystack_test_faiss")

@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
@pytest.mark.parametrize("index_buffer_size", [10_000, 2])
def test_faiss_update_docs(document_store, index_buffer_size):
documents = [
{"name": "name_1", "text": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_2", "text": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_3", "text": "text_3", "embedding": np.random.rand(768).astype(np.float32)},
]

# adjust buffer size
document_store.index_buffer_size = index_buffer_size

# initial write
document_store.write_documents(documents)

# do the update
retriever = DensePassageRetriever(document_store=document_store,
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
use_gpu=False, embed_title=True,
remove_sep_tok_from_untitled_passages=True)

document_store.update_embeddings(retriever=retriever)
documents_indexed = document_store.get_all_documents()

# test if number of documents is correct
assert len(documents_indexed) == len(documents)

# test if two docs have same vector_is assigned
vector_ids = set()
for i, doc in enumerate(documents_indexed):
vector_ids.add(doc.meta["vector_id"])
assert len(vector_ids) == len(documents)

# test if correct vectors are associated with docs
for i, doc in enumerate(documents_indexed):
original_doc = [d for d in documents if d["text"] == doc.text][0]
updated_embedding = retriever.embed_passages([Document.from_dict(original_doc)])
stored_emb = document_store.faiss_index.reconstruct(int(doc.meta["vector_id"]))
# compare original input vec with stored one (ignore extra dim added by hnsw)
assert np.allclose(updated_embedding, stored_emb[:-1])