Skip to content

Commit

Permalink
Add multiple modalities: tools, functions, json_mode (#218)
Browse files Browse the repository at this point in the history
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
jxnl and coderabbitai[bot] authored Nov 25, 2023
1 parent 7de55a9 commit 359c5f9
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 111 deletions.
70 changes: 70 additions & 0 deletions docs/concepts/patching.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Patching

Instructor enhances client functionality with three new keywords for backwards compatibility. This allows use of the enhanced client as usual, with structured output benefits.

- `response_model`: Defines the response type for `chat.completions.create`.
- `max_retries`: Determines retry attempts for failed `chat.completions.create` validations.
- `validation_context`: Provides extra context to the validation process.

There are three methods for structured output:

1. **Function Calling**: The primary method. Use this for stability and testing.
2. **Tool Calling**: Useful in specific scenarios; lacks the reasking feature of OpenAI's tool calling API.
3. **JSON Mode**: Offers closer adherence to JSON but with more potential validation errors. Suitable for specific non-function calling clients.

## Function Calling

```python
from openai import OpenAI
import instructor

client = instructor.patch(OpenAI())
```

## Tool Calling

```python
import instructor
from instructor import Mode

client = instructor.patch(OpenAI(), mode=Mode.TOOL_CALL)
```

## JSON Mode

```python
import instructor
from instructor import Mode
from openai import OpenAI

client = instructor.patch(OpenAI(), mode=Mode.JSON)
```

### Schema Integration

In JSON Mode, the schema is part of the system message:

```python
import instructor
from openai import OpenAI

client = instructor.patch(OpenAI())

response = client.chat.completions.create(
model="gpt-3.5-turbo-1106",
response_format={"type": "json_object"},
messages=[
{
"role": "system",
"content": f"Match your response to this json_schema: \n{UserExtract.model_json_schema()['properties']}",
},
{
"role": "user",
"content": "Extract jason is 25 years old",
},
],
)
user = UserExtract.from_response(response, mode=Mode.JSON)
assert user.name.lower() == "jason"
assert user.age == 25
```
48 changes: 42 additions & 6 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from calendar import c
import json
from docstring_parser import parse
from functools import wraps
from typing import Any, Callable
from pydantic import BaseModel, create_model, validate_arguments

import enum


class Mode(enum.Enum):
"""The mode to use for patching the client"""

FUNCTIONS: str = "function_call"
TOOLS: str = "tool_call"
JSON: str = "json_mode"


class openai_function:
"""
Expand Down Expand Up @@ -176,9 +187,9 @@ def openai_schema(cls):
def from_response(
cls,
completion,
throw_error: bool = True,
validation_context=None,
strict: bool = None,
mode: Mode = Mode.FUNCTIONS,
):
"""Execute the function from the response of an openai chat completion
Expand All @@ -193,11 +204,36 @@ def from_response(
"""
message = completion.choices[0].message

return cls.model_validate_json(
message.function_call.arguments,
context=validation_context,
strict=strict,
)
if mode == Mode.FUNCTIONS:
assert (
message.function_call.name == cls.openai_schema["name"]
), "Function name does not match"
return cls.model_validate_json(
message.function_call.arguments,
context=validation_context,
strict=strict,
)
elif mode == Mode.TOOLS:
assert (
len(message.tool_calls) == 1
), "Instructor does not support multiple tool calls, use List[Model] instead."
tool_call = message.tool_calls[0]
assert (
tool_call.function.name == cls.openai_schema["name"]
), "Tool name does not match"
return cls.model_validate_json(
tool_call.function.arguments,
context=validation_context,
strict=strict,
)
elif mode == Mode.JSON:
return cls.model_validate_json(
message.content,
context=validation_context,
strict=strict,
)
else:
raise ValueError(f"Invalid patch mode: {mode}")


def openai_schema(cls) -> OpenAISchema:
Expand Down
Loading

0 comments on commit 359c5f9

Please sign in to comment.