diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 747990a7cb..b8c4e53abe 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -17,9 +17,6 @@ from openai import BadRequestError -from autogen.agentchat.chat import _post_process_carryover_item -from autogen.exception_utils import InvalidCarryOverType, SenderRequired - from .._pydantic import BaseModel, model_dump from ..cache.cache import AbstractCache from ..code_utils import ( @@ -34,13 +31,15 @@ ) from ..coding.base import CodeExecutor from ..coding.factory import CodeExecutorFactory +from ..exception_utils import InvalidCarryOverType, SenderRequired from ..formatting_utils import colored from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str from ..io.base import IOStream from ..oai.client import ModelClient, OpenAIWrapper from ..runtime_logging import log_event, log_function_use, log_new_agent, logging_enabled +from ..tools import Tool from .agent import Agent, LLMAgent -from .chat import ChatResult, a_initiate_chats, initiate_chats +from .chat import ChatResult, _post_process_carryover_item, a_initiate_chats, initiate_chats from .utils import consolidate_chat_info, gather_usage_summary __all__ = ("ConversableAgent",) @@ -2695,7 +2694,7 @@ def register_for_llm( name: Optional[str] = None, description: Optional[str] = None, api_style: Literal["function", "tool"] = "tool", - ) -> Callable[[F], F]: + ) -> Callable[[Union[F, Tool]], Tool]: """Decorator factory for registering a function to be used by an agent. It's return value is used to decorate a function to be registered to the agent. The function uses type hints to @@ -2735,7 +2734,7 @@ def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c """ - def _decorator(func: F) -> F: + def _decorator(func_or_tool: Union[F, Tool]) -> Tool: """Decorator for registering a function to be used by an agent. Args: @@ -2749,42 +2748,62 @@ def _decorator(func: F) -> F: RuntimeError: if the LLM config is not set up before registering a function. """ - # name can be overwritten by the parameter, by default it is the same as function name - if name: - func._name = name - elif not hasattr(func, "_name"): - func._name = func.__name__ - - # description is propagated from the previous decorator, but it is mandatory for the first one - if description: - func._description = description - else: - if not hasattr(func, "_description"): - raise ValueError("Function description is required, none found.") + nonlocal name, description - # get JSON schema for the function - f = get_function_schema(func, name=func._name, description=func._description) + tool = self._create_tool_if_needed(func_or_tool, name, description) - # register the function to the agent if there is LLM config, raise an exception otherwise - if self.llm_config is None: - raise RuntimeError("LLM config must be setup before registering a function for LLM.") + self._register_for_llm(tool.func, tool.name, tool.description, api_style) - if api_style == "function": - f = f["function"] - self.update_function_signature(f, is_remove=False) - elif api_style == "tool": - self.update_tool_signature(f, is_remove=False) - else: - raise ValueError(f"Unsupported API style: {api_style}") - - return func + return tool return _decorator + def _create_tool_if_needed( + self, func_or_tool: Union[Tool, Callable[..., Any]], name: Optional[str], description: Optional[str] + ) -> Tool: + + if isinstance(func_or_tool, Tool): + tool: Tool = func_or_tool + tool._name = name or tool.name + tool._description = description or tool.description + + return tool + + if isinstance(func_or_tool, Callable): + func: Callable[..., Any] = func_or_tool + + name = name or func.__name__ + + tool = Tool(name=name, description=description, func=func) + + return tool + + raise ValueError( + "Parameter 'func_or_tool' must be a function or a Tool instance, it is '{type(func_or_tool)}' instead." + ) + + def _register_for_llm( + self, func: Callable[..., Any], name: str, description: str, api_style: Literal["tool", "function"] + ) -> None: + # get JSON schema for the function + f = get_function_schema(func, name=name, description=description) + + # register the function to the agent if there is LLM config, raise an exception otherwise + if self.llm_config is None: + raise RuntimeError("LLM config must be setup before registering a function for LLM.") + + if api_style == "function": + f = f["function"] + self.update_function_signature(f, is_remove=False) + elif api_style == "tool": + self.update_tool_signature(f, is_remove=False) + else: + raise ValueError(f"Unsupported API style: {api_style}") + def register_for_execution( self, name: Optional[str] = None, - ) -> Callable[[F], F]: + ) -> Callable[[Union[Tool, F]], Tool]: """Decorator factory for registering a function to be executed by an agent. It's return value is used to decorate a function to be registered to the agent. @@ -2806,7 +2825,7 @@ def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c """ - def _decorator(func: F) -> F: + def _decorator(func_or_tool: Union[Tool, F]) -> Tool: """Decorator for registering a function to be used by an agent. Args: @@ -2819,15 +2838,13 @@ def _decorator(func: F) -> F: ValueError: if the function description is not provided and not propagated by a previous decorator. """ - # name can be overwritten by the parameter, by default it is the same as function name - if name: - func._name = name - elif not hasattr(func, "_name"): - func._name = func.__name__ + nonlocal name + + tool = self._create_tool_if_needed(func_or_tool=func_or_tool, name=name, description=None) - self.register_function({func._name: self._wrap_function(func)}) + self.register_function({tool.name: self._wrap_function(tool.func)}) - return func + return tool return _decorator diff --git a/autogen/interop/pydantic_ai/pydantic_ai_tool.py b/autogen/interop/pydantic_ai/pydantic_ai_tool.py index 7ff50181ba..d83d5b109d 100644 --- a/autogen/interop/pydantic_ai/pydantic_ai_tool.py +++ b/autogen/interop/pydantic_ai/pydantic_ai_tool.py @@ -38,7 +38,7 @@ def __init__( parameters_json_schema (Dict[str, Any]): A schema describing the parameters that the function accepts. """ - super().__init__(name, description, func) + super().__init__(name=name, description=description, func=func) self._func_schema = { "type": "function", "function": { diff --git a/autogen/tools/tool.py b/autogen/tools/tool.py index 43914aa59c..d4fdd57719 100644 --- a/autogen/tools/tool.py +++ b/autogen/tools/tool.py @@ -2,9 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal -from ..agentchat.conversable_agent import ConversableAgent +if TYPE_CHECKING: + from ..agentchat.conversable_agent import ConversableAgent __all__ = ["Tool"] @@ -22,7 +23,7 @@ class Tool: func (Callable[..., Any]): The function to be executed when the tool is called. """ - def __init__(self, name: str, description: str, func: Callable[..., Any]) -> None: + def __init__(self, *, name: str, description: str, func: Callable[..., Any]) -> None: """Create a new Tool object. Args: @@ -46,7 +47,7 @@ def description(self) -> str: def func(self) -> Callable[..., Any]: return self._func - def register_for_llm(self, agent: ConversableAgent) -> None: + def register_for_llm(self, agent: "ConversableAgent") -> None: """ Registers the tool for use with a ConversableAgent's language model (LLM). @@ -56,9 +57,9 @@ def register_for_llm(self, agent: ConversableAgent) -> None: Args: agent (ConversableAgent): The agent to which the tool will be registered. """ - agent.register_for_llm(name=self._name, description=self._description)(self._func) + agent.register_for_llm()(self) - def register_for_execution(self, agent: ConversableAgent) -> None: + def register_for_execution(self, agent: "ConversableAgent") -> None: """ Registers the tool for direct execution by a ConversableAgent. @@ -68,4 +69,17 @@ def register_for_execution(self, agent: ConversableAgent) -> None: Args: agent (ConversableAgent): The agent to which the tool will be registered. """ - agent.register_for_execution(name=self._name)(self._func) + agent.register_for_execution()(self) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """ + Execute the tool by calling its underlying function with the provided arguments. + + Args: + *args: Positional arguments to pass to the tool + **kwargs: Keyword arguments to pass to the tool + + Returns: + The result of executing the tool's function. + """ + return self._func(*args, **kwargs) diff --git a/notebook/agentchat_function_call_currency_calculator.ipynb b/notebook/agentchat_function_call_currency_calculator.ipynb index 659cea748e..0b13e8230a 100644 --- a/notebook/agentchat_function_call_currency_calculator.ipynb +++ b/notebook/agentchat_function_call_currency_calculator.ipynb @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 1, "id": "2b803c17", "metadata": {}, "outputs": [], @@ -50,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 2, "id": "dca301a4", "metadata": {}, "outputs": [], @@ -120,7 +120,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 3, "id": "9fb85afb", "metadata": {}, "outputs": [], @@ -180,35 +180,10 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "3e52bbfe", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'type': 'function',\n", - " 'function': {'description': 'Currency exchange calculator.',\n", - " 'name': 'currency_calculator',\n", - " 'parameters': {'type': 'object',\n", - " 'properties': {'base_amount': {'type': 'number',\n", - " 'description': 'Amount of currency in base_currency'},\n", - " 'base_currency': {'enum': ['USD', 'EUR'],\n", - " 'type': 'string',\n", - " 'default': 'USD',\n", - " 'description': 'Base currency'},\n", - " 'quote_currency': {'enum': ['USD', 'EUR'],\n", - " 'type': 'string',\n", - " 'default': 'EUR',\n", - " 'description': 'Quote currency'}},\n", - " 'required': ['base_amount']}}}]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "chatbot.llm_config[\"tools\"]" ] @@ -229,12 +204,12 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 6, "id": "bd943369", "metadata": {}, "outputs": [], "source": [ - "assert user_proxy.function_map[\"currency_calculator\"]._origin == currency_calculator" + "assert user_proxy.function_map[\"currency_calculator\"]._origin == currency_calculator.func" ] }, { @@ -247,66 +222,10 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "id": "d5518947", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", - "\n", - "How much is 123.45 USD in EUR?\n", - "\n", - "--------------------------------------------------------------------------------\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", - "\n", - "\u001b[32m***** Suggested tool call (call_9ogJS4d40BT1rXfMn7YJb151): currency_calculator *****\u001b[0m\n", - "Arguments: \n", - "{\n", - " \"base_amount\": 123.45,\n", - " \"base_currency\": \"USD\",\n", - " \"quote_currency\": \"EUR\"\n", - "}\n", - "\u001b[32m************************************************************************************\u001b[0m\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[35m\n", - ">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n", - "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", - "\n", - "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", - "\n", - "\u001b[32m***** Response from calling tool (call_9ogJS4d40BT1rXfMn7YJb151) *****\u001b[0m\n", - "112.22727272727272 EUR\n", - "\u001b[32m**********************************************************************\u001b[0m\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", - "\n", - "123.45 USD is equivalent to 112.23 EUR.\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", - "\n", - "\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", - "\n", - "TERMINATE\n", - "\n", - "--------------------------------------------------------------------------------\n" - ] - } - ], + "outputs": [], "source": [ "with Cache.disk() as cache:\n", " # start the conversation\n", @@ -317,18 +236,10 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "4b5a0edc", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Chat summary: 123.45 USD is equivalent to 112.23 EUR.\n" - ] - } - ], + "outputs": [], "source": [ "print(\"Chat summary:\", res.summary)" ] @@ -351,7 +262,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 9, "id": "7b3d8b58", "metadata": {}, "outputs": [], @@ -400,101 +311,20 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "id": "971ed0d5", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'type': 'function',\n", - " 'function': {'description': 'Currency exchange calculator.',\n", - " 'name': 'currency_calculator',\n", - " 'parameters': {'type': 'object',\n", - " 'properties': {'base': {'properties': {'currency': {'description': 'Currency symbol',\n", - " 'enum': ['USD', 'EUR'],\n", - " 'title': 'Currency',\n", - " 'type': 'string'},\n", - " 'amount': {'default': 0,\n", - " 'description': 'Amount of currency',\n", - " 'minimum': 0.0,\n", - " 'title': 'Amount',\n", - " 'type': 'number'}},\n", - " 'required': ['currency'],\n", - " 'title': 'Currency',\n", - " 'type': 'object',\n", - " 'description': 'Base currency: amount and currency symbol'},\n", - " 'quote_currency': {'enum': ['USD', 'EUR'],\n", - " 'type': 'string',\n", - " 'default': 'USD',\n", - " 'description': 'Quote currency symbol'}},\n", - " 'required': ['base']}}}]" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "chatbot.llm_config[\"tools\"]" ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "id": "ab081090", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", - "\n", - "How much is 112.23 Euros in US Dollars?\n", - "\n", - "--------------------------------------------------------------------------------\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", - "\n", - "\u001b[32m***** Suggested tool call (call_BQkSmdFHsrKvmtDWCk0mY5sF): currency_calculator *****\u001b[0m\n", - "Arguments: \n", - "{\n", - " \"base\": {\n", - " \"currency\": \"EUR\",\n", - " \"amount\": 112.23\n", - " },\n", - " \"quote_currency\": \"USD\"\n", - "}\n", - "\u001b[32m************************************************************************************\u001b[0m\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[35m\n", - ">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n", - "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", - "\n", - "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", - "\n", - "\u001b[32m***** Response from calling tool (call_BQkSmdFHsrKvmtDWCk0mY5sF) *****\u001b[0m\n", - "{\"currency\":\"USD\",\"amount\":123.45300000000002}\n", - "\u001b[32m**********************************************************************\u001b[0m\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", - "\n", - "112.23 Euros is equivalent to 123.45 US Dollars.\n", - "TERMINATE\n", - "\n", - "--------------------------------------------------------------------------------\n" - ] - } - ], + "outputs": [], "source": [ "with Cache.disk() as cache:\n", " # start the conversation\n", @@ -505,86 +335,20 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "id": "4799f60c", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Chat summary: 112.23 Euros is equivalent to 123.45 US Dollars.\n" - ] - } - ], + "outputs": [], "source": [ "print(\"Chat summary:\", res.summary)" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "id": "0064d9cd", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", - "\n", - "How much is 123.45 US Dollars in Euros?\n", - "\n", - "--------------------------------------------------------------------------------\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", - "\n", - "\u001b[32m***** Suggested tool call (call_Xxol42xTswZHGX60OjvIQRG1): currency_calculator *****\u001b[0m\n", - "Arguments: \n", - "{\n", - " \"base\": {\n", - " \"currency\": \"USD\",\n", - " \"amount\": 123.45\n", - " },\n", - " \"quote_currency\": \"EUR\"\n", - "}\n", - "\u001b[32m************************************************************************************\u001b[0m\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[35m\n", - ">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n", - "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", - "\n", - "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", - "\n", - "\u001b[32m***** Response from calling tool (call_Xxol42xTswZHGX60OjvIQRG1) *****\u001b[0m\n", - "{\"currency\":\"EUR\",\"amount\":112.22727272727272}\n", - "\u001b[32m**********************************************************************\u001b[0m\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", - "\n", - "123.45 US Dollars is equivalent to 112.23 Euros.\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33muser_proxy\u001b[0m (to chatbot):\n", - "\n", - "\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\u001b[33mchatbot\u001b[0m (to user_proxy):\n", - "\n", - "TERMINATE\n", - "\n", - "--------------------------------------------------------------------------------\n" - ] - } - ], + "outputs": [], "source": [ "with Cache.disk() as cache:\n", " # start the conversation\n", @@ -597,18 +361,10 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "id": "80b2b42c", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Chat history: [{'content': 'How much is 123.45 US Dollars in Euros?', 'role': 'assistant'}, {'tool_calls': [{'id': 'call_Xxol42xTswZHGX60OjvIQRG1', 'function': {'arguments': '{\\n \"base\": {\\n \"currency\": \"USD\",\\n \"amount\": 123.45\\n },\\n \"quote_currency\": \"EUR\"\\n}', 'name': 'currency_calculator'}, 'type': 'function'}], 'content': None, 'role': 'assistant'}, {'content': '{\"currency\":\"EUR\",\"amount\":112.22727272727272}', 'tool_responses': [{'tool_call_id': 'call_Xxol42xTswZHGX60OjvIQRG1', 'role': 'tool', 'content': '{\"currency\":\"EUR\",\"amount\":112.22727272727272}'}], 'role': 'tool'}, {'content': '123.45 US Dollars is equivalent to 112.23 Euros.', 'role': 'user'}, {'content': '', 'role': 'assistant'}, {'content': 'TERMINATE', 'role': 'user'}]\n" - ] - } - ], + "outputs": [], "source": [ "print(\"Chat history:\", res.chat_history)" ] @@ -636,7 +392,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.16" } }, "nbformat": 4, diff --git a/scripts/pre-commit-license-check.py b/scripts/pre-commit-license-check.py old mode 100644 new mode 100755 diff --git a/scripts/test_skip_openai.sh b/scripts/test_skip_openai.sh new file mode 100755 index 0000000000..ff741d8322 --- /dev/null +++ b/scripts/test_skip_openai.sh @@ -0,0 +1,7 @@ +#! /bin/env bash + +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +pytest test --ignore=test/agentchat/contrib --skip-openai --durations=10 --durations-min=1.0 diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index d43f2dba3f..45775a55d7 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -24,6 +24,7 @@ from autogen.agentchat import ConversableAgent, UserProxyAgent from autogen.agentchat.conversable_agent import register_function from autogen.exception_utils import InvalidCarryOverType, SenderRequired +from autogen.tools.tool import Tool sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from conftest import MOCK_OPEN_AI_API_KEY, reason, skip_openai # noqa: E402 @@ -812,14 +813,12 @@ def test_register_for_llm_without_description(): mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY) agent = ConversableAgent(name="agent", llm_config={"config_list": gpt4_config_list}) - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match=" Input should be a valid string"): @agent.register_for_llm() def exec_python(cell: Annotated[str, "Valid Python cell to execute."]) -> str: pass - assert e.value.args[0] == "Function description is required, none found." - def test_register_for_llm_without_LLM(): agent = ConversableAgent(name="agent", llm_config=None) @@ -863,11 +862,11 @@ def test_register_for_execution(): def exec_python(cell: Annotated[str, "Valid Python cell to execute."]): pass - expected_function_map_1 = {"exec_python": exec_python} + expected_function_map_1 = {"exec_python": exec_python.func} assert get_origin(agent.function_map) == expected_function_map_1 assert get_origin(user_proxy_1.function_map) == expected_function_map_1 - expected_function_map_2 = {"python": exec_python} + expected_function_map_2 = {"python": exec_python.func} assert get_origin(user_proxy_2.function_map) == expected_function_map_2 @agent.register_for_execution() @@ -877,8 +876,8 @@ async def exec_sh(script: Annotated[str, "Valid shell script to execute."]): pass expected_function_map = { - "exec_python": exec_python, - "sh": exec_sh, + "exec_python": exec_python.func, + "sh": exec_sh.func, } assert get_origin(agent.function_map) == expected_function_map assert get_origin(user_proxy_1.function_map) == expected_function_map diff --git a/test/tools/test_tool.py b/test/tools/test_tool.py index b21edebd45..046ded611b 100644 --- a/test/tools/test_tool.py +++ b/test/tools/test_tool.py @@ -57,3 +57,6 @@ def test_register_for_execution(self) -> None: self.tool.register_for_execution(user_proxy) assert user_proxy.can_execute_function("test_tool") assert user_proxy.function_map["test_tool"]("Hello") == "Hello!" + + def test__call__(self) -> None: + assert self.tool("Hello") == "Hello!"