Skip to content

Commit

Permalink
[Paddle-Pipelines] update faiss (#7793)
Browse files Browse the repository at this point in the history
* update faiss

* update faiss

* update faiss
  • Loading branch information
qingzhong1 authored Jan 8, 2024
1 parent 079f067 commit ff1e910
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 8 deletions.
6 changes: 1 addition & 5 deletions pipelines/pipelines/document_stores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def update_embeddings(

vector_id_map = {}
for doc in document_batch:
vector_id_map[str(doc.id)] = str(vector_id)
vector_id_map[str(doc.id)] = str(vector_id) + "_" + index
vector_id += 1
self.update_vector_ids(vector_id_map, index=index)
progress_bar.set_description_str("Documents Processed")
Expand Down Expand Up @@ -443,7 +443,6 @@ def get_all_documents_generator(
)
if return_embedding is None:
return_embedding = self.return_embedding

for doc in documents:
if return_embedding:
if doc.meta and doc.meta.get("vector_id") is not None:
Expand Down Expand Up @@ -588,7 +587,6 @@ def query_by_embedding(

if filters:
logger.warning("Query filters are not implemented for the FAISSDocumentStore.")

index = index or self.index
if not self.faiss_indexes.get(index):
raise Exception(f"Index named '{index}' does not exists. Use 'update_embeddings()' to create an index.")
Expand All @@ -599,11 +597,9 @@ def query_by_embedding(
query_emb = query_emb.reshape(1, -1).astype(np.float32)
if self.similarity == "cosine":
self.normalize_embedding(query_emb)

score_matrix, vector_id_matrix = self.faiss_indexes[index].search(query_emb, top_k)
vector_ids_for_query = [str(vector_id) + "_" + index for vector_id in vector_id_matrix[0] if vector_id != -1]
documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index)

# assign query score to each document
scores_for_vector_ids: Dict[str, float] = {
str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])
Expand Down
3 changes: 0 additions & 3 deletions pipelines/pipelines/document_stores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,13 @@ def get_documents_by_vector_ids(
):
"""Fetch documents by specifying a list of text vector id strings"""
index = index or self.index

documents = []
for i in range(0, len(vector_ids), batch_size):
query = self.session.query(DocumentORM).filter(
DocumentORM.vector_id.in_(vector_ids[i : i + batch_size]), DocumentORM.index == index
)
for row in query.all():
documents.append(self._convert_sql_row_to_document(row))

sorted_documents = sorted(documents, key=lambda doc: vector_ids.index(doc.meta["vector_id"]))
return sorted_documents

Expand Down Expand Up @@ -405,7 +403,6 @@ def write_documents(
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
else:
document_objects = documents

document_objects = self._handle_duplicate_documents(
documents=document_objects, index=index, duplicate_documents=duplicate_documents
)
Expand Down

0 comments on commit ff1e910

Please sign in to comment.