Skip to content

Commit

Permalink
Reasking logic on validations (#98)
Browse files Browse the repository at this point in the history
* working cleaned up patch

* Reasking logic

* clean up

* remove

* clean up tests
  • Loading branch information
jxnl authored Sep 8, 2023
1 parent cffbb04 commit 1cc45e3
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 68 deletions.
82 changes: 82 additions & 0 deletions docs/reask.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Reasking When Validation Fails

Validators are a great tool for ensuring some property of the outputs. When you use the `patch()` method with the `openai` client, you can use the `max_retries` parameter to set the number of times you can reask. This allows the client to reattempt the API call a specified number of times if validation fails. Its another layer of defense against bad outputs of two forms.

1. Pydantic Validation Errors
2. JSON Decoding Errors

## Future Improvements

!!! notes "Contributions Welcome"
The current retry mechanism relies on a while loop. For a more robust solution, contributions to integrate the `tenacity` library are welcome.

## Example: Using Validators for Reasking

The example utilizes Pydantic's field validators in tandem with the `max_retries` parameter. In this example if the `name` field fails validation, the `openai` client will reattempt the API call. Here we use a plain validator, but we can also use [llms for validation](validation.md)

### Step 1: Define the Response Model with Validators

```python
import instructor
from pydantic import BaseModel, field_validator

# Apply the patch to the OpenAI client
instructor.patch()

class UserDetails(BaseModel):
name: str
age: int

@field_validator("name")
@classmethod
def validate_name(cls, v):
if v.upper() != v:
raise ValueError("Name must be in uppercase.")
return v
```

Here, the `UserDetails` class includes a validator for the `name` attribute. The validator checks that the name is in uppercase and raises a `ValueError` otherwise.

### Step 2: Exception Handling and Reasking

When validation fails, several steps are taken:

1. The existing messages are retained for the new API request.
2. The previous function call's response is added back.
3. A user prompt is included to reask the model, with details on the error.

```python
try:
...
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(dict(**response.choices[0].message))
kwargs["messages"].append(
{
"role": "user",
"content": f"Please correct the function call; errors encountered:\n{e}",
}
)
```

## Using the Client with Retries

Here, the `UserDetails` model is passed as the `response_model`, and `max_retries` is set to 2.

```python
model = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
response_model=UserDetails,
max_retries=2,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)

assert model.name == "JASON"
```

The `max_retries` parameter will trigger up to 2 reattempts if the `name` attribute fails the uppercase validation in `UserDetails`.

## Takeaways

Instead of framing "self-critique" or "self-reflection" in AI as new concepts, we can view them as validation errors with clear error messages that the systen can use to heal. This approach leverages existing programming practices for error handling, avoiding the need for new methodologies. We simplify the issue into code we already know how to write and leverage pydantic's powerful validation system to do so.
6 changes: 3 additions & 3 deletions docs/validation.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Introduction to Validation in Pydantic and LLMs

Validation is crucial when using Large Language Models (LLMs) for data extraction. It ensures data integrity, enables reasking for better results, and allows for overwriting incorrect values. Pydantic offers versatile validation capabilities suitable for use with LLM outputs.
Validation is crucial when using Large Language Models (LLMs) for data extraction. It ensures data integrity, enables [reasking for better results](reask.md), and allows for overwriting incorrect values. Pydantic offers versatile validation capabilities suitable for use with LLM outputs.


!!! note "Pydantic Validation Docs"
Expand All @@ -14,14 +14,14 @@ Validation is crucial when using Large Language Models (LLMs) for data extractio
## Importance of LLM Validation

- **Data Integrity**: Enforces data quality standards.
- **Reasking**: Utilizes Pydantic's error messages to improve LLM outputs.
- **[Reasking](reask.md)**: Utilizes Pydantic's error messages to improve LLM outputs.
- **Overwriting**: Overwrites incorrect values during API calls.

## Code Examples

### Simple Validation with Pydantic

The example uses a custom validator function to enforce a rule on the name attribute. If a user fails to input a full name (first and last name separated by a space), Pydantic will raise a validation error. This is useful for pre-processing data generated or extracted by an LLM. In the future, we can use this error to reask the model when appropriate.
The example uses a custom validator function to enforce a rule on the name attribute. If a user fails to input a full name (first and last name separated by a space), Pydantic will raise a validation error. If you want the LLM to automatically fix the error check out our [reasking docs.](reask.md)

```python
from pydantic import BaseModel, ValidationError
Expand Down
178 changes: 115 additions & 63 deletions instructor/patch.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,14 @@
from functools import wraps
from json import JSONDecodeError
from pydantic import ValidationError
import openai
import inspect
from typing import Callable, Optional, Type, Union
from typing import Callable, Type

from pydantic import BaseModel
from .function_calls import OpenAISchema, openai_schema


def wrap_chatcompletion(func: Callable) -> Callable:
is_async = inspect.iscoroutinefunction(func)
if is_async:

@wraps(func)
async def new_chatcompletion(
*args,
response_model: Optional[Union[Type[BaseModel], Type[OpenAISchema]]] = None,
**kwargs
): # type: ignore
if response_model is not None:
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model)
kwargs["functions"] = [response_model.openai_schema]
kwargs["function_call"] = {"name": response_model.openai_schema["name"]}

if kwargs.get("stream", False) and response_model is not None:
import warnings

warnings.warn(
"stream=True is not supported when using response_model parameter"
)

response = await func(*args, **kwargs)

if response_model is not None:
model = response_model.from_response(response)
model._raw_response = response
return model
return response

else:

@wraps(func)
def new_chatcompletion(
*args,
response_model: Optional[Union[Type[BaseModel], Type[OpenAISchema]]] = None,
**kwargs
):
if response_model is not None:
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model)
kwargs["functions"] = [response_model.openai_schema]
kwargs["function_call"] = {"name": response_model.openai_schema["name"]}

if kwargs.get("stream", False) and response_model is not None:
import warnings

warnings.warn(
"stream=True is not supported when using response_model parameter"
)

response = func(*args, **kwargs)
if response_model is not None:
model = response_model.from_response(response)
model._raw_response = response
return model
return response

new_chatcompletion.__doc__ = """
OVERRIDE_DOCS = """
Creates a new chat completion for the provided messages and parameters.
See: https://platform.openai.com/docs/api-reference/chat-completions/create
Expand All @@ -82,8 +24,118 @@ def new_chatcompletion(
Parameters:
response_model (Union[Type[BaseModel], Type[OpenAISchema]]): The response model to use for parsing the response from OpenAI's API, if available (default: None)
max_retries (int): The maximum number of retries to attempt if the response is not valid (default: 0)
"""
return new_chatcompletion


def handle_response_model(response_model: Type[BaseModel], kwargs):
new_kwargs = kwargs.copy()
if response_model is not None:
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model) # type: ignore
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore

if new_kwargs.get("stream", False) and response_model is not None:
import warnings

warnings.warn(
"stream=True is not supported when using response_model parameter"
)

return response_model, new_kwargs


def process_response(response, response_model): # type: ignore
if response_model is not None:
model = response_model.from_response(response)
model._raw_response = response
return model
return response


async def retry_async(func, response_model, args, kwargs, max_retries):
retries = 0
while retries <= max_retries:
try:
response = await func(*args, **kwargs)
return process_response(response, response_model), None
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(dict(**response.choices[0].message))
kwargs["messages"].append(
{
"role": "user",
"content": f"Recall the function correctly, exceptions found\n{e}",
}
)
retries += 1
if retries > max_retries:
raise e


def retry_sync(func, response_model, args, kwargs, max_retries):
retries = 0
new_kwargs = kwargs.copy()
while retries <= max_retries:
# Excepts ValidationError, and JSONDecodeError
try:
response = func(*args, **kwargs)
return process_response(response, response_model), None
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(dict(**response.choices[0].message))
kwargs["messages"].append(
{
"role": "user",
"content": f"Recall the function correctly, exceptions found\n{e}",
}
)
retries += 1
if retries > max_retries:
raise e


def wrap_chatcompletion(func: Callable) -> Callable:
is_async = inspect.iscoroutinefunction(func)

@wraps(func)
async def new_chatcompletion_async(response_model, *args, max_retries=0, **kwargs):
response_model, new_kwargs = handle_response_model(response_model, kwargs)
response, error = await retry_async(
func=func,
response_model=response_model,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
) # type: ignore
if error:
raise ValueError(error)
return process_response(response, response_model)

@wraps(func)
def new_chatcompletion_sync(response_model, *args, max_retries=0, **kwargs):
response_model, new_kwargs = handle_response_model(response_model, kwargs)
response, error = retry_sync(
func=func,
response_model=response_model,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
) # type: ignore
if error:
raise ValueError(error)
return response

wrapper_function = new_chatcompletion_async if is_async else new_chatcompletion_sync
wrapper_function.__doc__ = OVERRIDE_DOCS
return wrapper_function


def process_response(response, response_model):
if response_model is not None:
model = response_model.from_response(response)
model._raw_response = response
return model
return response


original_chatcompletion = openai.ChatCompletion.create
Expand Down
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ nav:
- Introduction:
- Getting Started: 'index.md'
- Prompt Engineering Tips: 'tips/index.md'
- Meta Functions:
- Helpers:
- Validations (self critique): "validation.md"
- Reasking via Validators: "reask.md"
- Multiple Extractions: "multitask.md"
- Handling Missing Content: "maybe.md"
- Philosophy: 'philosophy.md'
Expand Down
34 changes: 33 additions & 1 deletion tests/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from instructor import patch


@pytest.mark.skip(reason="Needs openai call")
@pytest.mark.skip("Not implemented")
def test_runmodel():
patch()

Expand All @@ -24,3 +24,35 @@ class UserExtract(BaseModel):
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"


@pytest.mark.skip("Not implemented")
def test_runmodel_validator():
patch()

from pydantic import field_validator

class UserExtract(BaseModel):
name: str
age: int

@field_validator("name")
@classmethod
def validate_name(cls, v):
if v.upper() != v:
raise ValueError("Name should be uppercase")
return v

model = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
response_model=UserExtract,
max_retries=2,
messages=[
{"role": "user", "content": "Extract jason is 25 years old"},
],
)
assert isinstance(model, UserExtract), "Should be instance of UserExtract"
assert model.name == "JASON"
assert hasattr(
model, "_raw_response"
), "The raw response should be available from OpenAI"

0 comments on commit 1cc45e3

Please sign in to comment.