Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Streaming MultiTask with response_model #221

Merged
merged 13 commits into from
Nov 26, 2023
42 changes: 24 additions & 18 deletions docs/concepts/lists.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Streaming and MultiTask
# Multi-task and Streaming

A common use case of structured extraction is defining a single schema class and then making another schema to create a list to do multiple extraction

Expand All @@ -13,40 +13,44 @@ class Users(BaseModel):
users: List[User]
```

Defining a task and creating a list of classes is a common enough pattern that we define a helper function `MultiTask` It procides a function to dynamically create a new class that:
Defining a task and creating a list of classes is a common enough pattern that we make this convenient by making use of `Iterable[T]`. This lets us dynamically create a new class that:

1. Dynamic docstrings and class name baed on the task
2. Helper method to support streaming by collectin function_call tokens until a object back out.
1. Has dynamic docstrings and class name based on the task
2. Support streaming by collecting tokens until a task is received back out.

## Extracting Tasks using MultiTask
## Extracting Tasks using Iterable

By using multitask you get a very convient class with prompts and names automatically defined. You get `from_response` just like any other `BaseModel` you're able to extract the list of objects data you want with `MultTask.tasks`.
By using `Iterable` you get a very convient class with prompts and names automatically defined:

```python
import instructor
from openai import OpenAI
from typing import Iterable
from pydantic import BaseModel

client = instructor.patch(OpenAI())
client = instructor.patch(OpenAI(), mode=instructor.function_calls.Mode.JSON)

class User(BaseModel):
name: str
age: int

MultiUser = instructor.MultiTask(User)
Users = Iterable[User]

