Skip to content

Commit

Permalink
integrate tool calls (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshyan1 authored Jul 17, 2024
1 parent 1a15742 commit 359c63d
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 14 deletions.
3 changes: 3 additions & 0 deletions examples/tools/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# tools

This example demonstrates how to utilize tool calls with an asynchronous Ollama client and the chat endpoint.
87 changes: 87 additions & 0 deletions examples/tools/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import json
import ollama
import asyncio


# Simulates an API call to get flight times
# In a real application, this would fetch data from a live database or API
def get_flight_times(departure: str, arrival: str) -> str:
flights = {
'NYC-LAX': {'departure': '08:00 AM', 'arrival': '11:30 AM', 'duration': '5h 30m'},
'LAX-NYC': {'departure': '02:00 PM', 'arrival': '10:30 PM', 'duration': '5h 30m'},
'LHR-JFK': {'departure': '10:00 AM', 'arrival': '01:00 PM', 'duration': '8h 00m'},
'JFK-LHR': {'departure': '09:00 PM', 'arrival': '09:00 AM', 'duration': '7h 00m'},
'CDG-DXB': {'departure': '11:00 AM', 'arrival': '08:00 PM', 'duration': '6h 00m'},
'DXB-CDG': {'departure': '03:00 AM', 'arrival': '07:30 AM', 'duration': '7h 30m'},
}

key = f'{departure}-{arrival}'.upper()
return json.dumps(flights.get(key, {'error': 'Flight not found'}))


async def run(model: str):
client = ollama.AsyncClient()
# Initialize conversation with a user query
messages = [{'role': 'user', 'content': 'What is the flight time from New York (NYC) to Los Angeles (LAX)?'}]

# First API call: Send the query and function description to the model
response = await client.chat(
model=model,
messages=messages,
tools=[
{
'type': 'function',
'function': {
'name': 'get_flight_times',
'description': 'Get the flight times between two cities',
'parameters': {
'type': 'object',
'properties': {
'departure': {
'type': 'string',
'description': 'The departure city (airport code)',
},
'arrival': {
'type': 'string',
'description': 'The arrival city (airport code)',
},
},
'required': ['departure', 'arrival'],
},
},
},
],
)

# Add the model's response to the conversation history
messages.append(response['message'])

# Check if the model decided to use the provided function
if not response['message'].get('tool_calls'):
print("The model didn't use the function. Its response was:")
print(response['message']['content'])
return

# Process function calls made by the model
if response['message'].get('tool_calls'):
available_functions = {
'get_flight_times': get_flight_times,
}
for tool in response['message']['tool_calls']:
function_to_call = available_functions[tool['function']['name']]
function_response = function_to_call(tool['function']['arguments']['departure'], tool['function']['arguments']['arrival'])
# Add function response to the conversation
messages.append(
{
'role': 'tool',
'content': function_response,
}
)

# Second API call: Get final response from the model
final_response = await client.chat(model=model, messages=messages)
print(final_response['message']['content'])


# Run the async function
asyncio.run(run('mistral'))
22 changes: 9 additions & 13 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
except metadata.PackageNotFoundError:
__version__ = '0.0.0'

from ollama._types import Message, Options, RequestError, ResponseError
from ollama._types import Message, Options, RequestError, ResponseError, Tool


class BaseClient:
Expand Down Expand Up @@ -180,6 +180,7 @@ def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[False] = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
Expand All @@ -191,6 +192,7 @@ def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[True] = True,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
Expand All @@ -201,6 +203,7 @@ def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: bool = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
Expand All @@ -222,12 +225,6 @@ def chat(
messages = deepcopy(messages)

for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of Message or dict-like objects')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if 'content' not in message:
raise RequestError('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]

Expand All @@ -237,6 +234,7 @@ def chat(
json={
'model': model,
'messages': messages,
'tools': tools or [],
'stream': stream,
'format': format,
'options': options or {},
Expand Down Expand Up @@ -574,6 +572,7 @@ async def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[False] = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
Expand All @@ -585,6 +584,7 @@ async def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: Literal[True] = True,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
Expand All @@ -595,6 +595,7 @@ async def chat(
self,
model: str = '',
messages: Optional[Sequence[Message]] = None,
tools: Optional[Sequence[Tool]] = None,
stream: bool = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
Expand All @@ -615,12 +616,6 @@ async def chat(
messages = deepcopy(messages)

for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
if 'content' not in message:
raise RequestError('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]

Expand All @@ -630,6 +625,7 @@ async def chat(
json={
'model': model,
'messages': messages,
'tools': tools or [],
'stream': stream,
'format': format,
'options': options or {},
Expand Down
51 changes: 50 additions & 1 deletion ollama/_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, TypedDict, Sequence, Literal
from typing import Any, TypedDict, Sequence, Literal, Mapping

import sys

Expand Down Expand Up @@ -53,6 +53,27 @@ class GenerateResponse(BaseGenerateResponse):
'Tokenized history up to the point of the response.'


class ToolCallFunction(TypedDict):
"""
Tool call function.
"""

name: str
'Name of the function.'

args: NotRequired[Mapping[str, Any]]
'Arguments of the function.'


class ToolCall(TypedDict):
"""
Model tool calls.
"""

function: ToolCallFunction
'Function to be called.'


class Message(TypedDict):
"""
Chat message.
Expand All @@ -76,6 +97,34 @@ class Message(TypedDict):
Valid image formats depend on the model. See the model card for more information.
"""

tool_calls: NotRequired[Sequence[ToolCall]]
"""
Tools calls to be made by the model.
"""


class Property(TypedDict):
type: str
description: str
enum: NotRequired[Sequence[str]] # `enum` is optional and can be a list of strings


class Parameters(TypedDict):
type: str
required: Sequence[str]
properties: Mapping[str, Property]


class ToolFunction(TypedDict):
name: str
description: str
parameters: Parameters


class Tool(TypedDict):
type: str
function: ToolFunction


class ChatResponse(BaseGenerateResponse):
"""
Expand Down
6 changes: 6 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_client_chat(httpserver: HTTPServer):
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False,
'format': '',
'options': {},
Expand Down Expand Up @@ -73,6 +74,7 @@ def generate():
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': True,
'format': '',
'options': {},
Expand Down Expand Up @@ -102,6 +104,7 @@ def test_client_chat_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'tools': [],
'stream': False,
'format': '',
'options': {},
Expand Down Expand Up @@ -522,6 +525,7 @@ async def test_async_client_chat(httpserver: HTTPServer):
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': False,
'format': '',
'options': {},
Expand Down Expand Up @@ -560,6 +564,7 @@ def generate():
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'tools': [],
'stream': True,
'format': '',
'options': {},
Expand Down Expand Up @@ -590,6 +595,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'tools': [],
'stream': False,
'format': '',
'options': {},
Expand Down

0 comments on commit 359c63d

Please sign in to comment.