Skip to content

Commit

Permalink
Add dummy retriever for benchmarking / reader-only settings (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanaysoni authored Jul 15, 2020
1 parent eb658d3 commit 5c1a5fe
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ services:
- READER_MODEL_PATH=deepset/roberta-base-squad2
# - READER_MODEL_PATH=home/user/models/roberta-base-squad2
# Alternative: If you want to use the TransformersReader (e.g. for loading a local model in transformers format):
# - READER_USE_TRANSFORMERS=True
# - READER_TYPE=TransformersReader
# - READER_MODEL_PATH=/home/user/models/roberta-base-squad2
# - READER_TOKENIZER=/home/user/models/roberta-base-squad2
restart: always
Expand Down
38 changes: 31 additions & 7 deletions haystack/database/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def get_all_documents(self) -> List[Document]:

def query(
self,
query: str,
filters: Optional[dict] = None,
query: Optional[str],
filters: Optional[Dict[str, List[str]]] = None,
top_k: int = 10,
custom_query: Optional[str] = None,
index: Optional[str] = None,
Expand All @@ -166,20 +166,41 @@ def query(
if index is None:
index = self.index

if custom_query: # substitute placeholder for question and filters for the custom_query template string
template = Template(custom_query)
# Naive retrieval without BM25, only filtering
if query is None:
body = {"query":
{"bool": {"must":
{"match_all": {}}}}} # type: Dict[str, Any]
if filters:
filter_clause = []
for key, values in filters.items():
filter_clause.append(
{
"terms": {key: values}
}
)
body["query"]["bool"]["filter"] = filter_clause

substitutions = {"question": query} # replace all "${question}" placeholder(s) with query
# replace all filter values placeholders with a list of strings(in JSON format) for each filter
# Retrieval via custom query
elif custom_query: # substitute placeholder for question and filters for the custom_query template string
template = Template(custom_query)
# replace all "${question}" placeholder(s) with query
substitutions = {"question": query}
# For each filter we got passed, we'll try to find & replace the corresponding placeholder in the template
# Example: filters={"years":[2018]} => replaces {$years} in custom_query with '[2018]'
if filters:
for key, values in filters.items():
values_str = json.dumps(values)
substitutions[key] = values_str
custom_query_json = template.substitute(**substitutions)
body = json.loads(custom_query_json)
# add top_k
body["size"] = str(top_k)

# Default Retrieval via BM25 using the user query on `self.search_fields`
else:
body = {
"size": top_k,
"size": str(top_k),
"query": {
"bool": {
"should": [{"multi_match": {"query": query, "type": "most_fields", "fields": self.search_fields}}]
Expand All @@ -190,6 +211,9 @@ def query(
if filters:
filter_clause = []
for key, values in filters.items():
if type(values) != list:
raise ValueError(f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. '
'Example: {"name": ["some", "more"], "category": ["only_one"]} ')
filter_clause.append(
{
"terms": {key: values}
Expand Down
20 changes: 19 additions & 1 deletion haystack/retriever/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ def __init__(self, document_store: ElasticsearchDocumentStore, custom_query: str
"fields": ["text", "title"]}}],
"filter": [ // optional custom filters
{"terms": {"year": "${years}"}},
{"terms": {"quarter": "${quarters}"}}],
{"terms": {"quarter": "${quarters}"}},
{"range": {"date": {"gte": "${date}"}}}
],
}
},
}
Expand Down Expand Up @@ -104,6 +107,21 @@ def eval(
return {"recall": recall, "map": mean_avg_precision}


class ElasticsearchFilterOnlyRetriever(ElasticsearchRetriever):
"""
Naive "Retriever" that returns all documents that match the given filters. No impact of query at all.
Helpful for benchmarking, testing and if you want to do QA on small documents without an "active" retriever.
"""

def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
if index is None:
index = self.document_store.index
documents = self.document_store.query(query=None, filters=filters, top_k=top_k,
custom_query=self.custom_query, index=index)
logger.info(f"Got {len(documents)} candidates from retriever")

return documents

# TODO make Paragraph generic for configurable units of text eg, pages, paragraphs, or split by a char_limit
Paragraph = namedtuple("Paragraph", ["paragraph_id", "document_id", "text", "meta"])

Expand Down
3 changes: 2 additions & 1 deletion rest_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

# Reader
READER_MODEL_PATH = os.getenv("READER_MODEL_PATH", None)
READER_USE_TRANSFORMERS = os.getenv("READER_USE_TRANSFORMERS", "False").lower() == "true"
READER_TYPE = os.getenv("READER_TYPE", "FARMReader") # alternative: 'TransformersReader'
READER_TOKENIZER = os.getenv("READER_TOKENIZER", None)
CONTEXT_WINDOW_SIZE = int(os.getenv("CONTEXT_WINDOW_SIZE", 500))
DEFAULT_TOP_K_READER = int(os.getenv("DEFAULT_TOP_K_READER", 5))
Expand All @@ -37,6 +37,7 @@
MAX_SEQ_LEN = int(os.getenv("MAX_SEQ_LEN", 256))

# Retriever
RETRIEVER_TYPE = os.getenv("RETRIEVER_TYPE", "ElasticsearchRetriever") # alternatives: 'EmbeddingRetriever', 'ElasticsearchRetriever', 'ElasticsearchFilterOnlyRetriever', None
DEFAULT_TOP_K_RETRIEVER = int(os.getenv("DEFAULT_TOP_K_RETRIEVER", 10))
EXCLUDE_META_DATA_FIELDS = os.getenv("EXCLUDE_META_DATA_FIELDS", None)
if EXCLUDE_META_DATA_FIELDS:
Expand Down
28 changes: 21 additions & 7 deletions rest_api/controller/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@

from haystack import Finder
from rest_api.config import DB_HOST, DB_PORT, DB_USER, DB_PW, DB_INDEX, ES_CONN_SCHEME, TEXT_FIELD_NAME, SEARCH_FIELD_NAME, \
EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, EMBEDDING_MODEL_PATH, USE_GPU, READER_MODEL_PATH, \
EMBEDDING_DIM, EMBEDDING_FIELD_NAME, EXCLUDE_META_DATA_FIELDS, RETRIEVER_TYPE, EMBEDDING_MODEL_PATH, USE_GPU, READER_MODEL_PATH, \
BATCHSIZE, CONTEXT_WINDOW_SIZE, TOP_K_PER_CANDIDATE, NO_ANS_BOOST, MAX_PROCESSES, MAX_SEQ_LEN, DOC_STRIDE, \
DEFAULT_TOP_K_READER, DEFAULT_TOP_K_RETRIEVER, CONCURRENT_REQUEST_PER_WORKER, FAQ_QUESTION_FIELD_NAME, \
EMBEDDING_MODEL_FORMAT, READER_USE_TRANSFORMERS, READER_TOKENIZER, GPU_NUMBER
EMBEDDING_MODEL_FORMAT, READER_TYPE, READER_TOKENIZER, GPU_NUMBER
from rest_api.controller.utils import RequestLimiter
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.reader.farm import FARMReader
from haystack.reader.transformers import TransformersReader
from haystack.retriever.base import BaseRetriever
from haystack.retriever.sparse import ElasticsearchRetriever
from haystack.retriever.sparse import ElasticsearchRetriever, ElasticsearchFilterOnlyRetriever
from haystack.retriever.dense import EmbeddingRetriever

logger = logging.getLogger(__name__)
Expand All @@ -44,26 +44,35 @@
)


if EMBEDDING_MODEL_PATH:
if RETRIEVER_TYPE == "EmbeddingRetriever":
retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model=EMBEDDING_MODEL_PATH,
model_format=EMBEDDING_MODEL_FORMAT,
gpu=USE_GPU
) # type: BaseRetriever
else:
elif RETRIEVER_TYPE == "ElasticsearchRetriever":
retriever = ElasticsearchRetriever(document_store=document_store)
elif RETRIEVER_TYPE is None or RETRIEVER_TYPE == "ElasticsearchFilterOnlyRetriever":
retriever = ElasticsearchFilterOnlyRetriever(document_store=document_store)
else:
raise ValueError(f"Could not load Retriever of type '{RETRIEVER_TYPE}'. "
f"Please adjust RETRIEVER_TYPE to one of: "
f"'EmbeddingRetriever', 'ElasticsearchRetriever', 'ElasticsearchFilterOnlyRetriever', None"
f"OR modify rest_api/search.py to support your retriever"
)


if READER_MODEL_PATH: # for extractive doc-qa
if READER_USE_TRANSFORMERS:
if READER_TYPE == "TransformersReader":
use_gpu = -1 if not USE_GPU else GPU_NUMBER
reader = TransformersReader(
model=str(READER_MODEL_PATH),
use_gpu=use_gpu,
context_window_size=CONTEXT_WINDOW_SIZE,
tokenizer=str(READER_TOKENIZER)
) # type: Optional[FARMReader]
else:
elif READER_TYPE == "FARMReader":
reader = FARMReader(
model_name_or_path=str(READER_MODEL_PATH),
batch_size=BATCHSIZE,
Expand All @@ -75,6 +84,11 @@
max_seq_len=MAX_SEQ_LEN,
doc_stride=DOC_STRIDE,
) # type: Optional[FARMReader]
else:
raise ValueError(f"Could not load Reader of type '{READER_TYPE}'. "
f"Please adjust READER_TYPE to one of: "
f"'FARMReader', 'TransformersReader', None"
)
else:
reader = None # don't need one for pure FAQ matching

Expand Down
24 changes: 24 additions & 0 deletions test/test_dummy_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from haystack.database.base import Document
import pytest


@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
def test_dummy_retriever(document_store_with_docs):
from haystack.retriever.sparse import ElasticsearchFilterOnlyRetriever
retriever = ElasticsearchFilterOnlyRetriever(document_store_with_docs)

result = retriever.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=1)
assert type(result[0]) == Document
assert result[0].text == "My name is Carla and I live in Berlin"
assert result[0].meta["name"] == "filename1"

result = retriever.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=5)
assert type(result[0]) == Document
assert result[0].text == "My name is Carla and I live in Berlin"
assert result[0].meta["name"] == "filename1"

