Skip to content

Commit

Permalink
Prompt node/run batch (#4072)
Browse files Browse the repository at this point in the history
* Starting to implement first pass at run_batch

* Started to add _flatten_input function

* First pass at run_batch method.

* Fixed bug

* Adding tests for run_batch

* Update doc strings

* Pylint and mypy

* Pylint

* Fixing mypy

* Restructurig of run_batch tests

* Add minor lg updates

* Adding more tests

* Update dev comments and call static method differently

* Fixed the setting of output variable

* Set output_variable in __init__ of PromptNode

* Make a one-liner

---------

Co-authored-by: agnieszka-m <[email protected]>
  • Loading branch information
sjrl and agnieszka-m authored Feb 20, 2023
1 parent 83d615a commit d129598
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 22 deletions.
142 changes: 122 additions & 20 deletions haystack/nodes/prompt/prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def __init__(
super().__init__()
self.prompt_templates: Dict[str, PromptTemplate] = {pt.name: pt for pt in get_predefined_prompt_templates()} # type: ignore
self.default_prompt_template: Union[str, PromptTemplate, None] = default_prompt_template
self.output_variable: Optional[str] = output_variable
self.output_variable: str = output_variable or "results"
self.model_name_or_path: Union[str, PromptModel] = model_name_or_path
self.prompt_model: PromptModel
self.stop_words: Optional[List[str]] = stop_words
Expand Down Expand Up @@ -924,8 +924,10 @@ def run(
invocation_context: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict, str]:
"""
Runs the PromptNode on these input parameters. Returns the output of the prompt model.
Parameters `file_paths`, `labels`, and `meta` are usually ignored.
Runs the PromptNode on these inputs parameters. Returns the output of the prompt model.
The parameters `query`, `file_paths`, `labels`, `documents` and `meta` are added to the invocation context
before invoking the prompt model. PromptNode uses these variables only if they are present as
parameters in the PromptTemplate.
:param query: The PromptNode usually ignores the query, unless it's used as a parameter in the
prompt template.
Expand All @@ -934,7 +936,8 @@ def run(
:param labels: The PromptNode usually ignores the labels, unless they're used as a parameter in the
prompt template.
:param documents: The documents to be used for the prompt.
:param meta: The meta to be used for the prompt. Usually not used.
:param meta: PromptNode usually ignores meta information, unless it's used as a parameter in the
PromptTemplate.
:param invocation_context: The invocation context to be used for the prompt.
"""
# prompt_collector is an empty list, it's passed to the PromptNode that will fill it with the rendered prompts,
Expand Down Expand Up @@ -967,29 +970,128 @@ def run(

results = self(prompt_collector=prompt_collector, **invocation_context)

final_result: Dict[str, Any] = {}
output_variable = self.output_variable or "results"
if output_variable:
invocation_context[output_variable] = results
final_result[output_variable] = results

final_result["invocation_context"] = invocation_context
final_result["_debug"] = {"prompts_used": prompt_collector}
invocation_context[self.output_variable] = results
final_result: Dict[str, Any] = {
self.output_variable: results,
"invocation_context": invocation_context,
"_debug": {"prompts_used": prompt_collector},
}
return final_result, "output_1"

def run_batch(
def run_batch( # type: ignore
self,
queries: Optional[Union[str, List[str]]] = None,
file_paths: Optional[List[str]] = None,
labels: Optional[Union[MultiLabel, List[MultiLabel]]] = None,
queries: Optional[List[str]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
invocation_contexts: Optional[List[Dict[str, Any]]] = None,
):
raise NotImplementedError("run_batch is not implemented for PromptNode.")
"""
Runs PromptNode in batch mode.
- If you provide a list containing a single query (and/or invocation context)...
- ... and a single list of Documents, the query is applied to each Document individually.
- ... and a list of lists of Documents, the query is applied to each list of Documents and the results
are aggregated per Document list.
- If you provide a list of multiple queries (and/or multiple invocation contexts)...
- ... and a single list of Documents, each query (and/or invocation context) is applied to each Document individually.
- ... and a list of lists of Documents, each query (and/or invocation context) is applied to its corresponding list of Documents
and the results are aggregated per query-Document pair.
- If you provide no Documents, then each query (and/or invocation context) is applied directly to the PromptTemplate.
:param queries: List of queries.
:param documents: Single list of Documents or list of lists of Documents in which to search for the answers.
:param invocation_contexts: List of invocation contexts.
"""
inputs = PromptNode._flatten_inputs(queries, documents, invocation_contexts)
all_results: Dict[str, List] = {self.output_variable: [], "invocation_contexts": [], "_debug": []}
for query, docs, invocation_context in zip(
inputs["queries"], inputs["documents"], inputs["invocation_contexts"]
):
results = self.run(query=query, documents=docs, invocation_context=invocation_context)[0]
all_results[self.output_variable].append(results[self.output_variable])
all_results["invocation_contexts"].append(all_results["invocation_contexts"])
all_results["_debug"].append(all_results["_debug"])
return all_results, "output_1"

def _prepare_model_kwargs(self):
# these are the parameters from PromptNode level
# that are passed to the prompt model invocation layer
return {"stop_words": self.stop_words}

@staticmethod
def _flatten_inputs(
queries: Optional[List[str]] = None,
documents: Optional[Union[List[Document], List[List[Document]]]] = None,
invocation_contexts: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, List]:
"""Flatten and copy the queries, documents, and invocation contexts into lists of equal length.
- If you provide a list containing a single query (and/or invocation context)...
- ... and a single list of Documents, the query is applied to each Document individually.
- ... and a list of lists of Documents, the query is applied to each list of Documents and the results
are aggregated per Document list.
- If you provide a list of multiple queries (and/or multiple invocation contexts)...
- ... and a single list of Documents, each query (and/or invocation context) is applied to each Document individually.
- ... and a list of lists of Documents, each query (and/or invocation context) is applied to its corresponding list of Documents
and the results are aggregated per query-Document pair.
- If you provide no Documents, then each query (and/or invocation context) is applied to the PromptTemplate.
:param queries: List of queries.
:param documents: Single list of Documents or list of lists of Documents in which to search for the answers.
:param invocation_contexts: List of invocation contexts.
"""
# Check that queries, and invocation_contexts are of the same length if provided
input_queries: List[Any]
input_invocation_contexts: List[Any]
if queries is not None and invocation_contexts is not None:
if len(queries) != len(invocation_contexts):
raise ValueError("The input variables queries and invocation_contexts should have the same length.")
input_queries = queries
input_invocation_contexts = invocation_contexts
elif queries is not None and invocation_contexts is None:
input_queries = queries
input_invocation_contexts = [None] * len(queries)
elif queries is None and invocation_contexts is not None:
input_queries = [None] * len(invocation_contexts)
input_invocation_contexts = invocation_contexts
else:
input_queries = [None]
input_invocation_contexts = [None]

multi_docs_list = isinstance(documents, list) and len(documents) > 0 and isinstance(documents[0], list)
single_docs_list = isinstance(documents, list) and len(documents) > 0 and isinstance(documents[0], Document)

# Docs case 1: single list of Documents
# -> apply each query (and invocation_contexts) to all Documents
inputs: Dict[str, List] = {"queries": [], "invocation_contexts": [], "documents": []}
if documents is not None:
if single_docs_list:
for query, invocation_context in zip(input_queries, input_invocation_contexts):
for doc in documents:
inputs["queries"].append(query)
inputs["invocation_contexts"].append(invocation_context)
inputs["documents"].append([doc])
# Docs case 2: list of lists of Documents
# -> apply each query (and invocation_context) to corresponding list of Documents,
# if queries contains only one query, apply it to each list of Documents
elif multi_docs_list:
total_queries = input_queries.copy()
total_invocation_contexts = input_invocation_contexts.copy()
if len(total_queries) == 1 and len(total_invocation_contexts) == 1:
total_queries = input_queries * len(documents)
total_invocation_contexts = input_invocation_contexts * len(documents)
if len(total_queries) != len(documents) or len(total_invocation_contexts) != len(documents):
raise ValueError("Number of queries must be equal to number of provided Document lists.")
for query, invocation_context, cur_docs in zip(total_queries, total_invocation_contexts, documents):
inputs["queries"].append(query)
inputs["invocation_contexts"].append(invocation_context)
inputs["documents"].append(cur_docs)
elif queries is not None or invocation_contexts is not None:
for query, invocation_context in zip(input_queries, input_invocation_contexts):
inputs["queries"].append(query)
inputs["invocation_contexts"].append(invocation_context)
inputs["documents"].append([None])
return inputs
76 changes: 74 additions & 2 deletions test/nodes/test_prompt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,12 @@ def test_simple_pipeline(prompt_model):
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
pytest.skip("No API key found for OpenAI, skipping test")

node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis")
node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable="out")

pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
result = pipe.run(query="not relevant", documents=[Document("Berlin is an amazing city.")])
assert result["results"][0].casefold() == "positive"
assert result["out"][0].casefold() == "positive"


@pytest.mark.integration
Expand Down Expand Up @@ -748,6 +748,78 @@ def test_complex_pipeline_with_multiple_same_prompt_node_components_yaml(tmp_pat
assert pipeline is not None


class TestRunBatch:
@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True)
def test_simple_pipeline_batch_no_query_single_doc_list(self, prompt_model):
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
pytest.skip("No API key found for OpenAI, skipping test")

node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis")

pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
result = pipe.run_batch(
queries=None, documents=[Document("Berlin is an amazing city."), Document("I am not feeling well.")]
)
assert isinstance(result["results"], list)
assert isinstance(result["results"][0], list)
assert isinstance(result["results"][0][0], str)
assert "positive" in result["results"][0][0].casefold()
assert "negative" in result["results"][1][0].casefold()

@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True)
def test_simple_pipeline_batch_no_query_multiple_doc_list(self, prompt_model):
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
pytest.skip("No API key found for OpenAI, skipping test")

node = PromptNode(prompt_model, default_prompt_template="sentiment-analysis", output_variable="out")

pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
result = pipe.run_batch(
queries=None,
documents=[
[Document("Berlin is an amazing city."), Document("Paris is an amazing city.")],
[Document("I am not feeling well.")],
],
)
assert isinstance(result["out"], list)
assert isinstance(result["out"][0], list)
assert isinstance(result["out"][0][0], str)
assert all("positive" in x.casefold() for x in result["out"][0])
assert "negative" in result["out"][1][0].casefold()

@pytest.mark.integration
@pytest.mark.parametrize("prompt_model", ["hf", "openai"], indirect=True)
def test_simple_pipeline_batch_query_multiple_doc_list(self, prompt_model):
if prompt_model.api_key is not None and not is_openai_api_key_set(prompt_model.api_key):
pytest.skip("No API key found for OpenAI, skipping test")

prompt_template = PromptTemplate(
name="question-answering-new",
prompt_text="Given the context please answer the question. Context: $documents; Question: $query; Answer:",
prompt_params=["documents", "query"],
)
node = PromptNode(prompt_model, default_prompt_template=prompt_template)

pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
result = pipe.run_batch(
queries=["Who lives in Berlin?"],
documents=[
[Document("My name is Carla and I live in Berlin"), Document("My name is James and I live in London")],
[Document("My name is Christelle and I live in Paris")],
],
debug=True,
)
assert isinstance(result["results"], list)
assert isinstance(result["results"][0], list)
assert isinstance(result["results"][0][0], str)
# TODO Finish


def test_HFLocalInvocationLayer_supports():
assert HFLocalInvocationLayer.supports("philschmid/flan-t5-base-samsum")
assert HFLocalInvocationLayer.supports("bigscience/T0_3B")

0 comments on commit d129598

Please sign in to comment.