Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multiple modalities: tools, functions, json_mode #218

Merged
merged 6 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from calendar import c
import json
from docstring_parser import parse
from functools import wraps
from typing import Any, Callable
from pydantic import BaseModel, create_model, validate_arguments

import enum


class PatchMode(enum.Enum):
"""The mode to use for patching the client"""

FUNCTION_CALL: str = "function_call"
TOOL_CALL: str = "tool_call"
JSON_MODE: str = "json_mode"


class openai_function:
"""
Expand Down Expand Up @@ -176,9 +187,9 @@ def openai_schema(cls):
def from_response(
cls,
completion,
throw_error: bool = True,
validation_context=None,
strict: bool = None,
mode: PatchMode = PatchMode.FUNCTION_CALL,
):
"""Execute the function from the response of an openai chat completion

Expand All @@ -193,11 +204,37 @@ def from_response(
"""
message = completion.choices[0].message

return cls.model_validate_json(
message.function_call.arguments,
context=validation_context,
strict=strict,
)
match mode:
case PatchMode.FUNCTION_CALL:
assert (
message.function_call.name == cls.openai_schema["name"]
), "Function name does not match"
return cls.model_validate_json(
message.function_call.arguments,
context=validation_context,
strict=strict,
)
case PatchMode.TOOL_CALL:
assert (
len(message.tool_calls) == 1
), "Instructor does not support multiple tool calls, use List[Model] instead."
tool_call = message.tool_calls[0]
assert (
tool_call.function.name == cls.openai_schema["name"]
), "Tool name does not match"
return cls.model_validate_json(
tool_call.function.arguments,
context=validation_context,
strict=strict,
)
case PatchMode.JSON_MODE:
return cls.model_validate_json(
message.content,
context=validation_context,
strict=strict,
)
case _:
raise ValueError(f"Invalid patch mode: {mode}")


def openai_schema(cls) -> OpenAISchema:
jxnl marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
119 changes: 74 additions & 45 deletions instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from pydantic import BaseModel, ValidationError

from .function_calls import OpenAISchema, openai_schema
from .function_calls import OpenAISchema, openai_schema, PatchMode

import logging

logger = logging.getLogger(__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import of PatchMode and the logging module are correctly added as per the change summary. However, ensure that the logging module is used consistently throughout the file instead of print statements for logging purposes.



OVERRIDE_DOCS = """
Creates a new chat completion for the provided messages and parameters.
Expand All @@ -29,13 +34,55 @@
"""


def handle_response_model(response_model: Type[BaseModel], kwargs):
def handle_response_model(
*, response_model: Type[BaseModel], kwargs, mode: PatchMode = PatchMode.FUNCTION_CALL
):
new_kwargs = kwargs.copy()
if response_model is not None:
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model) # type: ignore
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore

match mode:
case PatchMode.FUNCTION_CALL:
print("Patching function call")
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
new_kwargs["function_call"] = {
"name": response_model.openai_schema["name"]
} # type: ignore
case PatchMode.TOOL_CALL:
print("Patching tool call")
new_kwargs["tools"] = [
{
"type": "function",
"function": response_model.openai_schema,
}
]
new_kwargs["tool_choice"] = {
"type": "function",
"function": {"name": response_model.openai_schema["name"]},
}
case PatchMode.JSON_MODE:
print("Patching json mode")
new_kwargs["response_format"] = {"type": "json_object"}

# check that the first message is a system message
# if it is not, add a system message to the beginning
message = f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{response_model.model_json_schema()['properties']}"

if new_kwargs["messages"][0]["role"] != "system":
new_kwargs["messages"].insert(
0,
{
"role": "system",
"content": message,
},
)

# if the first message is a system append the schema to the end
if new_kwargs["messages"][0]["role"] == "system":
new_kwargs["messages"][0]["content"] += f"\n\n{message}"
case _:
raise ValueError(f"Invalid patch mode: {mode}")
jxnl marked this conversation as resolved.
Show resolved Hide resolved

if new_kwargs.get("stream", False) and response_model is not None:
import warnings
jxnl marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -48,7 +95,7 @@ def handle_response_model(response_model: Type[BaseModel], kwargs):


def process_response(
response, response_model, validation_context: dict = None, strict=None
response, *, response_model: Type[BaseModel], validation_context: dict = None, strict=None, mode: PatchMode = PatchMode.FUNCTION_CALL
): # type: ignore
"""Processes a OpenAI response with the response model, if available
It can use `validation_context` and `strict` to validate the response
Expand All @@ -62,25 +109,13 @@ def process_response(
"""
if response_model is not None:
model = response_model.from_response(
response, validation_context=validation_context, strict=strict
response, validation_context=validation_context, strict=strict, mode=mode
)
model._raw_response = response
return model
return response


def dump_message(message: ChatCompletionMessage) -> dict:
"""Dumps a message to a dict, to be returned to the OpenAI API.

Workaround for an issue with the OpenAI API, where the `tool_calls` field isn't allowed to be present in requests
if it isn't used.
"""
dumped_message = message.model_dump()
if not dumped_message.get("tool_calls"):
del dumped_message["tool_calls"]
return dumped_message


async def retry_async(
func,
response_model,
Expand All @@ -89,20 +124,18 @@ async def retry_async(
kwargs,
max_retries,
strict: Optional[bool] = None,
mode: PatchMode = PatchMode.FUNCTION_CALL,
):
retries = 0
while retries <= max_retries:
try:
response: ChatCompletion = await func(*args, **kwargs)
return (
process_response(
return process_response(
response,
response_model,
validation_context,
response_model=response_model,
validation_context=validation_context,
strict=strict,
),
None,
)
mode=mode)
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(response.choices[0].message) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The retry_async function attempts to append a message to kwargs["messages"] after catching an exception. However, this assumes that kwargs will always contain a messages key with a list value, which may not be the case. This could raise a KeyError or TypeError if messages is not present or not a list. The code should check for the existence and type of messages before appending.

kwargs["messages"].append(
Expand All @@ -124,20 +157,18 @@ def retry_sync(
kwargs,
max_retries,
strict: Optional[bool] = None,
mode: PatchMode = PatchMode.FUNCTION_CALL,
):
retries = 0
while retries <= max_retries:
# Excepts ValidationError, and JSONDecodeError
try:
response = func(*args, **kwargs)
return (
process_response(
response, response_model, validation_context, strict=strict
),
None,
)
return process_response(
response, response_model=response_model, validation_context=validation_context, strict=strict, mode=mode
)
except (ValidationError, JSONDecodeError) as e:
kwargs["messages"].append(dump_message(response.choices[0].message))
kwargs["messages"].append(response.choices[0].message)
jxnl marked this conversation as resolved.
Show resolved Hide resolved
kwargs["messages"].append(
{
"role": "user",
Comment on lines +188 to 191
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same issue as in retry_async applies to retry_sync. There should be a check for the existence and type of kwargs["messages"] before appending to it.

jxnl marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -156,7 +187,7 @@ def is_async(func: Callable) -> bool:
)


def wrap_chatcompletion(func: Callable) -> Callable:
def wrap_chatcompletion(func: Callable, mode:PatchMode=PatchMode.FUNCTION_CALL) -> Callable:
jxnl marked this conversation as resolved.
Show resolved Hide resolved
func_is_async = is_async(func)

@wraps(func)
Expand All @@ -167,17 +198,16 @@ async def new_chatcompletion_async(
*args,
**kwargs,
):
response_model, new_kwargs = handle_response_model(response_model, kwargs) # type: ignore
response, error = await retry_async(
response_model, new_kwargs = handle_response_model(response_model=response_model, kwargs=kwargs, mode=mode) # type: ignore
response = await retry_async(
func=func,
response_model=response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode,
) # type: ignore
if error:
raise ValueError(error)
return response

@wraps(func)
Expand All @@ -188,17 +218,16 @@ def new_chatcompletion_sync(
*args,
**kwargs,
):
response_model, new_kwargs = handle_response_model(response_model, kwargs) # type: ignore
response, error = retry_sync(
response_model, new_kwargs = handle_response_model(response_model=response_model, kwargs=kwargs, mode=mode) # type: ignore
response = retry_sync(
func=func,
response_model=response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode
) # type: ignore
if error:
raise ValueError(error)
return response

wrapper_function = (
Expand All @@ -208,7 +237,7 @@ def new_chatcompletion_sync(
return wrapper_function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wrap_chatcompletion function is missing a docstring. It's important to maintain consistency and provide documentation for all public functions, especially when they are part of a significant update to the codebase.



def patch(client: Union[OpenAI, AsyncOpenAI]):
def patch(client: Union[OpenAI, AsyncOpenAI], mode: PatchMode=PatchMode.FUNCTION_CALL):
"""
Patch the `client.chat.completions.create` method

Expand All @@ -220,11 +249,11 @@ def patch(client: Union[OpenAI, AsyncOpenAI]):
- `strict` parameter to use strict json parsing
"""

client.chat.completions.create = wrap_chatcompletion(client.chat.completions.create)
client.chat.completions.create = wrap_chatcompletion(client.chat.completions.create, mode=mode)
jxnl marked this conversation as resolved.
Show resolved Hide resolved
return client


def apatch(client: AsyncOpenAI):
def apatch(client: AsyncOpenAI, mode:PatchMode=PatchMode.FUNCTION_CALL):
"""
No longer necessary, use `patch` instead.

Expand All @@ -237,4 +266,4 @@ def apatch(client: AsyncOpenAI):
- `validation_context` parameter to validate the response using the pydantic model
- `strict` parameter to use strict json parsing
"""
return patch(client)
return patch(client, mode=mode)
60 changes: 60 additions & 0 deletions tests/openai/test_modes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from instructor.function_calls import OpenAISchema, PatchMode
from openai import OpenAI


client = OpenAI()

class UserExtract(OpenAISchema):
name: str
age: int


def test_tool_call():
response = client.chat.completions.create(
model="gpt-3.5-turbo-1106",
messages=[
{
"role": "user",
"content": "Extract jason is 25 years old, mary is 30 years old",
},
],
tools=[
{
"type": "function",
"function": UserExtract.openai_schema,
}
],
tool_choice={
"type": "function",
"function": {"name": UserExtract.openai_schema["name"]},
},
)
response_message = response.choices[0].message
tool_calls = response_message.tool_calls
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "UserExtract"
assert tool_calls[0].function
user = UserExtract.from_response(response, mode=PatchMode.TOOL_CALL)
assert user.name.lower() == "jason"
assert user.age == 25


def test_json_mode():
response = client.chat.completions.create(
model="gpt-3.5-turbo-1106",
response_format={"type": "json_object"},
messages=[
{
"role": "system",
"content": f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{UserExtract.model_json_schema()['properties']}",
},
{
"role": "user",
"content": "Extract jason is 25 years old",
},
],
)
print(response.choices[0].message.content)
user = UserExtract.from_response(response, mode=PatchMode.JSON_MODE)
assert user.name.lower() == "jason"
assert user.age == 25
Loading