Skip to content

Commit

Permalink
Merge pull request ag2ai#274 from ag2ai/tool-refactoring
Browse files Browse the repository at this point in the history
Refactoring of tool mechanism
  • Loading branch information
qingyun-wu authored Dec 23, 2024
2 parents eb6458e + edd8a06 commit 307e196
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 323 deletions.
99 changes: 58 additions & 41 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@

from openai import BadRequestError

from autogen.agentchat.chat import _post_process_carryover_item
from autogen.exception_utils import InvalidCarryOverType, SenderRequired

from .._pydantic import BaseModel, model_dump
from ..cache.cache import AbstractCache
from ..code_utils import (
Expand All @@ -34,13 +31,15 @@
)
from ..coding.base import CodeExecutor
from ..coding.factory import CodeExecutorFactory
from ..exception_utils import InvalidCarryOverType, SenderRequired
from ..formatting_utils import colored
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
from ..io.base import IOStream
from ..oai.client import ModelClient, OpenAIWrapper
from ..runtime_logging import log_event, log_function_use, log_new_agent, logging_enabled
from ..tools import Tool
from .agent import Agent, LLMAgent
from .chat import ChatResult, a_initiate_chats, initiate_chats
from .chat import ChatResult, _post_process_carryover_item, a_initiate_chats, initiate_chats
from .utils import consolidate_chat_info, gather_usage_summary

__all__ = ("ConversableAgent",)
Expand Down Expand Up @@ -2695,7 +2694,7 @@ def register_for_llm(
name: Optional[str] = None,
description: Optional[str] = None,
api_style: Literal["function", "tool"] = "tool",
) -> Callable[[F], F]:
) -> Callable[[Union[F, Tool]], Tool]:
"""Decorator factory for registering a function to be used by an agent.
It's return value is used to decorate a function to be registered to the agent. The function uses type hints to
Expand Down Expand Up @@ -2735,7 +2734,7 @@ def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c
"""

def _decorator(func: F) -> F:
def _decorator(func_or_tool: Union[F, Tool]) -> Tool:
"""Decorator for registering a function to be used by an agent.
Args:
Expand All @@ -2749,42 +2748,62 @@ def _decorator(func: F) -> F:
RuntimeError: if the LLM config is not set up before registering a function.
"""
# name can be overwritten by the parameter, by default it is the same as function name
if name:
func._name = name
elif not hasattr(func, "_name"):
func._name = func.__name__

# description is propagated from the previous decorator, but it is mandatory for the first one
if description:
func._description = description
else:
if not hasattr(func, "_description"):
raise ValueError("Function description is required, none found.")
nonlocal name, description

# get JSON schema for the function
f = get_function_schema(func, name=func._name, description=func._description)
tool = self._create_tool_if_needed(func_or_tool, name, description)

# register the function to the agent if there is LLM config, raise an exception otherwise
if self.llm_config is None:
raise RuntimeError("LLM config must be setup before registering a function for LLM.")
self._register_for_llm(tool.func, tool.name, tool.description, api_style)

if api_style == "function":
f = f["function"]
self.update_function_signature(f, is_remove=False)
elif api_style == "tool":
self.update_tool_signature(f, is_remove=False)
else:
raise ValueError(f"Unsupported API style: {api_style}")

return func
return tool

return _decorator

def _create_tool_if_needed(
self, func_or_tool: Union[Tool, Callable[..., Any]], name: Optional[str], description: Optional[str]
) -> Tool:

if isinstance(func_or_tool, Tool):
tool: Tool = func_or_tool
tool._name = name or tool.name
tool._description = description or tool.description

return tool

if isinstance(func_or_tool, Callable):
func: Callable[..., Any] = func_or_tool

name = name or func.__name__

tool = Tool(name=name, description=description, func=func)

return tool

raise ValueError(
"Parameter 'func_or_tool' must be a function or a Tool instance, it is '{type(func_or_tool)}' instead."
)

