From 5937f9cf16729a271d3cd748a1938945d0de125e Mon Sep 17 00:00:00 2001 From: Tanay Soni Date: Tue, 4 Aug 2020 14:24:12 +0200 Subject: [PATCH] Deprecate Tags for Document Stores (#286) --- haystack/database/base.py | 20 +++---- haystack/database/elasticsearch.py | 34 +++++------- haystack/database/memory.py | 89 +++++++++++------------------- haystack/database/sql.py | 84 ++++++++++------------------ haystack/finder.py | 8 +-- test/conftest.py | 12 ++-- test/test_db.py | 30 +++++++++- test/test_in_memory_store.py | 85 ---------------------------- 8 files changed, 117 insertions(+), 245 deletions(-) delete mode 100644 test/test_in_memory_store.py diff --git a/haystack/database/base.py b/haystack/database/base.py index 6e3ae4e7b1..0abe4763f9 100644 --- a/haystack/database/base.py +++ b/haystack/database/base.py @@ -9,7 +9,6 @@ def __init__(self, text: str, query_score: Optional[float] = None, question: Optional[str] = None, meta: Dict[str, Any] = None, - tags: Optional[Dict[str, Any]] = None, embedding: Optional[List[float]] = None): """ Object used to represent documents / passages in a standardized way within Haystack. @@ -24,7 +23,6 @@ def __init__(self, text: str, :param query_score: Retriever's query score for a retrieved document :param question: Question text for FAQs. :param meta: Meta fields for a document like name, url, or author. - :param tags: Tags that allow filtering of the data :param embedding: Vector encoding of the text """ @@ -38,7 +36,6 @@ def __init__(self, text: str, self.query_score = query_score self.question = question self.meta = meta - self.tags = tags # deprecate? self.embedding = embedding def to_dict(self): @@ -47,7 +44,7 @@ def to_dict(self): @classmethod def from_dict(cls, dict): _doc = dict.copy() - init_args = ["text", "id", "query_score", "question", "meta", "tags", "embedding"] + init_args = ["text", "id", "query_score", "question", "meta", "embedding"] if "meta" not in _doc.keys(): _doc["meta"] = {} # copy additional fields into "meta" @@ -110,14 +107,15 @@ class BaseDocumentStore(ABC): Base class for implementing Document Stores. """ index: Optional[str] + label_index: Optional[str] @abstractmethod def write_documents(self, documents: List[dict], index: Optional[str] = None): """ Indexes documents for later queries. - :param documents: List of dictionaries. - Default format: {"text": ""} + :param documents: a list of Python dictionaries or a list of Haystack Document objects. + For documents as dictionaries, the format is {"text": ""}. Optionally: Include meta data via {"text": "", "meta":{"name": ", "author": "somebody", ...}} It can be used for filtering and is accessible in the responses of the Finder. @@ -129,21 +127,17 @@ def write_documents(self, documents: List[dict], index: Optional[str] = None): pass @abstractmethod - def get_all_documents(self, index: Optional[str] = None) -> List[Document]: + def get_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Document]: pass @abstractmethod - def get_all_labels(self, index: str = "label", filters: Optional[dict] = None) -> List[Label]: + def get_all_labels(self, index: str = "label", filters: Optional[Optional[Dict[str, List[str]]]] = None) -> List[Label]: pass @abstractmethod def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]: pass - @abstractmethod - def get_document_ids_by_tags(self, tag, index) -> List[str]: - pass - @abstractmethod def get_document_count(self, index: Optional[str] = None) -> int: pass @@ -151,7 +145,7 @@ def get_document_count(self, index: Optional[str] = None) -> int: @abstractmethod def query_by_embedding(self, query_emb: List[float], - filters: Optional[dict] = None, + filters: Optional[Optional[Dict[str, List[str]]]] = None, top_k: int = 10, index: Optional[str] = None) -> List[Document]: pass diff --git a/haystack/database/elasticsearch.py b/haystack/database/elasticsearch.py index c8ee099030..c8fe75708f 100644 --- a/haystack/database/elasticsearch.py +++ b/haystack/database/elasticsearch.py @@ -9,6 +9,7 @@ from haystack.database.base import BaseDocumentStore, Document, Label from haystack.indexing.utils import eval_data_from_file +from haystack.retriever.base import BaseRetriever logger = logging.getLogger(__name__) @@ -89,7 +90,7 @@ def __init__( self.index: str = index self._create_label_index(label_index) - self.label_index = label_index + self.label_index: str = label_index self.update_existing_documents = update_existing_documents def _create_document_index(self, index_name): @@ -136,17 +137,6 @@ def get_document_by_id(self, id: str, index=None) -> Optional[Document]: document = self._convert_es_hit_to_document(result[0]) if result else None return document - def get_document_ids_by_tags(self, tags: dict, index: Optional[str]) -> List[str]: - index = index or self.index - term_queries = [{"terms": {key: value}} for key, value in tags.items()] - query = {"query": {"bool": {"must": term_queries}}} - logger.debug(f"Tag filter query: {query}") - result = self.client.search(index=index, body=query, size=10000)["hits"]["hits"] - doc_ids = [] - for hit in result: - doc_ids.append(hit["_id"]) - return doc_ids - def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None): """ Indexes documents for later queries in Elasticsearch. @@ -198,7 +188,8 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O documents_to_index.append(_doc) bulk(self.client, documents_to_index, request_timeout=300, refresh="wait_for") - def write_labels(self, labels: Union[List[Label], List[dict]], index: Optional[str] = "label"): + def write_labels(self, labels: Union[List[Label], List[dict]], index: Optional[str] = None): + index = index or self.label_index if index and not self.client.indices.exists(index=index): self._create_label_index(index) @@ -230,7 +221,7 @@ def get_document_count(self, index: Optional[str] = None) -> int: def get_label_count(self, index: Optional[str] = None) -> int: return self.get_document_count(index=index) - def get_all_documents(self, index: Optional[str] = None, filters: Optional[dict] = None) -> List[Document]: + def get_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Document]: if index is None: index = self.index @@ -239,12 +230,13 @@ def get_all_documents(self, index: Optional[str] = None, filters: Optional[dict] return documents - def get_all_labels(self, index: str = "label", filters: Optional[dict] = None) -> List[Label]: + def get_all_labels(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]: + index = index or self.label_index result = self.get_all_documents_in_index(index=index, filters=filters) labels = [Label.from_dict(hit["_source"]) for hit in result] return labels - def get_all_documents_in_index(self, index: str, filters: Optional[dict] = None) -> List[dict]: + def get_all_documents_in_index(self, index: str, filters: Optional[Dict[str, List[str]]] = None) -> List[dict]: body = { "query": { "bool": { @@ -346,7 +338,7 @@ def query( def query_by_embedding(self, query_emb: np.array, - filters: Optional[dict] = None, + filters: Optional[Dict[str, List[str]]] = None, top_k: int = 10, index: Optional[str] = None) -> List[Document]: if index is None: @@ -392,7 +384,7 @@ def query_by_embedding(self, def _convert_es_hit_to_document(self, hit: dict, score_adjustment: int = 0) -> Document: # We put all additional data of the doc into meta_data and return it in the API - meta_data = {k:v for k,v in hit["_source"].items() if k not in (self.text_field, self.faq_question_field, self.embedding_field, "tags")} + meta_data = {k:v for k,v in hit["_source"].items() if k not in (self.text_field, self.faq_question_field, self.embedding_field)} meta_data["name"] = meta_data.pop(self.name_field, None) document = Document( @@ -401,7 +393,6 @@ def _convert_es_hit_to_document(self, hit: dict, score_adjustment: int = 0) -> D meta=meta_data, query_score=hit["_score"] + score_adjustment if hit["_score"] else None, question=hit["_source"].get(self.faq_question_field), - tags=hit["_source"].get("tags"), embedding=hit["_source"].get(self.embedding_field) ) return document @@ -420,12 +411,13 @@ def describe_documents(self, index=None): } return stats - def update_embeddings(self, retriever, index=None): + def update_embeddings(self, retriever: BaseRetriever, index: Optional[str] = None): """ Updates the embeddings in the the document store using the encoding model specified in the retriever. This can be useful if want to add or change the embeddings for your documents (e.g. after changing the retriever config). :param retriever: Retriever + :param index: Index name to update :return: None """ if index is None: @@ -439,7 +431,7 @@ def update_embeddings(self, retriever, index=None): #TODO Index embeddings every X batches to avoid OOM for huge document collections logger.info(f"Updating embeddings for {len(passages)} docs ...") - embeddings = retriever.embed_passages(passages) + embeddings = retriever.embed_passages(passages) # type: ignore assert len(docs) == len(embeddings) diff --git a/haystack/database/memory.py b/haystack/database/memory.py index 101ed0f8ff..64a351d55f 100644 --- a/haystack/database/memory.py +++ b/haystack/database/memory.py @@ -12,7 +12,6 @@ class InMemoryDocumentStore(BaseDocumentStore): """ def __init__(self, embedding_field: Optional[str] = None): - self.doc_tags: Dict[str, Any] = {} self.indexes: Dict[str, Dict] = defaultdict(dict) self.index: str = "document" self.label_index: str = "label" @@ -22,10 +21,11 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O Indexes documents for later queries. - :param documents: a list of Python dictionaries or a list of Haystack Document objects. + :param documents: a list of Python dictionaries or a list of Haystack Document objects. For documents as dictionaries, the format is {"text": ""}. - Optionally, you can also supply "tags": ["one-tag", "another-one"] - or additional meta data via "meta": {"name": ", "author": "someone", "url":"some-url" ...} + Optionally: Include meta data via {"text": "", + "meta": {"name": ", "author": "somebody", ...}} + It can be used for filtering and is accessible in the responses of the Finder. :param index: write documents to a custom namespace. For instance, documents for evaluation can be indexed in a separate index than the documents for search. :return: None @@ -37,10 +37,6 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O for document in documents_objects: self.indexes[index][document.id] = document - #TODO fix tags after id refactoring - tags = document.tags - self._map_tags_to_ids(document.id, tags) - def write_labels(self, labels: Union[List[dict], List[Label]], index: Optional[str] = None): index = index or self.label_index label_objects = [Label.from_dict(l) if isinstance(l, dict) else l for l in labels] @@ -49,21 +45,6 @@ def write_labels(self, labels: Union[List[dict], List[Label]], index: Optional[s label_id = str(uuid4()) self.indexes[index][label_id] = label - def _map_tags_to_ids(self, hash: str, tags: List[str]): - if isinstance(tags, list): - for tag in tags: - if isinstance(tag, dict): - tag_keys = tag.keys() - for tag_key in tag_keys: - tag_values = tag.get(tag_key, []) - if tag_values: - for tag_value in tag_values: - comp_key = str((tag_key, tag_value)) - if comp_key in self.doc_tags: - self.doc_tags[comp_key].append(hash) - else: - self.doc_tags[comp_key] = [hash] - def get_document_by_id(self, id: str, index: Optional[str] = None) -> Document: index = index or self.index return self.indexes[index][id] @@ -79,7 +60,7 @@ def _convert_memory_hit_to_document(self, hit: Dict[str, Any], doc_id: Optional[ def query_by_embedding(self, query_emb: List[float], - filters: Optional[dict] = None, + filters: Optional[Dict[str, List[str]]] = None, top_k: int = 10, index: Optional[str] = None) -> List[Document]: @@ -116,44 +97,36 @@ def update_embeddings(self, retriever): #TODO raise NotImplementedError("update_embeddings() is not yet implemented for this DocumentStore") - def get_document_ids_by_tags(self, tags: Union[List[Dict[str, Union[str, List[str]]]], Dict[str, Union[str, List[str]]]], index: Optional[str] = None) -> List[str]: - """ - The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...} - The format for the dict is {"tag-1": ["value-1","value-2"], "tag-2": ["value-3]" ...} - """ - index = index or self.index - if not isinstance(tags, list): - tags = [tags] - result = self._find_ids_by_tags(tags, index=index) - return result - - def _find_ids_by_tags(self, tags: List[Dict[str, Union[str, List[str]]]], index: str): - result = [] - for tag in tags: - tag_keys = tag.keys() - for tag_key in tag_keys: - tag_values = tag.get(tag_key, None) - if tag_values: - for tag_value in tag_values: - comp_key = str((tag_key, tag_value)) - doc_ids = self.doc_tags.get(comp_key, []) - for doc_id in doc_ids: - result.append(self.indexes[index].get(doc_id)) - return result - - def get_document_count(self, index=None) -> int: + def get_document_count(self, index: Optional[str] = None) -> int: index = index or self.index return len(self.indexes[index].items()) - def get_label_count(self, index=None) -> int: + def get_label_count(self, index: Optional[str] = None) -> int: index = index or self.label_index return len(self.indexes[index].items()) - def get_all_documents(self, index=None) -> List[Document]: + def get_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Document]: index = index or self.index - return list(self.indexes[index].values()) + documents = list(self.indexes[index].values()) + filtered_documents = [] + + if filters: + for doc in documents: + is_hit = True + for key, values in filters.items(): + if doc.meta.get(key): + if doc.meta[key] not in values: + is_hit = False + else: + is_hit = False + if is_hit: + filtered_documents.append(doc) + else: + filtered_documents = documents + + return filtered_documents - def get_all_labels(self, index=None, filters=None) -> List[Label]: + def get_all_labels(self, index: str = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Label]: index = index or self.label_index if filters: @@ -172,7 +145,7 @@ def get_all_labels(self, index=None, filters=None) -> List[Label]: return result - def add_eval_data(self, filename: str, doc_index: str = "document", label_index: str = "label"): + def add_eval_data(self, filename: str, doc_index: Optional[str] = None, label_index: Optional[str] = None): """ Adds a SQuAD-formatted file to the DocumentStore in order to be able to perform evaluation on it. @@ -185,10 +158,12 @@ def add_eval_data(self, filename: str, doc_index: str = "document", label_index: """ docs, labels = eval_data_from_file(filename) + doc_index = doc_index or self.index + label_index = label_index or self.label_index self.write_documents(docs, index=doc_index) self.write_labels(labels, index=label_index) - def delete_all_documents(self, index=None): + def delete_all_documents(self, index: Optional[str] = None): """ Delete all documents in a index. @@ -197,4 +172,4 @@ def delete_all_documents(self, index=None): """ index = index or self.index - self.indexes[index] = {} \ No newline at end of file + self.indexes[index] = {} diff --git a/haystack/database/sql.py b/haystack/database/sql.py index 8ded8ffa30..1aaf774dcf 100644 --- a/haystack/database/sql.py +++ b/haystack/database/sql.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Union, List, Optional from uuid import uuid4 -from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, PickleType, Boolean +from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, Boolean from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship, sessionmaker @@ -14,7 +14,7 @@ class ORMBase(Base): __abstract__ = True - id = Column(String, primary_key=True) + id = Column(String, default=lambda: str(uuid4()), primary_key=True) created = Column(DateTime, server_default=func.now()) updated = Column(DateTime, server_default=func.now(), server_onupdate=func.now()) @@ -24,25 +24,24 @@ class DocumentORM(ORMBase): text = Column(String, nullable=False) index = Column(String, nullable=False) - meta_data = Column(PickleType) - tags = relationship("TagORM", secondary="document_tag", backref="Document") + meta = relationship("MetaORM", secondary="document_meta", backref="Document") -class TagORM(ORMBase): - __tablename__ = "tag" +class MetaORM(ORMBase): + __tablename__ = "meta" name = Column(String) value = Column(String) - documents = relationship(DocumentORM, secondary="document_tag", backref="Tag") + documents = relationship(DocumentORM, secondary="document_meta", backref="Meta") -class DocumentTagORM(ORMBase): - __tablename__ = "document_tag" +class DocumentMetaORM(ORMBase): + __tablename__ = "document_meta" document_id = Column(String, ForeignKey("document.id"), nullable=False) - tag_id = Column(Integer, ForeignKey("tag.id"), nullable=False) + meta_id = Column(Integer, ForeignKey("meta.id"), nullable=False) class LabelORM(ORMBase): @@ -75,11 +74,17 @@ def get_document_by_id(self, id: str, index=None) -> Optional[Document]: document = document_row or self._convert_sql_row_to_document(document_row) return document - def get_all_documents(self, index=None) -> List[Document]: + def get_all_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> List[Document]: index = index or self.index - document_rows = self.session.query(DocumentORM).filter_by(index=index).all() - documents = [self._convert_sql_row_to_document(row) for row in document_rows] + if filters: + for key, values in filters.items(): + results = self.session.query(DocumentORM).filter(DocumentORM.meta.any(MetaORM.name.in_([key]))).\ + filter(DocumentORM.meta.any(MetaORM.value.in_(values))).all() + else: + results = self.session.query(DocumentORM).filter_by(index=index).all() + + documents = [self._convert_sql_row_to_document(row) for row in results] return documents def get_all_labels(self, index=None, filters: Optional[dict] = None): @@ -89,45 +94,15 @@ def get_all_labels(self, index=None, filters: Optional[dict] = None): return labels - def get_document_ids_by_tags(self, tags: Dict[str, Union[str, List]], index: Optional[str] = None) -> List[str]: - """ - Get list of document ids that have tags from the given list of tags. - - :param tags: limit scope to documents having the given tags and their corresponding values. - The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...} - """ - if not tags: - raise Exception("No tag supplied for filtering the documents") - - if index: - raise Exception("'index' parameter is not supported in SQLDocumentStore.get_document_ids_by_tags().") - - query = """ - SELECT id FROM document WHERE id in ( - SELECT dt.document_id - FROM document_tag dt JOIN - tag t - ON t.id = dt.tag_id - GROUP BY dt.document_id - """ - tag_filters = [] - for tag in tags: - tag_filters.append(f"SUM(CASE WHEN t.value='{tag}' THEN 1 ELSE 0 END) > 0") - - final_query = f"{query} HAVING {' AND '.join(tag_filters)});" - query_results = self.session.execute(final_query) - - doc_ids = [row[0] for row in query_results] - return doc_ids - def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None): """ Indexes documents for later queries. - :param documents: a list of Python dictionaries or a list of Haystack Document objects. + :param documents: a list of Python dictionaries or a list of Haystack Document objects. For documents as dictionaries, the format is {"text": ""}. - Optionally, you can also supply "tags": ["one-tag", "another-one"] - or additional meta data via "meta": {"name": ", "author": "someone", "url":"some-url" ...} + Optionally: Include meta data via {"text": "", + "meta":{"name": ", "author": "somebody", ...}} + It can be used for filtering and is accessible in the responses of the Finder. :param index: add an optional index attribute to documents. It can be later used for filtering. For instance, documents for evaluation can be indexed in a separate index than the documents for search. @@ -135,11 +110,12 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O """ # Make sure we comply to Document class format - documents = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents] + document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents] index = index or self.index - for doc in documents: - row = DocumentORM(id=doc.id, text=doc.text, meta_data=doc.meta, index=index) # type: ignore - self.session.add(row) + for doc in document_objects: + meta_orms = [MetaORM(name=key, value=value) for key, value in doc.meta.items()] + doc_orm = DocumentORM(id=doc.id, text=doc.text, meta=meta_orms, index=index) + self.session.add(doc_orm) self.session.commit() def write_labels(self, labels, index=None): @@ -148,7 +124,6 @@ def write_labels(self, labels, index=None): index = index or self.index for label in labels: label_orm = LabelORM( - id=str(uuid4()), document_id=label.document_id, no_answer=label.no_answer, origin=label.origin, @@ -163,7 +138,7 @@ def write_labels(self, labels, index=None): self.session.add(label_orm) self.session.commit() - def add_eval_data(self, filename: str, doc_index: str = "document", label_index: str = "label"): + def add_eval_data(self, filename: str, doc_index: str = "eval_document", label_index: str = "label"): """ Adds a SQuAD-formatted file to the DocumentStore in order to be able to perform evaluation on it. @@ -191,8 +166,7 @@ def _convert_sql_row_to_document(self, row) -> Document: document = Document( id=row.id, text=row.text, - meta=row.meta_data, - tags=row.tags + meta={meta.name: meta.value for meta in row.meta} ) return document diff --git a/haystack/finder.py b/haystack/finder.py index 4ab315b8e4..67df8b6fd3 100644 --- a/haystack/finder.py +++ b/haystack/finder.py @@ -32,8 +32,8 @@ def get_answers(self, question: str, top_k_reader: int = 1, top_k_retriever: int :param question: the question string :param top_k_reader: number of answers returned by the reader :param top_k_retriever: number of text units to be retrieved - :param filters: limit scope to documents having the given tags and their corresponding values. - The format for the dict is {"tag-1": ["value-1","value-2"], "tag-2": ["value-3]" ...} + :param filters: limit scope to documents having the given meta data values. + The format for the dict is {"key-1": ["value-1", "value-2"], "key-2": ["value-3]" ...} :return: """ @@ -71,8 +71,8 @@ def get_answers_via_similar_questions(self, question: str, top_k_retriever: int :param question: the question string :param top_k_retriever: number of text units to be retrieved - :param filters: limit scope to documents having the given tags and their corresponding values. - The format for the dict is {"tag-1": "value-1", "tag-2": "value-2" ...} + :param filters: limit scope to documents having the given meta data values. + The format for the dict is {"key-1": ["value-1", "value-2"], "key-2": ["value-3]" ...} :return: """ diff --git a/test/conftest.py b/test/conftest.py index d068646cfd..ad8cb749fe 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,20 +1,19 @@ - +import os import tarfile import time import urllib.request from subprocess import Popen, PIPE, STDOUT, run -import os import pytest from elasticsearch import Elasticsearch +from haystack.database.base import Document +from haystack.database.elasticsearch import ElasticsearchDocumentStore +from haystack.database.memory import InMemoryDocumentStore +from haystack.database.sql import SQLDocumentStore from haystack.reader.farm import FARMReader from haystack.reader.transformers import TransformersReader -from haystack.database.base import Document -from haystack.database.sql import SQLDocumentStore -from haystack.database.memory import InMemoryDocumentStore -from haystack.database.elasticsearch import ElasticsearchDocumentStore @pytest.fixture(scope='session') def elasticsearch_dir(tmpdir_factory): @@ -122,7 +121,6 @@ def document_store_with_docs(request, test_docs_xs, elasticsearch_fixture): document_store = ElasticsearchDocumentStore(index="haystack_test") assert document_store.get_document_count() == 0 document_store.write_documents(test_docs_xs) - time.sleep(2) return document_store diff --git a/test/test_db.py b/test/test_db.py index ed772c53c8..6cf6ddf275 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -1,21 +1,45 @@ import pytest -import time from haystack.database.base import Document -def test_get_all_documents(document_store_with_docs): +def test_get_all_documents_without_filters(document_store_with_docs): documents = document_store_with_docs.get_all_documents() assert all(isinstance(d, Document) for d in documents) assert len(documents) == 3 assert {d.meta["name"] for d in documents} == {"filename1", "filename2", "filename3"} assert {d.meta["meta_field"] for d in documents} == {"test1", "test2", "test3"} + + +def test_get_all_documents_with_correct_filters(document_store_with_docs): + documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test2"]}) + assert len(documents) == 1 + assert documents[0].meta["name"] == "filename2" + + documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test1", "test3"]}) + assert len(documents) == 2 + assert {d.meta["name"] for d in documents} == {"filename1", "filename3"} + assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"} + + +def test_get_all_documents_with_incorrect_filter_name(document_store_with_docs): + documents = document_store_with_docs.get_all_documents(filters={"incorrect_meta_field": ["test2"]}) + assert len(documents) == 0 + + +def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs): + documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["incorrect_value"]}) + assert len(documents) == 0 + + +def test_get_documents_by_id(document_store_with_docs): + documents = document_store_with_docs.get_all_documents() doc = document_store_with_docs.get_document_by_id(documents[0].id) assert doc.id == documents[0].id assert doc.text == documents[0].text -@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True) +@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True) def test_elasticsearch_update_meta(document_store_with_docs): document = document_store_with_docs.query(query=None, filters={"name": ["filename1"]})[0] document_store_with_docs.update_document_meta(document.id, meta={"meta_field": "updated_meta"}) diff --git a/test/test_in_memory_store.py b/test/test_in_memory_store.py deleted file mode 100644 index 97c4e0f378..0000000000 --- a/test/test_in_memory_store.py +++ /dev/null @@ -1,85 +0,0 @@ -from haystack import Finder -from haystack.reader.transformers import TransformersReader -from haystack.retriever.sparse import TfidfRetriever - - -def test_finder_get_answers_with_in_memory_store(): - test_docs = [ - {"text": "testing the finder with pyhton unit test 1", 'meta': {"name": "testing the finder 1", 'url': 'url'}}, - {"text": "testing the finder with pyhton unit test 2", 'meta': {"name": "testing the finder 2", 'url': 'url'}}, - {"text": "testing the finder with pyhton unit test 3", 'meta': {"name": "testing the finder 3", 'url': 'url'}} - ] - - from haystack.database.memory import InMemoryDocumentStore - document_store = InMemoryDocumentStore() - document_store.write_documents(test_docs) - - retriever = TfidfRetriever(document_store=document_store) - reader = TransformersReader(model="distilbert-base-uncased-distilled-squad", - tokenizer="distilbert-base-uncased", use_gpu=-1) - finder = Finder(reader, retriever) - prediction = finder.get_answers(question="testing finder", top_k_retriever=10, - top_k_reader=5) - assert prediction is not None - - -def test_memory_store_get_by_tags(): - test_docs = [ - {"text": "testing the finder with pyhton unit test 1", 'meta': {"name": "testing the finder 1", 'url': 'url'}}, - {"text": "testing the finder with pyhton unit test 2", 'meta': {"name": "testing the finder 2", 'url': None}}, - {"text": "testing the finder with pyhton unit test 3", 'meta': {"name": "testing the finder 3", 'url': 'url'}} - ] - - from haystack.database.memory import InMemoryDocumentStore - document_store = InMemoryDocumentStore() - document_store.write_documents(test_docs) - - docs = document_store.get_document_ids_by_tags({'has_url': 'false'}) - - assert docs == [] - - -def test_memory_store_get_by_tag_lists_union(): - test_docs = [ - {"text": "testing the finder with pyhton unit test 1", 'meta': {"name": "testing the finder 1", 'url': 'url'}, 'tags': [{'tag2': ["1"]}]}, - {"text": "testing the finder with pyhton unit test 2", 'meta': {"name": "testing the finder 2", 'url': None}, 'tags': [{'tag1': ['1']}]}, - {"text": "testing the finder with pyhton unit test 3", 'meta': {"name": "testing the finder 3", 'url': 'url'}, 'tags': [{'tag2': ["1", "2"]}]} - ] - - from haystack.database.memory import InMemoryDocumentStore - document_store = InMemoryDocumentStore() - document_store.write_documents(test_docs) - - docs = document_store.get_document_ids_by_tags({'tag2': ["1"]}) - assert docs[0].text == 'testing the finder with pyhton unit test 1' - assert docs[1].text == 'testing the finder with pyhton unit test 3' - assert docs[1].text == 'testing the finder with pyhton unit test 3' - assert docs[1].tags[0] == {"tag2": ["1", "2"]} - -def test_memory_store_get_by_tag_lists_non_existent_tag(): - test_docs = [ - {"text": "testing the finder with pyhton unit test 1", 'meta': {'url': 'url', "name": "testing the finder 1"}, 'tags': [{'tag1': ["1"]}]}, - ] - from haystack.database.memory import InMemoryDocumentStore - document_store = InMemoryDocumentStore() - document_store.write_documents(test_docs) - docs = document_store.get_document_ids_by_tags({'tag1': ["3"]}) - assert docs == [] - - -def test_memory_store_get_by_tag_lists_disjoint(): - test_docs = [ - {"text": "testing the finder with pyhton unit test 1", 'meta': {"name": "testing the finder 1", 'url': 'url'}, 'tags': [{'tag1': ["1"]}]}, - {"text": "testing the finder with pyhton unit test 2", 'meta': {"name": "testing the finder 2", 'url': None}, 'tags': [{'tag2': ['1']}]}, - {"text": "testing the finder with pyhton unit test 3", 'meta': {"name": "testing the finder 3", 'url': 'url'}, 'tags': [{'tag3': ["1", "2"]}]}, - {"text": "testing the finder with pyhton unit test 3", 'meta': {"name": "testing the finder 4", 'url': 'url'}, 'tags': [{'tag3': ["1", "3"]}]} - ] - - from haystack.database.memory import InMemoryDocumentStore - document_store = InMemoryDocumentStore() - document_store.write_documents(test_docs) - - docs = document_store.get_document_ids_by_tags({'tag3': ["3"]}) - assert len(docs) == 1 - assert docs[0].text == 'testing the finder with pyhton unit test 3' - assert docs[0].tags[0] == {"tag3": ["1", "3"]} \ No newline at end of file