diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index be1c83b05d163..7a6637c271d56 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -1,15 +1,10 @@ """Interface for tools.""" from functools import partial -from inspect import signature from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union -from pydantic import BaseModel, validate_arguments, validator +from pydantic import BaseModel, validator -from langchain.tools.base import ( - BaseTool, - create_schema_from_function, - get_filtered_args, -) +from langchain.tools.base import BaseTool, StructuredTool class Tool(BaseTool): @@ -33,30 +28,21 @@ def args(self) -> dict: """The tool's input arguments.""" if self.args_schema is not None: return self.args_schema.schema()["properties"] - inferred_model = validate_arguments(self.func).model # type: ignore - filtered_args = get_filtered_args( - inferred_model, self.func, invalid_args={"args", "kwargs"} - ) - if filtered_args: - return filtered_args - # For backwards compatability, if the function signature is ambiguous, + # For backwards compatibility, if the function signature is ambiguous, # assume it takes a single string input. return {"tool_input": {"type": "string"}} - def _to_args_and_kwargs(self, tool_input: str | Dict) -> Tuple[Tuple, Dict]: + def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: """Convert tool input to pydantic model.""" args, kwargs = super()._to_args_and_kwargs(tool_input) - if self.is_single_input: - # For backwards compatability. If no schema is inferred, - # the tool must assume it should be run with a single input - all_args = list(args) + list(kwargs.values()) - if len(all_args) != 1: - raise ValueError( - f"Too many arguments to single-input tool {self.name}." - f" Args: {all_args}" - ) - return tuple(all_args), {} - return args, kwargs + # For backwards compatibility. The tool must be run with a single input + all_args = list(args) + list(kwargs.values()) + if len(all_args) != 1: + raise ValueError( + f"Too many arguments to single-input tool {self.name}." + f" Args: {all_args}" + ) + return tuple(all_args), {} def _run(self, *args: Any, **kwargs: Any) -> Any: """Use the tool.""" @@ -129,22 +115,24 @@ def search_api(query: str) -> str: """ def _make_with_name(tool_name: str) -> Callable: - def _make_tool(func: Callable) -> Tool: - assert func.__doc__, "Function must have a docstring" - # Description example: - # search_api(query: str) - Searches the API for the query. - description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" - _args_schema = args_schema - if _args_schema is None and infer_schema: - _args_schema = create_schema_from_function(f"{tool_name}Schema", func) - tool_ = Tool( + def _make_tool(func: Callable) -> BaseTool: + if infer_schema or args_schema is not None: + return StructuredTool.from_function( + func, + name=tool_name, + return_direct=return_direct, + args_schema=args_schema, + infer_schema=infer_schema, + ) + # If someone doesn't want a schema applied, we must treat it as + # a simple string->string function + assert func.__doc__ is not None, "Function must have a docstring" + return Tool( name=tool_name, func=func, - args_schema=_args_schema, - description=description, + description=f"{tool_name} tool", return_direct=return_direct, ) - return tool_ return _make_tool diff --git a/langchain/tools/base.py b/langchain/tools/base.py index c400af43fabf8..9090c6839c9bd 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from inspect import signature -from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union from pydantic import ( BaseModel, @@ -75,16 +75,17 @@ def _create_subset_model( def get_filtered_args( inferred_model: Type[BaseModel], func: Callable, - invalid_args: Optional[Set[str]] = None, ) -> dict: """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] valid_keys = signature(func).parameters - invalid_args = invalid_args or set() - return {k: schema[k] for k in valid_keys if k not in invalid_args} + return {k: schema[k] for k in valid_keys} -def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]: +def create_schema_from_function( + model_name: str, + func: Callable, +) -> Type[BaseModel]: """Create a pydantic schema from a function's signature.""" inferred_model = validate_arguments(func).model # type: ignore # Pydantic adds placeholder virtual fields we need to strip @@ -98,12 +99,23 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): """Interface LangChain tools must implement.""" name: str + """The unique name of the tool that clearly communicates its purpose.""" description: str + """Used to tell the model how/when/why to use the tool. + + You can provide few-shot examples as a part of the description. + """ args_schema: Optional[Type[BaseModel]] = None """Pydantic model class to validate and parse the tool's input arguments.""" return_direct: bool = False + """Whether to return the tool's output directly. Setting this to True means + + that after the tool is called, the AgentExecutor will stop looping. + """ verbose: bool = False + """Whether to log the tool's progress.""" callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) + """Callback manager for this tool.""" class Config: """Configuration for this pydantic object.""" @@ -157,7 +169,7 @@ async def _arun(self, *args: Any, **kwargs: Any) -> Any: """Use the tool asynchronously.""" def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: - # For backwards compatability, if run_input is a string, + # For backwards compatibility, if run_input is a string, # pass as a positional argument. if isinstance(tool_input, str): return (tool_input,), {} @@ -253,3 +265,62 @@ async def arun( def __call__(self, tool_input: Union[str, dict]) -> Any: """Make tool callable.""" return self.run(tool_input) + + +class StructuredTool(BaseTool): + """Tool that can operate on any number of inputs.""" + + description: str = "" + args_schema: Type[BaseModel] = Field(..., description="The tool schema.") + """The input arguments' schema.""" + func: Callable[..., str] + """The function to run when the tool is called.""" + coroutine: Optional[Callable[..., Awaitable[str]]] = None + """The asynchronous version of the function.""" + + @property + def args(self) -> dict: + """The tool's input arguments.""" + return self.args_schema.schema()["properties"] + + def _run(self, *args: Any, **kwargs: Any) -> Any: + """Use the tool.""" + return self.func(*args, **kwargs) + + async def _arun(self, *args: Any, **kwargs: Any) -> Any: + """Use the tool asynchronously.""" + if self.coroutine: + return await self.coroutine(*args, **kwargs) + raise NotImplementedError("Tool does not support async") + + @classmethod + def from_function( + cls, + func: Callable, + name: Optional[str] = None, + description: Optional[str] = None, + return_direct: bool = False, + args_schema: Optional[Type[BaseModel]] = None, + infer_schema: bool = True, + **kwargs: Any, + ) -> StructuredTool: + name = name or func.__name__ + description = description or func.__doc__ + assert ( + description is not None + ), "Function must have a docstring if description not provided." + + # Description example: + # search_api(query: str) - Searches the API for the query. + description = f"{name}{signature(func)} - {description.strip()}" + _args_schema = args_schema + if _args_schema is None and infer_schema: + _args_schema = create_schema_from_function(f"{name}Schema", func) + return cls( + name=name, + func=func, + args_schema=_args_schema, + description=description, + return_direct=return_direct, + **kwargs, + ) diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index 21fc9c388c15c..965d44db110e9 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -16,7 +16,7 @@ from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent from langchain.agents.tools import Tool, tool -from langchain.tools.base import BaseTool, SchemaAnnotationError +from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool def test_unnamed_decorator() -> None: @@ -27,7 +27,7 @@ def search_api(query: str) -> str: """Search the API for the query.""" return "API result" - assert isinstance(search_api, Tool) + assert isinstance(search_api, BaseTool) assert search_api.name == "search_api" assert not search_api.return_direct assert search_api("test") == "API result" @@ -145,7 +145,7 @@ def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str: """Return the arguments directly.""" return f"{arg1} {arg2} {arg3}" - assert isinstance(tool_func, Tool) + assert isinstance(tool_func, BaseTool) assert tool_func.args_schema == _MockSchema @@ -159,7 +159,7 @@ def structured_tool_input( """Return the arguments directly.""" return f"{arg1} {arg2} {arg3}" - assert isinstance(structured_tool_input, Tool) + assert isinstance(structured_tool_input, BaseTool) assert structured_tool_input.args_schema is not None assert ( structured_tool_input.args_schema.schema()["properties"] @@ -171,14 +171,14 @@ def structured_tool_input( def test_structured_args_decorator_no_infer_schema() -> None: """Test functionality with structured arguments parsed as a decorator.""" - @tool(infer_schema=False) + @tool def structured_tool_input( arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None ) -> str: """Return the arguments directly.""" return f"{arg1}, {arg2}, {opt_arg}" - assert isinstance(structured_tool_input, Tool) + assert isinstance(structured_tool_input, BaseTool) assert structured_tool_input.name == "structured_tool_input" args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}} expected_result = "1, 0.001, {'foo': 'bar'}" @@ -193,8 +193,9 @@ def unstructured_tool_input(tool_input: str) -> str: """Return the arguments directly.""" return f"{tool_input}" - assert isinstance(unstructured_tool_input, Tool) + assert isinstance(unstructured_tool_input, BaseTool) assert unstructured_tool_input.args_schema is None + assert unstructured_tool_input.run("foo") == "foo" def test_base_tool_inheritance_base_schema() -> None: @@ -225,18 +226,18 @@ def test_tool_lambda_args_schema() -> None: func=lambda tool_input: tool_input, ) assert tool.args_schema is None - expected_args = {"tool_input": {"title": "Tool Input"}} + expected_args = {"tool_input": {"type": "string"}} assert tool.args == expected_args -def test_tool_lambda_multi_args_schema() -> None: +def test_structured_tool_lambda_multi_args_schema() -> None: """Test args schema inference when the tool argument is a lambda function.""" - tool = Tool( + tool = StructuredTool.from_function( name="tool", description="A tool", func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore ) - assert tool.args_schema is None + assert tool.args_schema is not None expected_args = { "tool_input": {"title": "Tool Input"}, "other_arg": {"title": "Other Arg"}, @@ -268,7 +269,7 @@ def empty_tool_input() -> str: """Return a constant.""" return "the empty result" - assert isinstance(empty_tool_input, Tool) + assert isinstance(empty_tool_input, BaseTool) assert empty_tool_input.name == "empty_tool_input" assert empty_tool_input.args == {} assert empty_tool_input.run({}) == "the empty result" @@ -282,7 +283,7 @@ def search_api(query: str) -> str: """Search the API for the query.""" return "API result" - assert isinstance(search_api, Tool) + assert isinstance(search_api, BaseTool) assert search_api.name == "search" assert not search_api.return_direct @@ -295,7 +296,7 @@ def search_api(query: str) -> str: """Search the API for the query.""" return "API result" - assert isinstance(search_api, Tool) + assert isinstance(search_api, BaseTool) assert search_api.name == "search" assert search_api.return_direct @@ -308,7 +309,7 @@ def search_api(query: str) -> str: """Search the API for the query.""" return "API result" - assert isinstance(search_api, Tool) + assert isinstance(search_api, BaseTool) assert search_api.name == "search_api" assert search_api.return_direct @@ -325,7 +326,7 @@ def search_api( """Search the API for the query.""" return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}" - assert isinstance(search_api, Tool) + assert isinstance(search_api, BaseTool) result = search_api.run( tool_input={ "arg_0": "foo",