completion = client.chat.completions.create(
model="gpt-4-0613",
users = client.chat.completions.create(
model="gpt-3.5-turbo-1106",
temperature=0.1,
response_model=Users,
stream=False,
functions=[MultiUser.openai_schema],
function_call={"name": MultiUser.openai_schema["name"]},
messages=[
{
"role": "user",
"content": f"Consider the data below: Jason is 10 and John is 30",
"content": "Consider this data: Jason is 10 and John is 30.\
Correctly segment it into entitites\
Make sure the JSON is correct",
},
],
)
users.model_dump_json()
```

```json
Expand All @@ -60,18 +64,20 @@ completion = client.chat.completions.create(

## Streaming Tasks

Since a `MultiTask(T)` is well contrained to `tasks: List[T]` we can make assuptions on how tokens are used and provide a helper method that allows you generate tasks as the the tokens are streamed in
We can also generate tasks as the tokens are streamed in by defining an `Iterable[T]` type.

Lets look at an example in action with the same class

```python hl_lines="6 26"
MultiUser = instructor.MultiTask(User)
from typing import Iterable

Users = Iterable[User]

completion = client.chat.completions.create(
users = client.chat.completions.create(
model="gpt-4",
temperature=0.1,
stream=True,
response_model=MultiUser,
response_model=Users,
messages=[
{
"role": "system",
Comment on lines 60 to 79
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This review was outside the patches, so it was mapped to the patch with the greatest overlap. Original lines [67-100]

The streaming example uses an undefined variable input, which may lead to confusion. It should be defined or the example should be clarified to indicate that input is a placeholder for actual data.

Expand All @@ -89,7 +95,7 @@ completion = client.chat.completions.create(
max_tokens=1000,
)

for user in MultiUser.from_streaming_response(completion):
for user in users:
assert isinstance(user, User)
print(user)
Comment on lines 91 to 96
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The note about streaming being a prototype is important and should be highlighted or made more prominent to ensure users are aware of its experimental status.


Expand Down
2 changes: 1 addition & 1 deletion docs/concepts/patching.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ client = instructor.patch(OpenAI())
import instructor
from instructor import Mode

client = instructor.patch(OpenAI(), mode=Mode.TOOL_CALL)
client = instructor.patch(OpenAI(), mode=Mode.TOOLS)
```

## JSON Mode
Expand Down
14 changes: 6 additions & 8 deletions examples/streaming_multitask/streaming_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ class User(BaseModel):
age: int


def stream_extract(input: str, cls) -> Iterable[User]:
MultiUser = instructor.MultiTask(cls)
completion = client.chat.completions.create(
Users = Iterable[User]

def stream_extract(input: str) -> Users:
return client.chat.completions.create(
model="gpt-4-0613",
temperature=0.1,
stream=True,
functions=[MultiUser.openai_schema],
function_call={"name": MultiUser.openai_schema["name"]},
response_model=Users,
messages=[
{
"role": "system",
Expand All @@ -40,13 +40,11 @@ def stream_extract(input: str, cls) -> Iterable[User]:
],
max_tokens=1000,
)
return MultiUser.from_streaming_response(completion)


start = time.time()
for user in stream_extract(
input="Create 5 characters from the book Three Body Problem",
cls=User,
input="Create 5 characters from the book Three Body Problem"
):
delay = round(time.time() - start, 1)
print(f"{delay} s: User({user})")
Expand Down
23 changes: 17 additions & 6 deletions instructor/dsl/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

from pydantic import BaseModel, Field, create_model

from instructor.function_calls import OpenAISchema
from instructor.function_calls import OpenAISchema, Mode


class MultiTaskBase:
task_type = None # type: ignore

@classmethod
def from_streaming_response(cls, completion):
json_chunks = cls.extract_json(completion)
def from_streaming_response(cls, completion, mode: Mode):
json_chunks = cls.extract_json(completion, mode)
yield from cls.tasks_from_chunks(json_chunks)

@classmethod
Expand All @@ -31,11 +31,22 @@ def tasks_from_chunks(cls, json_chunks):
yield obj

@staticmethod
def extract_json(completion):
def extract_json(completion, mode: Mode):
for chunk in completion:
try:
if json_chunk := chunk.choices[0].delta.function_call.arguments:
yield json_chunk
if mode == Mode.FUNCTIONS:
if json_chunk := chunk.choices[0].delta.function_call.arguments:
yield json_chunk
elif mode == Mode.JSON:
if json_chunk := chunk.choices[0].delta.content:
yield json_chunk
elif mode == Mode.TOOLS:
if json_chunk := chunk.choices[0].delta.tool_calls:
yield json_chunk[0].function.arguments
else:
raise NotImplementedError(
f"Mode {mode} is not supported for MultiTask streaming"
)
except AttributeError:
pass

Expand Down
8 changes: 8 additions & 0 deletions instructor/function_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def from_response(
validation_context=None,
strict: bool = None,
mode: Mode = Mode.FUNCTIONS,
stream_multitask: bool = False,
):
"""Execute the function from the response of an openai chat completion

Expand All @@ -198,10 +199,17 @@ def from_response(
throw_error (bool): Whether to throw an error if the function call is not detected
validation_context (dict): The validation context to use for validating the response
strict (bool): Whether to use strict json parsing
mode (Mode): The openai completion mode
stream_multitask (bool): Whether to stream a multitask response

Returns:
cls (OpenAISchema): An instance of the class
"""
if stream_multitask:
return cls.from_streaming_response(
completion, mode
)

message = completion.choices[0].message

if mode == Mode.FUNCTIONS:
Anmol6 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
35 changes: 21 additions & 14 deletions instructor/patch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import inspect
from functools import wraps
from instructor.dsl.multitask import MultiTask, MultiTaskBase
from json import JSONDecodeError
from typing import Callable, Optional, Type, Union
from typing import get_origin, get_args, Callable, Optional, Type, Union
from collections.abc import Iterable

from openai import AsyncOpenAI, OpenAI
from openai.types.chat import ChatCompletion
Expand Down Expand Up @@ -50,8 +52,14 @@ def handle_response_model(
):
new_kwargs = kwargs.copy()
if response_model is not None:
if get_origin(response_model) is Iterable:
iterable_element_class = get_args(response_model)[0]
response_model = MultiTask(iterable_element_class)
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model) # type: ignore

if new_kwargs.get("stream", False) and not issubclass(response_model, MultiTaskBase):
raise NotImplementedError("stream=True is not supported when using response_model parameter for non-iterables")

Anmol6 marked this conversation as resolved.
Show resolved Hide resolved
if mode == Mode.FUNCTIONS:
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
Expand All @@ -72,7 +80,7 @@ def handle_response_model(

# check that the first message is a system message
# if it is not, add a system message to the beginning
message = f"Make sure that your response to any message matchs the json_schema below, do not deviate at all: \n{response_model.model_json_schema()['properties']}"
message = f"Make sure that your response to any message matches the json_schema below, do not deviate at all: \n{response_model.model_json_schema()['properties']}"

if new_kwargs["messages"][0]["role"] != "system":
new_kwargs["messages"].insert(
Expand All @@ -89,41 +97,36 @@ def handle_response_model(
else:
raise ValueError(f"Invalid patch mode: {mode}")
Anmol6 marked this conversation as resolved.
Show resolved Hide resolved

if new_kwargs.get("stream", False) and response_model is not None:
raise NotImplementedError(
"stream=True is not supported when using response_model parameter"
)

warnings.warn(
"stream=True is not supported when using response_model parameter"
)

return response_model, new_kwargs


def process_response(
response,
*,
response_model: Type[BaseModel],
stream: bool,
validation_context: dict = None,
strict=None,
mode: Mode = Mode.FUNCTIONS,
): # type: ignore
"""Processes a OpenAI response with the response model, if available
"""Processes a OpenAI response with the response model, if available.
It can use `validation_context` and `strict` to validate the response
via the pydantic model

Args:
response (ChatCompletion): The response from OpenAI's API
response_model (BaseModel): The response model to use for parsing the response
stream (bool): Whether the response is a stream
validation_context (dict, optional): The validation context to use for validating the response. Defaults to None.
strict (bool, optional): Whether to use strict json parsing. Defaults to None.
"""
if response_model is not None:
stream_multitask = stream and issubclass(response_model, MultiTaskBase)
model = response_model.from_response(
response, validation_context=validation_context, strict=strict, mode=mode
response, validation_context=validation_context, strict=strict, mode=mode, stream_multitask=stream_multitask
)
model._raw_response = response
if not stream:
model._raw_response = response
return model
return response

Expand All @@ -142,9 +145,11 @@ async def retry_async(
while retries <= max_retries:
try:
response: ChatCompletion = await func(*args, **kwargs)
stream = kwargs.get("stream", False)
return process_response(
response,
response_model=response_model,
stream=stream,
validation_context=validation_context,
strict=strict,
mode=mode,
Expand Down Expand Up @@ -177,9 +182,11 @@ def retry_sync(
# Excepts ValidationError, and JSONDecodeError
try:
response = func(*args, **kwargs)
stream = kwargs.get("stream", False)
return process_response(
response,
response_model=response_model,
stream=stream,
validation_context=validation_context,
strict=strict,
mode=mode,
Expand Down
Loading