Skip to content

Commit

Permalink
feat: Add IVF and Product Quantization support for OpenSearchDocument…
Browse files Browse the repository at this point in the history
…Store (#3850)

* Add IVF and Product Quantization support for OpenSearchDocumentStore

* Remove unused import statement

* Fix mypy

* Adapt doc strings and error messages to account for PQ

* Adapt validation of indices

* Adapt existing tests

* Fix pylint

* Add tests

* Update lg

* Adapt based on PR review comments

* Fix Pylint

* Adapt based on PR review

* Add request_timeout

* Adapt based on PR review

* Adapt based on PR review

* Adapt tests

* Pin tenacity

* Unpin tenacity

* Adapt based on PR comments

* Add match to tests

---------

Co-authored-by: agnieszka-m <[email protected]>
  • Loading branch information
bogdankostic and agnieszka-m authored Feb 17, 2023
1 parent 8370715 commit 7eeb3e0
Show file tree
Hide file tree
Showing 7 changed files with 811 additions and 60 deletions.
11 changes: 0 additions & 11 deletions haystack/document_stores/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(
timeout: int = 30,
return_embedding: bool = False,
duplicate_documents: str = "overwrite",
index_type: str = "flat",
scroll: str = "1d",
skip_missing_embeddings: bool = True,
synonyms: Optional[List] = None,
Expand Down Expand Up @@ -113,8 +112,6 @@ def __init__(
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
:param index_type: The type of index to be created. Choose from 'flat' and 'hnsw'. Currently the
ElasticsearchDocumentStore does not support HNSW but OpenDistroElasticsearchDocumentStore does.
:param scroll: Determines how long the current index is fixed, e.g. during updating all documents with embeddings.
Defaults to "1d" and should not be larger than this. Can also be in minutes "5m" or hours "15h"
For details, see https://www.elastic.co/guide/en/elasticsearch/reference/current/scroll-api.html
Expand All @@ -132,13 +129,6 @@ def __init__(
:param use_system_proxy: Whether to use system proxy.
"""
# hnsw is only supported in OpensearchDocumentStore
if index_type == "hnsw":
raise DocumentStoreError(
"The HNSW algorithm for approximate nearest neighbours calculation is currently not available in the ElasticSearchDocumentStore. "
"Try the OpenSearchDocumentStore instead."
)

# Base constructor might need the client to be ready, create it first
client = self._init_elastic_client(
host=host,
Expand Down Expand Up @@ -173,7 +163,6 @@ def __init__(
similarity=similarity,
return_embedding=return_embedding,
duplicate_documents=duplicate_documents,
index_type=index_type,
scroll=scroll,
skip_missing_embeddings=skip_missing_embeddings,
synonyms=synonyms,
Expand Down
13 changes: 9 additions & 4 deletions haystack/document_stores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Fil

def train_index(
self,
documents: Optional[Union[List[dict], List[Document]]],
documents: Optional[Union[List[dict], List[Document]]] = None,
embeddings: Optional[np.ndarray] = None,
index: Optional[str] = None,
):
Expand All @@ -474,15 +474,20 @@ def train_index(
:return: None
"""
index = index or self.index
if embeddings and documents:
if isinstance(embeddings, np.ndarray) and documents:
raise ValueError("Either pass `documents` or `embeddings`. You passed both.")

if documents:
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
doc_embeddings = [doc.embedding for doc in document_objects]
doc_embeddings = [doc.embedding for doc in document_objects if doc.embedding is not None]
embeddings_for_train = np.array(doc_embeddings, dtype="float32")
self.faiss_indexes[index].train(embeddings_for_train)
if embeddings:
elif isinstance(embeddings, np.ndarray):
self.faiss_indexes[index].train(embeddings)
else:
logger.warning(
"When calling `train_index`, you must provide either Documents or embeddings. Because none of these values was provided, no training will be performed. "
)

def delete_all_documents(
self,
Expand Down
Loading

0 comments on commit 7eeb3e0

Please sign in to comment.