result = retriever.retrieve(query="godzilla", filters={"name": ["filename3"]}, top_k=5)
assert type(result[0]) == Document
assert result[0].text == "My name is Christelle and I live in Paris"
assert result[0].meta["name"] == "filename3"

35 changes: 35 additions & 0 deletions test/test_elastic_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from haystack.retriever.sparse import ElasticsearchRetriever
import pytest


@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
def test_elasticsearch_retrieval(document_store_with_docs):
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
res = retriever.retrieve(query="Who lives in Berlin?")
assert res[0].text == "My name is Carla and I live in Berlin"
assert len(res) == 3
assert res[0].meta["name"] == "filename1"

@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
def test_elasticsearch_retrieval_filters(document_store_with_docs):
retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name": ["filename1"]})
assert res[0].text == "My name is Carla and I live in Berlin"
assert len(res) == 1
assert res[0].meta["name"] == "filename1"

res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["not_existing_value"]})
assert len(res) == 0

res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "not_existing_field": ["not_existing_value"]})
assert len(res) == 0

retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["test1","test2"]})
assert res[0].text == "My name is Carla and I live in Berlin"
assert len(res) == 1
assert res[0].meta["name"] == "filename1"

retriever = ElasticsearchRetriever(document_store=document_store_with_docs)
res = retriever.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field":["test2"]})
assert len(res) == 0
4 changes: 2 additions & 2 deletions test/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_context_window_size(test_docs_xs):
doc = Document(id=d["meta"]["name"], text=d["text"], meta=d["meta"])
docs.append(doc)
for window_size in [10, 15, 20]:
farm_reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad",
farm_reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", num_processes=0,
use_gpu=False, top_k_per_sample=5, no_ans_boost=None, context_window_size=window_size)
prediction = farm_reader.predict(question="Who lives in Berlin?", documents=docs, top_k=5)
for answer in prediction["answers"]:
Expand All @@ -90,7 +90,7 @@ def test_top_k(test_docs_xs):
for d in test_docs_xs:
doc = Document(id=d["meta"]["name"], text=d["text"], meta=d["meta"])
docs.append(doc)
farm_reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad",
farm_reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad", num_processes=0,
use_gpu=False, top_k_per_sample=4, no_ans_boost=None, top_k_per_candidate=4)
for top_k in [2, 5, 10]:
prediction = farm_reader.predict(question="Who lives in Berlin?", documents=docs, top_k=top_k)
Expand Down

0 comments on commit 5c1a5fe

Please sign in to comment.