From 6032fb05a194992e91c6e78a0a735f57f4b9cd60 Mon Sep 17 00:00:00 2001 From: Rowan Skewes Date: Thu, 24 Oct 2024 18:15:20 +1100 Subject: [PATCH 1/2] Add test for unbounded recursion on parse errors --- tests/unit/test_prompt.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/unit/test_prompt.py b/tests/unit/test_prompt.py index b9b1d4079..10db3b432 100644 --- a/tests/unit/test_prompt.py +++ b/tests/unit/test_prompt.py @@ -3,6 +3,7 @@ import pytest from langchain_core.outputs import Generation, LLMResult from langchain_core.prompt_values import StringPromptValue +from pydantic import BaseModel from ragas.llms.base import BaseRagasLLM from ragas.prompt import StringIO, StringPrompt @@ -203,3 +204,25 @@ def test_prompt_class_attributes(): p.examples = [] assert p.instruction != p_another_instance.instruction assert p.examples != p_another_instance.examples + + +@pytest.mark.asyncio +async def test_prompt_parse_retry(): + from ragas.prompt import PydanticPrompt, StringIO + from ragas.exceptions import RagasOutputParserException + + class OutputModel(BaseModel): + example: str + + class Prompt(PydanticPrompt[StringIO, OutputModel]): + instruction = "" + input_model = StringIO + output_model = OutputModel + + echo_llm = EchoLLM(run_config=RunConfig()) + prompt = Prompt() + with pytest.raises(RagasOutputParserException): + await prompt.generate( + data=StringIO(text="this prompt will be echoed back as invalid JSON"), + llm=echo_llm, + ) From d3798f2e36b53bad97909bb4f95ef7d06a1ebb1a Mon Sep 17 00:00:00 2001 From: Rowan Skewes Date: Thu, 24 Oct 2024 17:33:30 +1100 Subject: [PATCH 2/2] Fix: Limit number of retries for parse failures When parsing of an LLM response fails, the invalid output is sent to the LLM to be fixed. This fix threads the number of retries through this call, preventing unbounded recursion. --- src/ragas/exceptions.py | 4 ++-- src/ragas/prompt/pydantic_prompt.py | 24 +++++++++++++----------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/ragas/exceptions.py b/src/ragas/exceptions.py index 2f8327035..122553639 100644 --- a/src/ragas/exceptions.py +++ b/src/ragas/exceptions.py @@ -26,9 +26,9 @@ class RagasOutputParserException(RagasException): Exception raised when the output parser fails to parse the output. """ - def __init__(self, num_retries: int): + def __init__(self): msg = ( - f"The output parser failed to parse the output after {num_retries} retries." + "The output parser failed to parse the output including retries." ) super().__init__(msg) diff --git a/src/ragas/prompt/pydantic_prompt.py b/src/ragas/prompt/pydantic_prompt.py index 3ff2c1469..1bbf75563 100644 --- a/src/ragas/prompt/pydantic_prompt.py +++ b/src/ragas/prompt/pydantic_prompt.py @@ -93,6 +93,7 @@ async def generate( temperature: t.Optional[float] = None, stop: t.Optional[t.List[str]] = None, callbacks: t.Optional[Callbacks] = None, + retries_left: int = 3, ) -> OutputModel: """ Generate a single output using the provided language model and input data. @@ -111,6 +112,8 @@ async def generate( A list of stop sequences to end generation. callbacks : Callbacks, optional Callback functions to be called during the generation process. + retries_left : int, optional + Number of retry attempts for an invalid LLM response Returns ------- @@ -131,6 +134,7 @@ async def generate( temperature=temperature, stop=stop, callbacks=callbacks, + retries_left=retries_left, ) return output_single[0] @@ -142,6 +146,7 @@ async def generate_multiple( temperature: t.Optional[float] = None, stop: t.Optional[t.List[str]] = None, callbacks: t.Optional[Callbacks] = None, + retries_left: int = 3, ) -> t.List[OutputModel]: """ Generate multiple outputs using the provided language model and input data. @@ -160,6 +165,8 @@ async def generate_multiple( A list of stop sequences to end generation. callbacks : Callbacks, optional Callback functions to be called during the generation process. + retries_left : int, optional + Number of retry attempts for an invalid LLM response Returns ------- @@ -198,7 +205,7 @@ async def generate_multiple( prompt_value=prompt_value, llm=llm, callbacks=prompt_cb, - max_retries=3, + retries_left=retries_left, ) processed_output = self.process_output(answer, data) # type: ignore output_models.append(processed_output) @@ -390,14 +397,14 @@ async def parse_output_string( prompt_value: PromptValue, llm: BaseRagasLLM, callbacks: Callbacks, - max_retries: int = 1, + retries_left: int = 1, ): callbacks = callbacks or [] try: jsonstr = extract_json(output_string) result = super().parse(jsonstr) except OutputParserException: - if max_retries != 0: + if retries_left != 0: retry_rm, retry_cb = new_group( name="fix_output_format", inputs={"output_string": output_string}, @@ -410,17 +417,12 @@ async def parse_output_string( prompt_value=prompt_value.to_string(), ), callbacks=retry_cb, + retries_left = retries_left - 1, ) retry_rm.on_chain_end({"fixed_output_string": fixed_output_string}) - return await self.parse_output_string( - output_string=fixed_output_string.text, - prompt_value=prompt_value, - llm=llm, - max_retries=max_retries - 1, - callbacks=callbacks, - ) + result = fixed_output_string else: - raise RagasOutputParserException(num_retries=max_retries) + raise RagasOutputParserException() return result