From a2ac3cab7516c365c78f7843bb292f9b40195b6a Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Tue, 16 Jul 2024 15:23:41 -0700 Subject: [PATCH 1/5] tool calls --- examples/tools/README.md | 3 ++ examples/tools/main.py | 85 ++++++++++++++++++++++++++++++++++++++++ ollama/_client.py | 22 +++++------ ollama/_types.py | 43 +++++++++++++++++++- 4 files changed, 139 insertions(+), 14 deletions(-) create mode 100644 examples/tools/README.md create mode 100644 examples/tools/main.py diff --git a/examples/tools/README.md b/examples/tools/README.md new file mode 100644 index 0000000..85ca5dd --- /dev/null +++ b/examples/tools/README.md @@ -0,0 +1,3 @@ +# tools + +This example demonstrates how to utilize tool calls with an asynchronous Ollama client and the chat endpoint. diff --git a/examples/tools/main.py b/examples/tools/main.py new file mode 100644 index 0000000..aa6aeb3 --- /dev/null +++ b/examples/tools/main.py @@ -0,0 +1,85 @@ +import json +import ollama +import asyncio + +# Simulates an API call to get flight times +# In a real application, this would fetch data from a live database or API +def get_flight_times(departure: str, arrival: str) -> str: + flights = { + "NYC-LAX": {"departure": "08:00 AM", "arrival": "11:30 AM", "duration": "5h 30m"}, + "LAX-NYC": {"departure": "02:00 PM", "arrival": "10:30 PM", "duration": "5h 30m"}, + "LHR-JFK": {"departure": "10:00 AM", "arrival": "01:00 PM", "duration": "8h 00m"}, + "JFK-LHR": {"departure": "09:00 PM", "arrival": "09:00 AM", "duration": "7h 00m"}, + "CDG-DXB": {"departure": "11:00 AM", "arrival": "08:00 PM", "duration": "6h 00m"}, + "DXB-CDG": {"departure": "03:00 AM", "arrival": "07:30 AM", "duration": "7h 30m"}, + } + + key = f"{departure}-{arrival}".upper() + return json.dumps(flights.get(key, {"error": "Flight not found"})) + +async def run(model: str): + client = ollama.AsyncClient() + # Initialize conversation with a user query + messages = [{"role": "user", "content": "What is the flight time from New York (NYC) to Los Angeles (LAX)?"}] + + # First API call: Send the query and function description to the model + response = await client.chat( + model=model, + messages=messages, + tools=[ + { + "type": "function", + "function": { + "name": "get_flight_times", + "description": "Get the flight times between two cities", + "parameters": { + "type": "object", + "properties": { + "departure": { + "type": "string", + "description": "The departure city (airport code)", + }, + "arrival": { + "type": "string", + "description": "The arrival city (airport code)", + }, + }, + "required": ["departure", "arrival"], + }, + }, + }, + ], + ) + + # Add the model's response to the conversation history + messages.append(response["message"]) + + # Check if the model decided to use the provided function + if not response["message"].get("tool_calls"): + print("The model didn't use the function. Its response was:") + print(response["message"]["content"]) + return + + # Process function calls made by the model + if response["message"].get("tool_calls"): + available_functions = { + "get_flight_times": get_flight_times, + } + for tool in response["message"]["tool_calls"]: + function_to_call = available_functions[tool["function"]["name"]] + function_response = function_to_call( + tool["function"]["arguments"]["departure"], + tool["function"]["arguments"]["arrival"] + ) + # Add function response to the conversation + messages.append({ + "role": "tool", + "content": function_response, + }) + + # Second API call: Get final response from the model + final_response = await client.chat(model=model,messages=messages) + print(final_response["message"]["content"]) + +# Run the async function +asyncio.run(run("mistral")) diff --git a/ollama/_client.py b/ollama/_client.py index 1109aee..921dfe5 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -27,7 +27,7 @@ except metadata.PackageNotFoundError: __version__ = '0.0.0' -from ollama._types import Message, Options, RequestError, ResponseError +from ollama._types import Message, Options, RequestError, ResponseError, ToolCall, Tool class BaseClient: @@ -180,6 +180,7 @@ def chat( self, model: str = '', messages: Optional[Sequence[Message]] = None, + tools: Optional[Sequence[Tool]] = None, stream: Literal[False] = False, format: Literal['', 'json'] = '', options: Optional[Options] = None, @@ -191,6 +192,7 @@ def chat( self, model: str = '', messages: Optional[Sequence[Message]] = None, + tools: Optional[Sequence[Tool]] = None, stream: Literal[True] = True, format: Literal['', 'json'] = '', options: Optional[Options] = None, @@ -201,6 +203,7 @@ def chat( self, model: str = '', messages: Optional[Sequence[Message]] = None, + tools: Optional[Sequence[Tool]] = None, stream: bool = False, format: Literal['', 'json'] = '', options: Optional[Options] = None, @@ -222,12 +225,6 @@ def chat( messages = deepcopy(messages) for message in messages or []: - if not isinstance(message, dict): - raise TypeError('messages must be a list of Message or dict-like objects') - if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']: - raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"') - if 'content' not in message: - raise RequestError('messages must contain content') if images := message.get('images'): message['images'] = [_encode_image(image) for image in images] @@ -237,6 +234,7 @@ def chat( json={ 'model': model, 'messages': messages, + 'tools': tools, 'stream': stream, 'format': format, 'options': options or {}, @@ -574,6 +572,7 @@ async def chat( self, model: str = '', messages: Optional[Sequence[Message]] = None, + tools: Optional[Sequence[Tool]] = None, stream: Literal[False] = False, format: Literal['', 'json'] = '', options: Optional[Options] = None, @@ -585,6 +584,7 @@ async def chat( self, model: str = '', messages: Optional[Sequence[Message]] = None, + tools: Optional[Sequence[Tool]] = None, stream: Literal[True] = True, format: Literal['', 'json'] = '', options: Optional[Options] = None, @@ -595,6 +595,7 @@ async def chat( self, model: str = '', messages: Optional[Sequence[Message]] = None, + tools: Optional[Sequence[Tool]] = None, stream: bool = False, format: Literal['', 'json'] = '', options: Optional[Options] = None, @@ -615,12 +616,6 @@ async def chat( messages = deepcopy(messages) for message in messages or []: - if not isinstance(message, dict): - raise TypeError('messages must be a list of strings') - if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']: - raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"') - if 'content' not in message: - raise RequestError('messages must contain content') if images := message.get('images'): message['images'] = [_encode_image(image) for image in images] @@ -630,6 +625,7 @@ async def chat( json={ 'model': model, 'messages': messages, + 'tools': tools, 'stream': stream, 'format': format, 'options': options or {}, diff --git a/ollama/_types.py b/ollama/_types.py index e29888e..00fd530 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -1,5 +1,5 @@ import json -from typing import Any, TypedDict, Sequence, Literal +from typing import Any, TypedDict, Sequence, Literal, Mapping import sys @@ -52,6 +52,24 @@ class GenerateResponse(BaseGenerateResponse): context: Sequence[int] 'Tokenized history up to the point of the response.' +class ToolCallFunction(TypedDict): + """ + Tool call function. + """ + + name: str + 'Name of the function.' + + args: NotRequired[Mapping[str, Any]] + 'Arguments of the function.' + +class ToolCall(TypedDict): + """ + Model tool calls. + """ + + function: ToolCallFunction + 'Function to be called.' class Message(TypedDict): """ @@ -75,7 +93,30 @@ class Message(TypedDict): Valid image formats depend on the model. See the model card for more information. """ + + tool_calls: NotRequired[Sequence[ToolCall]] + """ + Tools calls to be made by the model. + """ + +class Property(TypedDict): + type: str + description: str + enum: NotRequired[Sequence[str]] # `enum` is optional and can be a list of strings + +class Parameters(TypedDict): + type: str + required: Sequence[str] + properties: Mapping[str, Property] + +class ToolFunction(TypedDict): + name: str + description: str + parameters: Parameters +class Tool(TypedDict): + type: str + function: ToolFunction class ChatResponse(BaseGenerateResponse): """ From 67c5ea36e847561966a5c1298b22e8d0f23e26fb Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Tue, 16 Jul 2024 15:30:00 -0700 Subject: [PATCH 2/5] removed unneeded toolcall --- ollama/_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ollama/_client.py b/ollama/_client.py index 921dfe5..760863f 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -27,7 +27,7 @@ except metadata.PackageNotFoundError: __version__ = '0.0.0' -from ollama._types import Message, Options, RequestError, ResponseError, ToolCall, Tool +from ollama._types import Message, Options, RequestError, ResponseError, Tool class BaseClient: From 9409f1afb2c65d954296c7e04516a6ef0663bb80 Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Tue, 16 Jul 2024 15:31:35 -0700 Subject: [PATCH 3/5] rfmt --- examples/tools/main.py | 138 +++++++++++++++++++++-------------------- ollama/_types.py | 40 +++++++----- 2 files changed, 94 insertions(+), 84 deletions(-) diff --git a/examples/tools/main.py b/examples/tools/main.py index aa6aeb3..133b238 100644 --- a/examples/tools/main.py +++ b/examples/tools/main.py @@ -1,85 +1,87 @@ import json -import ollama +import ollama import asyncio + # Simulates an API call to get flight times # In a real application, this would fetch data from a live database or API def get_flight_times(departure: str, arrival: str) -> str: - flights = { - "NYC-LAX": {"departure": "08:00 AM", "arrival": "11:30 AM", "duration": "5h 30m"}, - "LAX-NYC": {"departure": "02:00 PM", "arrival": "10:30 PM", "duration": "5h 30m"}, - "LHR-JFK": {"departure": "10:00 AM", "arrival": "01:00 PM", "duration": "8h 00m"}, - "JFK-LHR": {"departure": "09:00 PM", "arrival": "09:00 AM", "duration": "7h 00m"}, - "CDG-DXB": {"departure": "11:00 AM", "arrival": "08:00 PM", "duration": "6h 00m"}, - "DXB-CDG": {"departure": "03:00 AM", "arrival": "07:30 AM", "duration": "7h 30m"}, - } + flights = { + 'NYC-LAX': {'departure': '08:00 AM', 'arrival': '11:30 AM', 'duration': '5h 30m'}, + 'LAX-NYC': {'departure': '02:00 PM', 'arrival': '10:30 PM', 'duration': '5h 30m'}, + 'LHR-JFK': {'departure': '10:00 AM', 'arrival': '01:00 PM', 'duration': '8h 00m'}, + 'JFK-LHR': {'departure': '09:00 PM', 'arrival': '09:00 AM', 'duration': '7h 00m'}, + 'CDG-DXB': {'departure': '11:00 AM', 'arrival': '08:00 PM', 'duration': '6h 00m'}, + 'DXB-CDG': {'departure': '03:00 AM', 'arrival': '07:30 AM', 'duration': '7h 30m'}, + } + + key = f'{departure}-{arrival}'.upper() + return json.dumps(flights.get(key, {'error': 'Flight not found'})) - key = f"{departure}-{arrival}".upper() - return json.dumps(flights.get(key, {"error": "Flight not found"})) async def run(model: str): - client = ollama.AsyncClient() - # Initialize conversation with a user query - messages = [{"role": "user", "content": "What is the flight time from New York (NYC) to Los Angeles (LAX)?"}] + client = ollama.AsyncClient() + # Initialize conversation with a user query + messages = [{'role': 'user', 'content': 'What is the flight time from New York (NYC) to Los Angeles (LAX)?'}] - # First API call: Send the query and function description to the model - response = await client.chat( - model=model, - messages=messages, - tools=[ - { - "type": "function", - "function": { - "name": "get_flight_times", - "description": "Get the flight times between two cities", - "parameters": { - "type": "object", - "properties": { - "departure": { - "type": "string", - "description": "The departure city (airport code)", - }, - "arrival": { - "type": "string", - "description": "The arrival city (airport code)", - }, - }, - "required": ["departure", "arrival"], - }, - }, + # First API call: Send the query and function description to the model + response = await client.chat( + model=model, + messages=messages, + tools=[ + { + 'type': 'function', + 'function': { + 'name': 'get_flight_times', + 'description': 'Get the flight times between two cities', + 'parameters': { + 'type': 'object', + 'properties': { + 'departure': { + 'type': 'string', + 'description': 'The departure city (airport code)', + }, + 'arrival': { + 'type': 'string', + 'description': 'The arrival city (airport code)', + }, }, - ], - ) - - # Add the model's response to the conversation history - messages.append(response["message"]) + 'required': ['departure', 'arrival'], + }, + }, + }, + ], + ) - # Check if the model decided to use the provided function - if not response["message"].get("tool_calls"): - print("The model didn't use the function. Its response was:") - print(response["message"]["content"]) - return + # Add the model's response to the conversation history + messages.append(response['message']) - # Process function calls made by the model - if response["message"].get("tool_calls"): - available_functions = { - "get_flight_times": get_flight_times, + # Check if the model decided to use the provided function + if not response['message'].get('tool_calls'): + print("The model didn't use the function. Its response was:") + print(response['message']['content']) + return + + # Process function calls made by the model + if response['message'].get('tool_calls'): + available_functions = { + 'get_flight_times': get_flight_times, + } + for tool in response['message']['tool_calls']: + function_to_call = available_functions[tool['function']['name']] + function_response = function_to_call(tool['function']['arguments']['departure'], tool['function']['arguments']['arrival']) + # Add function response to the conversation + messages.append( + { + 'role': 'tool', + 'content': function_response, } - for tool in response["message"]["tool_calls"]: - function_to_call = available_functions[tool["function"]["name"]] - function_response = function_to_call( - tool["function"]["arguments"]["departure"], - tool["function"]["arguments"]["arrival"] - ) - # Add function response to the conversation - messages.append({ - "role": "tool", - "content": function_response, - }) + ) + + # Second API call: Get final response from the model + final_response = await client.chat(model=model, messages=messages) + print(final_response['message']['content']) - # Second API call: Get final response from the model - final_response = await client.chat(model=model,messages=messages) - print(final_response["message"]["content"]) # Run the async function -asyncio.run(run("mistral")) +asyncio.run(run('mistral')) diff --git a/ollama/_types.py b/ollama/_types.py index 00fd530..03158ff 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -52,25 +52,28 @@ class GenerateResponse(BaseGenerateResponse): context: Sequence[int] 'Tokenized history up to the point of the response.' + class ToolCallFunction(TypedDict): """ Tool call function. """ - + name: str 'Name of the function.' - + args: NotRequired[Mapping[str, Any]] 'Arguments of the function.' - + + class ToolCall(TypedDict): """ Model tool calls. """ - + function: ToolCallFunction 'Function to be called.' + class Message(TypedDict): """ Chat message. @@ -93,30 +96,35 @@ class Message(TypedDict): Valid image formats depend on the model. See the model card for more information. """ - + tool_calls: NotRequired[Sequence[ToolCall]] """ Tools calls to be made by the model. """ + class Property(TypedDict): - type: str - description: str - enum: NotRequired[Sequence[str]] # `enum` is optional and can be a list of strings + type: str + description: str + enum: NotRequired[Sequence[str]] # `enum` is optional and can be a list of strings + class Parameters(TypedDict): - type: str - required: Sequence[str] - properties: Mapping[str, Property] + type: str + required: Sequence[str] + properties: Mapping[str, Property] + class ToolFunction(TypedDict): - name: str - description: str - parameters: Parameters + name: str + description: str + parameters: Parameters + class Tool(TypedDict): - type: str - function: ToolFunction + type: str + function: ToolFunction + class ChatResponse(BaseGenerateResponse): """ From 5d121a91b9122a3ef60400f05a7e6d0a9ea73fc5 Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Tue, 16 Jul 2024 15:36:28 -0700 Subject: [PATCH 4/5] tests --- tests/test_client.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_client.py b/tests/test_client.py index 711e3a7..727c239 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -26,6 +26,7 @@ def test_client_chat(httpserver: HTTPServer): json={ 'model': 'dummy', 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'tools': [], 'stream': False, 'format': '', 'options': {}, @@ -73,6 +74,7 @@ def generate(): json={ 'model': 'dummy', 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'tools': [], 'stream': True, 'format': '', 'options': {}, @@ -102,6 +104,7 @@ def test_client_chat_images(httpserver: HTTPServer): 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], }, ], + 'tools': [], 'stream': False, 'format': '', 'options': {}, @@ -522,6 +525,7 @@ async def test_async_client_chat(httpserver: HTTPServer): json={ 'model': 'dummy', 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'tools': [], 'stream': False, 'format': '', 'options': {}, @@ -560,6 +564,7 @@ def generate(): json={ 'model': 'dummy', 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'tools': [], 'stream': True, 'format': '', 'options': {}, @@ -590,6 +595,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer): 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], }, ], + 'tools': [], 'stream': False, 'format': '', 'options': {}, From 4b8c40e638d0390c7b8e1378643e9bc751752c86 Mon Sep 17 00:00:00 2001 From: Josh Yan Date: Tue, 16 Jul 2024 15:56:52 -0700 Subject: [PATCH 5/5] tests --- ollama/_client.py | 4 ++-- ollama/_types.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ollama/_client.py b/ollama/_client.py index 760863f..e640a34 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -234,7 +234,7 @@ def chat( json={ 'model': model, 'messages': messages, - 'tools': tools, + 'tools': tools or [], 'stream': stream, 'format': format, 'options': options or {}, @@ -625,7 +625,7 @@ async def chat( json={ 'model': model, 'messages': messages, - 'tools': tools, + 'tools': tools or [], 'stream': stream, 'format': format, 'options': options or {}, diff --git a/ollama/_types.py b/ollama/_types.py index 03158ff..7dcf9c5 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -99,7 +99,7 @@ class Message(TypedDict): tool_calls: NotRequired[Sequence[ToolCall]] """ - Tools calls to be made by the model. + Tools calls to be made by the model. """