Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of tool mechanism #274

Merged
merged 5 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 58 additions & 41 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@

from openai import BadRequestError

from autogen.agentchat.chat import _post_process_carryover_item
from autogen.exception_utils import InvalidCarryOverType, SenderRequired

from .._pydantic import BaseModel, model_dump
from ..cache.cache import AbstractCache
from ..code_utils import (
Expand All @@ -34,13 +31,15 @@
)
from ..coding.base import CodeExecutor
from ..coding.factory import CodeExecutorFactory
from ..exception_utils import InvalidCarryOverType, SenderRequired
from ..formatting_utils import colored
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
from ..io.base import IOStream
from ..oai.client import ModelClient, OpenAIWrapper
from ..runtime_logging import log_event, log_function_use, log_new_agent, logging_enabled
from ..tools import Tool
from .agent import Agent, LLMAgent
from .chat import ChatResult, a_initiate_chats, initiate_chats
from .chat import ChatResult, _post_process_carryover_item, a_initiate_chats, initiate_chats
from .utils import consolidate_chat_info, gather_usage_summary

__all__ = ("ConversableAgent",)
Expand Down Expand Up @@ -2695,7 +2694,7 @@ def register_for_llm(
name: Optional[str] = None,
description: Optional[str] = None,
api_style: Literal["function", "tool"] = "tool",
) -> Callable[[F], F]:
) -> Callable[[Union[F, Tool]], Tool]:
"""Decorator factory for registering a function to be used by an agent.

It's return value is used to decorate a function to be registered to the agent. The function uses type hints to
Expand Down Expand Up @@ -2735,7 +2734,7 @@ def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c

"""

def _decorator(func: F) -> F:
def _decorator(func_or_tool: Union[F, Tool]) -> Tool:
"""Decorator for registering a function to be used by an agent.

Args:
Expand All @@ -2749,42 +2748,62 @@ def _decorator(func: F) -> F:
RuntimeError: if the LLM config is not set up before registering a function.

"""
# name can be overwritten by the parameter, by default it is the same as function name
if name:
func._name = name
elif not hasattr(func, "_name"):
func._name = func.__name__

# description is propagated from the previous decorator, but it is mandatory for the first one
if description:
func._description = description
else:
if not hasattr(func, "_description"):
raise ValueError("Function description is required, none found.")
nonlocal name, description

# get JSON schema for the function
f = get_function_schema(func, name=func._name, description=func._description)
tool = self._create_tool_if_needed(func_or_tool, name, description)

# register the function to the agent if there is LLM config, raise an exception otherwise
if self.llm_config is None:
raise RuntimeError("LLM config must be setup before registering a function for LLM.")
self._register_for_llm(tool.func, tool.name, tool.description, api_style)

if api_style == "function":
f = f["function"]
self.update_function_signature(f, is_remove=False)
elif api_style == "tool":
self.update_tool_signature(f, is_remove=False)
else:
raise ValueError(f"Unsupported API style: {api_style}")

return func
return tool

return _decorator

def _create_tool_if_needed(
self, func_or_tool: Union[Tool, Callable[..., Any]], name: Optional[str], description: Optional[str]
) -> Tool:

if isinstance(func_or_tool, Tool):
tool: Tool = func_or_tool
tool._name = name or tool.name
tool._description = description or tool.description

return tool

if isinstance(func_or_tool, Callable):
func: Callable[..., Any] = func_or_tool

name = name or func.__name__

tool = Tool(name=name, description=description, func=func)

return tool

raise ValueError(
"Parameter 'func_or_tool' must be a function or a Tool instance, it is '{type(func_or_tool)}' instead."
)

def _register_for_llm(
self, func: Callable[..., Any], name: str, description: str, api_style: Literal["tool", "function"]
) -> None:
# get JSON schema for the function
f = get_function_schema(func, name=name, description=description)

# register the function to the agent if there is LLM config, raise an exception otherwise
if self.llm_config is None:
raise RuntimeError("LLM config must be setup before registering a function for LLM.")

if api_style == "function":
f = f["function"]
self.update_function_signature(f, is_remove=False)
elif api_style == "tool":
self.update_tool_signature(f, is_remove=False)
else:
raise ValueError(f"Unsupported API style: {api_style}")

def register_for_execution(
self,
name: Optional[str] = None,
) -> Callable[[F], F]:
) -> Callable[[Union[Tool, F]], Tool]:
"""Decorator factory for registering a function to be executed by an agent.

It's return value is used to decorate a function to be registered to the agent.
Expand All @@ -2806,7 +2825,7 @@ def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c

"""

def _decorator(func: F) -> F:
def _decorator(func_or_tool: Union[Tool, F]) -> Tool:
"""Decorator for registering a function to be used by an agent.

Args:
Expand All @@ -2819,15 +2838,13 @@ def _decorator(func: F) -> F:
ValueError: if the function description is not provided and not propagated by a previous decorator.

"""
# name can be overwritten by the parameter, by default it is the same as function name
if name:
func._name = name
elif not hasattr(func, "_name"):
func._name = func.__name__
nonlocal name

tool = self._create_tool_if_needed(func_or_tool=func_or_tool, name=name, description=None)

self.register_function({func._name: self._wrap_function(func)})
self.register_function({tool.name: self._wrap_function(tool.func)})

return func
return tool

return _decorator

Expand Down
2 changes: 1 addition & 1 deletion autogen/interop/pydantic_ai/pydantic_ai_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
parameters_json_schema (Dict[str, Any]): A schema describing the parameters
that the function accepts.
"""
super().__init__(name, description, func)
super().__init__(name=name, description=description, func=func)
self._func_schema = {
"type": "function",
"function": {
Expand Down
28 changes: 21 additions & 7 deletions autogen/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Literal

from ..agentchat.conversable_agent import ConversableAgent
if TYPE_CHECKING:
from ..agentchat.conversable_agent import ConversableAgent

__all__ = ["Tool"]

Expand All @@ -22,7 +23,7 @@ class Tool:
func (Callable[..., Any]): The function to be executed when the tool is called.
"""

def __init__(self, name: str, description: str, func: Callable[..., Any]) -> None:
def __init__(self, *, name: str, description: str, func: Callable[..., Any]) -> None:
"""Create a new Tool object.

Args:
Expand All @@ -46,7 +47,7 @@ def description(self) -> str:
def func(self) -> Callable[..., Any]:
return self._func

def register_for_llm(self, agent: ConversableAgent) -> None:
def register_for_llm(self, agent: "ConversableAgent") -> None:
"""
Registers the tool for use with a ConversableAgent's language model (LLM).

Expand All @@ -56,9 +57,9 @@ def register_for_llm(self, agent: ConversableAgent) -> None:
Args:
agent (ConversableAgent): The agent to which the tool will be registered.
"""
agent.register_for_llm(name=self._name, description=self._description)(self._func)
agent.register_for_llm()(self)

def register_for_execution(self, agent: ConversableAgent) -> None:
def register_for_execution(self, agent: "ConversableAgent") -> None:
"""
Registers the tool for direct execution by a ConversableAgent.

Expand All @@ -68,4 +69,17 @@ def register_for_execution(self, agent: ConversableAgent) -> None:
Args:
agent (ConversableAgent): The agent to which the tool will be registered.
"""
agent.register_for_execution(name=self._name)(self._func)
agent.register_for_execution()(self)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""
Execute the tool by calling its underlying function with the provided arguments.

Args:
*args: Positional arguments to pass to the tool
**kwargs: Keyword arguments to pass to the tool

Returns:
The result of executing the tool's function.
"""
return self._func(*args, **kwargs)
Loading
Loading