-
-
Notifications
You must be signed in to change notification settings - Fork 732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
docs: add support for mistral tool calling #467
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from pydantic import BaseModel, Field | ||
from typing import Optional | ||
from mistralai.client import MistralClient | ||
from instructor.patch import patch | ||
from instructor.function_calls import Mode | ||
|
||
client = MistralClient() | ||
new_chat = patch(create=client.chat, mode=Mode.MIST_TOOLS) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can make an example in the hub! and update patch.md |
||
|
||
|
||
class UserDetail(BaseModel): | ||
age: Optional[int] = None | ||
name: Optional[str] = None | ||
role: Optional[str] = None | ||
|
||
|
||
def get_user_detail(string) -> UserDetail: | ||
return new_chat( | ||
model="mistral-large-latest", | ||
response_model=UserDetail, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": f"Get user details for {string}", | ||
}, | ||
], | ||
) # type: ignore | ||
|
||
|
||
user = get_user_detail("Jason is 25 years old") | ||
print(user.model_dump_json(indent=2)) | ||
""" | ||
{ | ||
"age": 25, | ||
"name": "Jason", | ||
"role": null | ||
} | ||
""" | ||
|
||
user = get_user_detail("Jason is a 25 years old scientist") | ||
print(user.model_dump_json(indent=2)) | ||
""" | ||
{ | ||
"age": 25, | ||
"name": "Jason", | ||
"role": "scientist" | ||
} | ||
""" | ||
|
||
user = get_user_detail("User not found") | ||
print(user.model_dump_json(indent=2)) | ||
""" | ||
{ | ||
"age": null, | ||
"name": null, | ||
"role": null | ||
} | ||
""" |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -15,6 +15,7 @@ class Mode(enum.Enum): | |||||
FUNCTIONS: str = "function_call" | ||||||
PARALLEL_TOOLS: str = "parallel_tool_call" | ||||||
TOOLS: str = "tool_call" | ||||||
MIST_TOOLS: str = "mistral_tools" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
JSON: str = "json_mode" | ||||||
MD_JSON: str = "markdown_json_mode" | ||||||
JSON_SCHEMA: str = "json_schema_mode" | ||||||
|
@@ -114,7 +115,7 @@ def from_response( | |||||
context=validation_context, | ||||||
strict=strict, | ||||||
) | ||||||
elif mode == Mode.TOOLS: | ||||||
elif mode == Mode.TOOLS or mode == Mode.MIST_TOOLS: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mode in
Suggested change
|
||||||
assert ( | ||||||
len(message.tool_calls) == 1 | ||||||
), "Instructor does not support multiple tool calls, use List[Model] instead." | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -116,17 +116,20 @@ def handle_response_model( | |||||
if mode == Mode.FUNCTIONS: | ||||||
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore | ||||||
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore | ||||||
elif mode == Mode.TOOLS: | ||||||
elif mode == Mode.TOOLS or mode == Mode.MIST_TOOLS: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
new_kwargs["tools"] = [ | ||||||
{ | ||||||
"type": "function", | ||||||
"function": response_model.openai_schema, | ||||||
} | ||||||
] | ||||||
new_kwargs["tool_choice"] = { | ||||||
"type": "function", | ||||||
"function": {"name": response_model.openai_schema["name"]}, | ||||||
} | ||||||
if mode == Mode.MIST_TOOLS: | ||||||
new_kwargs["tool_choice"] = "any" | ||||||
else: | ||||||
new_kwargs["tool_choice"] = { | ||||||
"type": "function", | ||||||
"function": {"name": response_model.openai_schema["name"]}, | ||||||
} | ||||||
elif mode in {Mode.JSON, Mode.MD_JSON, Mode.JSON_SCHEMA}: | ||||||
# If its a JSON Mode we need to massage the prompt a bit | ||||||
# in order to get the response we want in a json format | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.