From 359c5f929553bbdc9cfdb7ea00ab8a7140ec2755 Mon Sep 17 00:00:00 2001 From: Jason Liu Date: Sat, 25 Nov 2023 13:56:52 -0500 Subject: [PATCH] Add multiple modalities: tools, functions, json_mode (#218) Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- docs/concepts/patching.md | 70 +++++++++++++++ instructor/function_calls.py | 48 +++++++++-- instructor/patch.py | 160 ++++++++++++++++++++++++----------- mkdocs.yml | 1 + tests/openai/test_modes.py | 60 +++++++++++++ tests/openai/test_patch.py | 99 ++++++++++------------ tests/test_validators.py | 31 +++++-- 7 files changed, 358 insertions(+), 111 deletions(-) create mode 100644 docs/concepts/patching.md create mode 100644 tests/openai/test_modes.py diff --git a/docs/concepts/patching.md b/docs/concepts/patching.md new file mode 100644 index 000000000..16fb25b1b --- /dev/null +++ b/docs/concepts/patching.md @@ -0,0 +1,70 @@ +# Patching + +Instructor enhances client functionality with three new keywords for backwards compatibility. This allows use of the enhanced client as usual, with structured output benefits. + +- `response_model`: Defines the response type for `chat.completions.create`. +- `max_retries`: Determines retry attempts for failed `chat.completions.create` validations. +- `validation_context`: Provides extra context to the validation process. + +There are three methods for structured output: + +1. **Function Calling**: The primary method. Use this for stability and testing. +2. **Tool Calling**: Useful in specific scenarios; lacks the reasking feature of OpenAI's tool calling API. +3. **JSON Mode**: Offers closer adherence to JSON but with more potential validation errors. Suitable for specific non-function calling clients. + +## Function Calling + +```python +from openai import OpenAI +import instructor + +client = instructor.patch(OpenAI()) +``` + +## Tool Calling + +```python +import instructor +from instructor import Mode + +client = instructor.patch(OpenAI(), mode=Mode.TOOL_CALL) +``` + +## JSON Mode + +```python +import instructor +from instructor import Mode +from openai import OpenAI + +client = instructor.patch(OpenAI(), mode=Mode.JSON) +``` + +### Schema Integration + +In JSON Mode, the schema is part of the system message: + +```python +import instructor +from openai import OpenAI + +client = instructor.patch(OpenAI()) + +response = client.chat.completions.create( + model="gpt-3.5-turbo-1106", + response_format={"type": "json_object"}, + messages=[ + { + "role": "system", + "content": f"Match your response to this json_schema: \n{UserExtract.model_json_schema()['properties']}", + }, + { + "role": "user", + "content": "Extract jason is 25 years old", + }, + ], +) +user = UserExtract.from_response(response, mode=Mode.JSON) +assert user.name.lower() == "jason" +assert user.age == 25 +``` diff --git a/instructor/function_calls.py b/instructor/function_calls.py index 5f82a38e6..a14482db4 100644 --- a/instructor/function_calls.py +++ b/instructor/function_calls.py @@ -1,9 +1,20 @@ +from calendar import c import json from docstring_parser import parse from functools import wraps from typing import Any, Callable from pydantic import BaseModel, create_model, validate_arguments +import enum + + +class Mode(enum.Enum): + """The mode to use for patching the client""" + + FUNCTIONS: str = "function_call" + TOOLS: str = "tool_call" + JSON: str = "json_mode" + class openai_function: """ @@ -176,9 +187,9 @@ def openai_schema(cls): def from_response( cls, completion, - throw_error: bool = True, validation_context=None, strict: bool = None, + mode: Mode = Mode.FUNCTIONS, ): """Execute the function from the response of an openai chat completion @@ -193,11 +204,36 @@ def from_response( """ message = completion.choices[0].message - return cls.model_validate_json( - message.function_call.arguments, - context=validation_context, - strict=strict, - ) + if mode == Mode.FUNCTIONS: + assert ( + message.function_call.name == cls.openai_schema["name"] + ), "Function name does not match" + return cls.model_validate_json( + message.function_call.arguments, + context=validation_context, + strict=strict, + ) + elif mode == Mode.TOOLS: + assert ( + len(message.tool_calls) == 1 + ), "Instructor does not support multiple tool calls, use List[Model] instead." + tool_call = message.tool_calls[0] + assert ( + tool_call.function.name == cls.openai_schema["name"] + ), "Tool name does not match" + return cls.model_validate_json( + tool_call.function.arguments, + context=validation_context, + strict=strict, + ) + elif mode == Mode.JSON: + return cls.model_validate_json( + message.content, + context=validation_context, + strict=strict, + ) + else: + raise ValueError(f"Invalid patch mode: {mode}") def openai_schema(cls) -> OpenAISchema: diff --git a/instructor/patch.py b/instructor/patch.py index 776a35ab4..1830a8a0b 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -4,10 +4,12 @@ from typing import Callable, Optional, Type, Union from openai import AsyncOpenAI, OpenAI -from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat import ChatCompletion from pydantic import BaseModel, ValidationError -from .function_calls import OpenAISchema, openai_schema +from .function_calls import OpenAISchema, openai_schema, Mode + +import warnings OVERRIDE_DOCS = """ Creates a new chat completion for the provided messages and parameters. @@ -29,16 +31,68 @@ """ -def handle_response_model(response_model: Type[BaseModel], kwargs): +def dump_message(message) -> dict: + """Dumps a message to a dict, to be returned to the OpenAI API. + Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests + if it isn't used. + """ + dumped_message = message.model_dump() + if not dumped_message.get("tool_calls"): + del dumped_message["tool_calls"] + return {k: v for k, v in dumped_message.items() if v} + + +def handle_response_model( + *, + response_model: Type[BaseModel], + kwargs, + mode: Mode = Mode.FUNCTIONS, +): new_kwargs = kwargs.copy() if response_model is not None: if not issubclass(response_model, OpenAISchema): response_model = openai_schema(response_model) # type: ignore - new_kwargs["functions"] = [response_model.openai_schema] # type: ignore - new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore + + if mode == Mode.FUNCTIONS: + new_kwargs["functions"] = [response_model.openai_schema] # type: ignore + new_kwargs["function_call"] = { + "name": response_model.openai_schema["name"] + } # type: ignore + elif mode == Mode.TOOLS: + new_kwargs["tools"] = [ + { + "type": "function", + "function": response_model.openai_schema, + } + ] + new_kwargs["tool_choice"] = { + "type": "function", + "function": {"name": response_model.openai_schema["name"]}, + } + elif mode == Mode.JSON: + new_kwargs["response_format"] = {"type": "json_object"} + + # check that the first message is a system message + # if it is not, add a system message to the beginning + message = f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{response_model.model_json_schema()['properties']}" + + if new_kwargs["messages"][0]["role"] != "system": + new_kwargs["messages"].insert( + 0, + { + "role": "system", + "content": message, + }, + ) + + # if the first message is a system append the schema to the end + if new_kwargs["messages"][0]["role"] == "system": + new_kwargs["messages"][0]["content"] += f"\n\n{message}" + else: + raise ValueError(f"Invalid patch mode: {mode}") if new_kwargs.get("stream", False) and response_model is not None: - import warnings + raise NotImplementedError("stream=True is not supported when using response_model parameter") warnings.warn( "stream=True is not supported when using response_model parameter" @@ -48,7 +102,12 @@ def handle_response_model(response_model: Type[BaseModel], kwargs): def process_response( - response, response_model, validation_context: dict = None, strict=None + response, + *, + response_model: Type[BaseModel], + validation_context: dict = None, + strict=None, + mode: Mode = Mode.FUNCTIONS, ): # type: ignore """Processes a OpenAI response with the response model, if available It can use `validation_context` and `strict` to validate the response @@ -62,25 +121,13 @@ def process_response( """ if response_model is not None: model = response_model.from_response( - response, validation_context=validation_context, strict=strict + response, validation_context=validation_context, strict=strict, mode=mode ) model._raw_response = response return model return response -def dump_message(message: ChatCompletionMessage) -> dict: - """Dumps a message to a dict, to be returned to the OpenAI API. - - Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests - if it isn't used. - """ - dumped_message = message.model_dump() - if not dumped_message.get("tool_calls"): - del dumped_message["tool_calls"] - return dumped_message - - async def retry_async( func, response_model, @@ -89,22 +136,21 @@ async def retry_async( kwargs, max_retries, strict: Optional[bool] = None, + mode: Mode = Mode.FUNCTIONS, ): retries = 0 while retries <= max_retries: try: response: ChatCompletion = await func(*args, **kwargs) - return ( - process_response( - response, - response_model, - validation_context, - strict=strict, - ), - None, + return process_response( + response, + response_model=response_model, + validation_context=validation_context, + strict=strict, + mode=mode, ) except (ValidationError, JSONDecodeError) as e: - kwargs["messages"].append(response.choices[0].message) # type: ignore + kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore kwargs["messages"].append( { "role": "user", @@ -124,20 +170,22 @@ def retry_sync( kwargs, max_retries, strict: Optional[bool] = None, + mode: Mode = Mode.FUNCTIONS, ): retries = 0 while retries <= max_retries: # Excepts ValidationError, and JSONDecodeError try: response = func(*args, **kwargs) - return ( - process_response( - response, response_model, validation_context, strict=strict - ), - None, + return process_response( + response, + response_model=response_model, + validation_context=validation_context, + strict=strict, + mode=mode, ) except (ValidationError, JSONDecodeError) as e: - kwargs["messages"].append(dump_message(response.choices[0].message)) + kwargs["messages"].append(response.choices[0].message) kwargs["messages"].append( { "role": "user", @@ -156,7 +204,9 @@ def is_async(func: Callable) -> bool: ) -def wrap_chatcompletion(func: Callable) -> Callable: +def wrap_chatcompletion( + func: Callable, mode: Mode = Mode.FUNCTIONS +) -> Callable: func_is_async = is_async(func) @wraps(func) @@ -167,17 +217,22 @@ async def new_chatcompletion_async( *args, **kwargs, ): - response_model, new_kwargs = handle_response_model(response_model, kwargs) # type: ignore - response, error = await retry_async( + if mode == Mode.TOOLS: + max_retries = 0 + warnings.warn("max_retries is not supported when using tool calls") + + response_model, new_kwargs = handle_response_model( + response_model=response_model, kwargs=kwargs, mode=mode + ) # type: ignore + response = await retry_async( func=func, response_model=response_model, validation_context=validation_context, max_retries=max_retries, args=args, kwargs=new_kwargs, + mode=mode, ) # type: ignore - if error: - raise ValueError(error) return response @wraps(func) @@ -188,17 +243,22 @@ def new_chatcompletion_sync( *args, **kwargs, ): - response_model, new_kwargs = handle_response_model(response_model, kwargs) # type: ignore - response, error = retry_sync( + if mode == Mode.TOOLS: + max_retries = 0 + warnings.warn("max_retries is not supported when using tool calls") + + response_model, new_kwargs = handle_response_model( + response_model=response_model, kwargs=kwargs, mode=mode + ) # type: ignore + response = retry_sync( func=func, response_model=response_model, validation_context=validation_context, max_retries=max_retries, args=args, kwargs=new_kwargs, + mode=mode, ) # type: ignore - if error: - raise ValueError(error) return response wrapper_function = ( @@ -208,7 +268,9 @@ def new_chatcompletion_sync( return wrapper_function -def patch(client: Union[OpenAI, AsyncOpenAI]): +def patch( + client: Union[OpenAI, AsyncOpenAI], mode: Mode = Mode.FUNCTIONS +): """ Patch the `client.chat.completions.create` method @@ -220,11 +282,13 @@ def patch(client: Union[OpenAI, AsyncOpenAI]): - `strict` parameter to use strict json parsing """ - client.chat.completions.create = wrap_chatcompletion(client.chat.completions.create) + client.chat.completions.create = wrap_chatcompletion( + client.chat.completions.create, mode=mode + ) return client -def apatch(client: AsyncOpenAI): +def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS): """ No longer necessary, use `patch` instead. @@ -237,4 +301,4 @@ def apatch(client: AsyncOpenAI): - `validation_context` parameter to validate the response using the pydantic model - `strict` parameter to use strict json parsing """ - return patch(client) + return patch(client, mode=mode) diff --git a/mkdocs.yml b/mkdocs.yml index c72cf8baa..14cd691d9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -131,6 +131,7 @@ nav: - Models: 'concepts/models.md' - Fields: 'concepts/fields.md' - Types: 'concepts/types.md' + - Patching: 'concepts/patching.md' - Streaming: "concepts/lists.md" - Union: 'concepts/union.md' - Alias: 'concepts/alias.md' diff --git a/tests/openai/test_modes.py b/tests/openai/test_modes.py new file mode 100644 index 000000000..ff92c0460 --- /dev/null +++ b/tests/openai/test_modes.py @@ -0,0 +1,60 @@ +from instructor.function_calls import OpenAISchema, Mode +from openai import OpenAI + + +client = OpenAI() + + +class UserExtract(OpenAISchema): + name: str + age: int + + +def test_tool_call(): + response = client.chat.completions.create( + model="gpt-3.5-turbo-1106", + messages=[ + { + "role": "user", + "content": "Extract jason is 25 years old, mary is 30 years old", + }, + ], + tools=[ + { + "type": "function", + "function": UserExtract.openai_schema, + } + ], + tool_choice={ + "type": "function", + "function": {"name": UserExtract.openai_schema["name"]}, + }, + ) + response_message = response.choices[0].message + tool_calls = response_message.tool_calls + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "UserExtract" + assert tool_calls[0].function + user = UserExtract.from_response(response, mode=Mode.TOOLS) + assert user.name.lower() == "jason" + assert user.age == 25 + + +def test_json_mode(): + response = client.chat.completions.create( + model="gpt-3.5-turbo-1106", + response_format={"type": "json_object"}, + messages=[ + { + "role": "system", + "content": f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{UserExtract.model_json_schema()['properties']}", + }, + { + "role": "user", + "content": "Extract jason is 25 years old", + }, + ], + ) + user = UserExtract.from_response(response, mode=Mode.JSON) + assert user.name.lower() == "jason" + assert user.age == 25 diff --git a/tests/openai/test_patch.py b/tests/openai/test_patch.py index b9e07ae6d..f6a8e3150 100644 --- a/tests/openai/test_patch.py +++ b/tests/openai/test_patch.py @@ -1,30 +1,27 @@ +from pydantic import BaseModel, field_validator import pytest import instructor -from instructor import llm_validator -from typing_extensions import Annotated -from pydantic import field_validator, BaseModel, BeforeValidator, ValidationError from openai import OpenAI, AsyncOpenAI -client = instructor.patch(OpenAI()) +from instructor.function_calls import Mode + aclient = instructor.patch(AsyncOpenAI()) +client = instructor.patch(OpenAI()) class UserExtract(BaseModel): name: str age: int - @field_validator("name") - @classmethod - def validate_name(cls, v): - if v.upper() != v: - raise ValueError("Name should be uppercase") - return v - -def test_runmodel_validator(): +@pytest.mark.parametrize( + "mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS] +) +def test_runmodel(mode): + client = instructor.patch(OpenAI(), mode=mode) model = client.chat.completions.create( - model="gpt-3.5-turbo", + model="gpt-3.5-turbo-1106", response_model=UserExtract, max_retries=2, messages=[ @@ -32,16 +29,21 @@ def test_runmodel_validator(): ], ) assert isinstance(model, UserExtract), "Should be instance of UserExtract" - assert model.name == "JASON" + assert model.name.lower() == "jason" + assert model.age == 25 assert hasattr( model, "_raw_response" ), "The raw response should be available from OpenAI" +@pytest.mark.parametrize( + "mode", [Mode.FUNCTIONS, Mode.JSON, Mode.TOOLS] +) @pytest.mark.asyncio -async def test_runmodel_async_validator(): +async def test_runmodel_async(mode): + aclient = instructor.patch(AsyncOpenAI(), mode=mode) model = await aclient.chat.completions.create( - model="gpt-3.5-turbo", + model="gpt-3.5-turbo-1106", response_model=UserExtract, max_retries=2, messages=[ @@ -49,64 +51,57 @@ async def test_runmodel_async_validator(): ], ) assert isinstance(model, UserExtract), "Should be instance of UserExtract" - assert model.name == "JASON" + assert model.name.lower() == "jason" + assert model.age == 25 assert hasattr( model, "_raw_response" ), "The raw response should be available from OpenAI" -class UserExtractSimple(BaseModel): +class UserExtractValidated(BaseModel): name: str age: int + @field_validator("name") + @classmethod + def validate_name(cls, v): + if v.upper() != v: + raise ValueError("Name should be uppercase") + return v -@pytest.mark.asyncio -async def test_async_runmodel(): - model = await aclient.chat.completions.create( - model="gpt-3.5-turbo", - response_model=UserExtractSimple, + +@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON]) +def test_runmodel_validator(mode): + client = instructor.patch(OpenAI(), mode=mode) + model = client.chat.completions.create( + model="gpt-3.5-turbo-1106", + response_model=UserExtractValidated, + max_retries=2, messages=[ {"role": "user", "content": "Extract jason is 25 years old"}, ], ) - assert isinstance( - model, UserExtractSimple - ), "Should be instance of UserExtractSimple" - assert model.name.lower() == "jason" + assert isinstance(model, UserExtractValidated), "Should be instance of UserExtract" + assert model.name == "JASON" assert hasattr( model, "_raw_response" ), "The raw response should be available from OpenAI" -def test_runmodel(): - model = client.chat.completions.create( - model="gpt-3.5-turbo", - response_model=UserExtractSimple, +@pytest.mark.parametrize("mode", [Mode.FUNCTIONS, Mode.JSON]) +@pytest.mark.asyncio +async def test_runmodel_async_validator(mode): + aclient = instructor.patch(AsyncOpenAI(), mode=mode) + model = await aclient.chat.completions.create( + model="gpt-3.5-turbo-1106", + response_model=UserExtractValidated, + max_retries=2, messages=[ {"role": "user", "content": "Extract jason is 25 years old"}, ], ) - assert isinstance( - model, UserExtractSimple - ), "Should be instance of UserExtractSimple" - assert model.name.lower() == "jason" + assert isinstance(model, UserExtractValidated), "Should be instance of UserExtract" + assert model.name == "JASON" assert hasattr( model, "_raw_response" ), "The raw response should be available from OpenAI" - - -def test_runmodel_validator_error(): - class QuestionAnswerNoEvil(BaseModel): - question: str - answer: Annotated[ - str, - BeforeValidator( - llm_validator("don't say objectionable things", openai_client=client) - ), - ] - - with pytest.raises(ValidationError): - QuestionAnswerNoEvil( - question="What is the meaning of life?", - answer="The meaning of life is to be evil and steal", - ) diff --git a/tests/test_validators.py b/tests/test_validators.py index a8302318b..da1c152b8 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,17 +1,38 @@ import pytest -import instructor +import instructor from typing_extensions import Annotated -from pydantic import BaseModel, AfterValidator, ValidationError +from pydantic import BaseModel, AfterValidator, BeforeValidator, ValidationError from openai import OpenAI +from instructor.dsl.validators import llm_validator + client = instructor.patch(OpenAI()) + def test_patch_completes_successfully(): class Response(BaseModel): - message: Annotated[str, AfterValidator(instructor.openai_moderation(client=client))] - + message: Annotated[ + str, AfterValidator(instructor.openai_moderation(client=client)) + ] with pytest.raises(ValidationError) as e: - Response(message="I want to make them suffer the consequences") \ No newline at end of file + Response(message="I want to make them suffer the consequences") + + +def test_runmodel_validator_error(): + class QuestionAnswerNoEvil(BaseModel): + question: str + answer: Annotated[ + str, + BeforeValidator( + llm_validator("don't say objectionable things", openai_client=client) + ), + ] + + with pytest.raises(ValidationError): + QuestionAnswerNoEvil( + question="What is the meaning of life?", + answer="The meaning of life is to be evil and steal", + )