def _register_for_llm(
self, func: Callable[..., Any], name: str, description: str, api_style: Literal["tool", "function"]
) -> None:
# get JSON schema for the function
f = get_function_schema(func, name=name, description=description)

# register the function to the agent if there is LLM config, raise an exception otherwise
if self.llm_config is None:
raise RuntimeError("LLM config must be setup before registering a function for LLM.")

if api_style == "function":
f = f["function"]
self.update_function_signature(f, is_remove=False)
elif api_style == "tool":
self.update_tool_signature(f, is_remove=False)
else:
raise ValueError(f"Unsupported API style: {api_style}")

def register_for_execution(
self,
name: Optional[str] = None,
) -> Callable[[F], F]:
) -> Callable[[Union[Tool, F]], Tool]:
"""Decorator factory for registering a function to be executed by an agent.
It's return value is used to decorate a function to be registered to the agent.
Expand All @@ -2806,7 +2825,7 @@ def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c
"""

def _decorator(func: F) -> F:
def _decorator(func_or_tool: Union[Tool, F]) -> Tool:
"""Decorator for registering a function to be used by an agent.
Args:
Expand All @@ -2819,15 +2838,13 @@ def _decorator(func: F) -> F:
ValueError: if the function description is not provided and not propagated by a previous decorator.
"""
# name can be overwritten by the parameter, by default it is the same as function name
if name:
func._name = name
elif not hasattr(func, "_name"):
func._name = func.__name__
nonlocal name

tool = self._create_tool_if_needed(func_or_tool=func_or_tool, name=name, description=None)

self.register_function({func._name: self._wrap_function(func)})
self.register_function({tool.name: self._wrap_function(tool.func)})

return func
return tool

return _decorator

Expand Down
2 changes: 1 addition & 1 deletion autogen/interop/pydantic_ai/pydantic_ai_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
parameters_json_schema (Dict[str, Any]): A schema describing the parameters
that the function accepts.
"""
super().__init__(name, description, func)
super().__init__(name=name, description=description, func=func)
self._func_schema = {
"type": "function",
"function": {
Expand Down
28 changes: 21 additions & 7 deletions autogen/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Literal

from ..agentchat.conversable_agent import ConversableAgent
if TYPE_CHECKING:
from ..agentchat.conversable_agent import ConversableAgent

__all__ = ["Tool"]

Expand All @@ -22,7 +23,7 @@ class Tool:
func (Callable[..., Any]): The function to be executed when the tool is called.
"""

def __init__(self, name: str, description: str, func: Callable[..., Any]) -> None:
def __init__(self, *, name: str, description: str, func: Callable[..., Any]) -> None:
"""Create a new Tool object.
Args:
Expand All @@ -46,7 +47,7 @@ def description(self) -> str:
def func(self) -> Callable[..., Any]:
return self._func

def register_for_llm(self, agent: ConversableAgent) -> None:
def register_for_llm(self, agent: "ConversableAgent") -> None:
"""
Registers the tool for use with a ConversableAgent's language model (LLM).
Expand All @@ -56,9 +57,9 @@ def register_for_llm(self, agent: ConversableAgent) -> None:
Args:
agent (ConversableAgent): The agent to which the tool will be registered.
"""
agent.register_for_llm(name=self._name, description=self._description)(self._func)
agent.register_for_llm()(self)

def register_for_execution(self, agent: ConversableAgent) -> None:
def register_for_execution(self, agent: "ConversableAgent") -> None:
"""
Registers the tool for direct execution by a ConversableAgent.
Expand All @@ -68,4 +69,17 @@ def register_for_execution(self, agent: ConversableAgent) -> None:
Args:
agent (ConversableAgent): The agent to which the tool will be registered.
"""
agent.register_for_execution(name=self._name)(self._func)
agent.register_for_execution()(self)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""
Execute the tool by calling its underlying function with the provided arguments.
Args:
*args: Positional arguments to pass to the tool
**kwargs: Keyword arguments to pass to the tool
Returns:
The result of executing the tool's function.
"""
return self._func(*args, **kwargs)
Loading

0 comments on commit 307e196

Please sign in to comment.