From 1cc45e3faf5d59d4936df01616ba6876ed78a93c Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Fri, 8 Sep 2023 00:58:29 -0500 Subject: [PATCH] Reasking logic on validations (#98) * working cleaned up patch * Reasking logic * clean up * remove * clean up tests --- docs/reask.md | 82 ++++++++++++++++++++ docs/validation.md | 6 +- instructor/patch.py | 178 ++++++++++++++++++++++++++++---------------- mkdocs.yml | 3 +- tests/test_patch.py | 34 ++++++++- 5 files changed, 235 insertions(+), 68 deletions(-) create mode 100644 docs/reask.md diff --git a/docs/reask.md b/docs/reask.md new file mode 100644 index 000000000..0a898b112 --- /dev/null +++ b/docs/reask.md @@ -0,0 +1,82 @@ +# Reasking When Validation Fails + +Validators are a great tool for ensuring some property of the outputs. When you use the `patch()` method with the `openai` client, you can use the `max_retries` parameter to set the number of times you can reask. This allows the client to reattempt the API call a specified number of times if validation fails. Its another layer of defense against bad outputs of two forms. + +1. Pydantic Validation Errors +2. JSON Decoding Errors + +## Future Improvements + +!!! notes "Contributions Welcome" + The current retry mechanism relies on a while loop. For a more robust solution, contributions to integrate the `tenacity` library are welcome. + +## Example: Using Validators for Reasking + +The example utilizes Pydantic's field validators in tandem with the `max_retries` parameter. In this example if the `name` field fails validation, the `openai` client will reattempt the API call. Here we use a plain validator, but we can also use [llms for validation](validation.md) + +### Step 1: Define the Response Model with Validators + +```python +import instructor +from pydantic import BaseModel, field_validator + +# Apply the patch to the OpenAI client +instructor.patch() + +class UserDetails(BaseModel): + name: str + age: int + + @field_validator("name") + @classmethod + def validate_name(cls, v): + if v.upper() != v: + raise ValueError("Name must be in uppercase.") + return v +``` + +Here, the `UserDetails` class includes a validator for the `name` attribute. The validator checks that the name is in uppercase and raises a `ValueError` otherwise. + +### Step 2: Exception Handling and Reasking + +When validation fails, several steps are taken: + +1. The existing messages are retained for the new API request. +2. The previous function call's response is added back. +3. A user prompt is included to reask the model, with details on the error. + +```python +try: + ... +except (ValidationError, JSONDecodeError) as e: + kwargs["messages"].append(dict(**response.choices[0].message)) + kwargs["messages"].append( + { + "role": "user", + "content": f"Please correct the function call; errors encountered:\n{e}", + } + ) +``` + +## Using the Client with Retries + +Here, the `UserDetails` model is passed as the `response_model`, and `max_retries` is set to 2. + +```python +model = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + response_model=UserDetails, + max_retries=2, + messages=[ + {"role": "user", "content": "Extract jason is 25 years old"}, + ], +) + +assert model.name == "JASON" +``` + +The `max_retries` parameter will trigger up to 2 reattempts if the `name` attribute fails the uppercase validation in `UserDetails`. + +## Takeaways + +Instead of framing "self-critique" or "self-reflection" in AI as new concepts, we can view them as validation errors with clear error messages that the systen can use to heal. This approach leverages existing programming practices for error handling, avoiding the need for new methodologies. We simplify the issue into code we already know how to write and leverage pydantic's powerful validation system to do so. \ No newline at end of file diff --git a/docs/validation.md b/docs/validation.md index 7b635a84c..8676c2616 100644 --- a/docs/validation.md +++ b/docs/validation.md @@ -1,6 +1,6 @@ # Introduction to Validation in Pydantic and LLMs -Validation is crucial when using Large Language Models (LLMs) for data extraction. It ensures data integrity, enables reasking for better results, and allows for overwriting incorrect values. Pydantic offers versatile validation capabilities suitable for use with LLM outputs. +Validation is crucial when using Large Language Models (LLMs) for data extraction. It ensures data integrity, enables [reasking for better results](reask.md), and allows for overwriting incorrect values. Pydantic offers versatile validation capabilities suitable for use with LLM outputs. !!! note "Pydantic Validation Docs" @@ -14,14 +14,14 @@ Validation is crucial when using Large Language Models (LLMs) for data extractio ## Importance of LLM Validation - **Data Integrity**: Enforces data quality standards. -- **Reasking**: Utilizes Pydantic's error messages to improve LLM outputs. +- **[Reasking](reask.md)**: Utilizes Pydantic's error messages to improve LLM outputs. - **Overwriting**: Overwrites incorrect values during API calls. ## Code Examples ### Simple Validation with Pydantic -The example uses a custom validator function to enforce a rule on the name attribute. If a user fails to input a full name (first and last name separated by a space), Pydantic will raise a validation error. This is useful for pre-processing data generated or extracted by an LLM. In the future, we can use this error to reask the model when appropriate. +The example uses a custom validator function to enforce a rule on the name attribute. If a user fails to input a full name (first and last name separated by a space), Pydantic will raise a validation error. If you want the LLM to automatically fix the error check out our [reasking docs.](reask.md) ```python from pydantic import BaseModel, ValidationError diff --git a/instructor/patch.py b/instructor/patch.py index 2a66fc47b..79f247c5b 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -1,72 +1,14 @@ from functools import wraps +from json import JSONDecodeError +from pydantic import ValidationError import openai import inspect -from typing import Callable, Optional, Type, Union +from typing import Callable, Type from pydantic import BaseModel from .function_calls import OpenAISchema, openai_schema - -def wrap_chatcompletion(func: Callable) -> Callable: - is_async = inspect.iscoroutinefunction(func) - if is_async: - - @wraps(func) - async def new_chatcompletion( - *args, - response_model: Optional[Union[Type[BaseModel], Type[OpenAISchema]]] = None, - **kwargs - ): # type: ignore - if response_model is not None: - if not issubclass(response_model, OpenAISchema): - response_model = openai_schema(response_model) - kwargs["functions"] = [response_model.openai_schema] - kwargs["function_call"] = {"name": response_model.openai_schema["name"]} - - if kwargs.get("stream", False) and response_model is not None: - import warnings - - warnings.warn( - "stream=True is not supported when using response_model parameter" - ) - - response = await func(*args, **kwargs) - - if response_model is not None: - model = response_model.from_response(response) - model._raw_response = response - return model - return response - - else: - - @wraps(func) - def new_chatcompletion( - *args, - response_model: Optional[Union[Type[BaseModel], Type[OpenAISchema]]] = None, - **kwargs - ): - if response_model is not None: - if not issubclass(response_model, OpenAISchema): - response_model = openai_schema(response_model) - kwargs["functions"] = [response_model.openai_schema] - kwargs["function_call"] = {"name": response_model.openai_schema["name"]} - - if kwargs.get("stream", False) and response_model is not None: - import warnings - - warnings.warn( - "stream=True is not supported when using response_model parameter" - ) - - response = func(*args, **kwargs) - if response_model is not None: - model = response_model.from_response(response) - model._raw_response = response - return model - return response - - new_chatcompletion.__doc__ = """ +OVERRIDE_DOCS = """ Creates a new chat completion for the provided messages and parameters. See: https://platform.openai.com/docs/api-reference/chat-completions/create @@ -82,8 +24,118 @@ def new_chatcompletion( Parameters: response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None) + max_retries (int): The maximum number of retries to attempt if the response is not valid (default: 0) """ - return new_chatcompletion + + +def handle_response_model(response_model: Type[BaseModel], kwargs): + new_kwargs = kwargs.copy() + if response_model is not None: + if not issubclass(response_model, OpenAISchema): + response_model = openai_schema(response_model) # type: ignore + new_kwargs["functions"] = [response_model.openai_schema] # type: ignore + new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore + + if new_kwargs.get("stream", False) and response_model is not None: + import warnings + + warnings.warn( + "stream=True is not supported when using response_model parameter" + ) + + return response_model, new_kwargs + + +def process_response(response, response_model): # type: ignore + if response_model is not None: + model = response_model.from_response(response) + model._raw_response = response + return model + return response + + +async def retry_async(func, response_model, args, kwargs, max_retries): + retries = 0 + while retries <= max_retries: + try: + response = await func(*args, **kwargs) + return process_response(response, response_model), None + except (ValidationError, JSONDecodeError) as e: + kwargs["messages"].append(dict(**response.choices[0].message)) + kwargs["messages"].append( + { + "role": "user", + "content": f"Recall the function correctly, exceptions found\n{e}", + } + ) + retries += 1 + if retries > max_retries: + raise e + + +def retry_sync(func, response_model, args, kwargs, max_retries): + retries = 0 + new_kwargs = kwargs.copy() + while retries <= max_retries: + # Excepts ValidationError, and JSONDecodeError + try: + response = func(*args, **kwargs) + return process_response(response, response_model), None + except (ValidationError, JSONDecodeError) as e: + kwargs["messages"].append(dict(**response.choices[0].message)) + kwargs["messages"].append( + { + "role": "user", + "content": f"Recall the function correctly, exceptions found\n{e}", + } + ) + retries += 1 + if retries > max_retries: + raise e + + +def wrap_chatcompletion(func: Callable) -> Callable: + is_async = inspect.iscoroutinefunction(func) + + @wraps(func) + async def new_chatcompletion_async(response_model, *args, max_retries=0, **kwargs): + response_model, new_kwargs = handle_response_model(response_model, kwargs) + response, error = await retry_async( + func=func, + response_model=response_model, + max_retries=max_retries, + args=args, + kwargs=new_kwargs, + ) # type: ignore + if error: + raise ValueError(error) + return process_response(response, response_model) + + @wraps(func) + def new_chatcompletion_sync(response_model, *args, max_retries=0, **kwargs): + response_model, new_kwargs = handle_response_model(response_model, kwargs) + response, error = retry_sync( + func=func, + response_model=response_model, + max_retries=max_retries, + args=args, + kwargs=new_kwargs, + ) # type: ignore + if error: + raise ValueError(error) + return response + + wrapper_function = new_chatcompletion_async if is_async else new_chatcompletion_sync + wrapper_function.__doc__ = OVERRIDE_DOCS + return wrapper_function + + +def process_response(response, response_model): + if response_model is not None: + model = response_model.from_response(response) + model._raw_response = response + return model + return response original_chatcompletion = openai.ChatCompletion.create diff --git a/mkdocs.yml b/mkdocs.yml index 764c0745f..848d1745f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -46,8 +46,9 @@ nav: - Introduction: - Getting Started: 'index.md' - Prompt Engineering Tips: 'tips/index.md' - - Meta Functions: + - Helpers: - Validations (self critique): "validation.md" + - Reasking via Validators: "reask.md" - Multiple Extractions: "multitask.md" - Handling Missing Content: "maybe.md" - Philosophy: 'philosophy.md' diff --git a/tests/test_patch.py b/tests/test_patch.py index b3d157504..9769a875f 100644 --- a/tests/test_patch.py +++ b/tests/test_patch.py @@ -4,7 +4,7 @@ from instructor import patch -@pytest.mark.skip(reason="Needs openai call") +@pytest.mark.skip("Not implemented") def test_runmodel(): patch() @@ -24,3 +24,35 @@ class UserExtract(BaseModel): assert hasattr( model, "_raw_response" ), "The raw response should be available from OpenAI" + + +@pytest.mark.skip("Not implemented") +def test_runmodel_validator(): + patch() + + from pydantic import field_validator + + class UserExtract(BaseModel): + name: str + age: int + + @field_validator("name") + @classmethod + def validate_name(cls, v): + if v.upper() != v: + raise ValueError("Name should be uppercase") + return v + + model = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + response_model=UserExtract, + max_retries=2, + messages=[ + {"role": "user", "content": "Extract jason is 25 years old"}, + ], + ) + assert isinstance(model, UserExtract), "Should be instance of UserExtract" + assert model.name == "JASON" + assert hasattr( + model, "_raw_response" + ), "The raw response should be available from OpenAI"