From 027638af8aba98a397284a221a4e6d76148e0b28 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Thu, 27 Apr 2023 15:36:11 -0700 Subject: [PATCH 1/2] Add validation on agent instantiation for multi-input tools (#3681) Tradeoffs here: - No lint-time checking for compatibility - Differs from JS package - The signature inference, etc. in the base tool isn't simple - The `args_schema` is optional Pros: - Forwards compatibility retained - Doesn't break backwards compatibility - User doesn't have to think about which class to subclass (single base tool or dynamic `Tool` interface regardless of input) - No need to change the load_tools, etc. interfaces Co-authored-by: Hasan Patel --- langchain/agents/tools.py | 58 +++++++++++++++++++-------- langchain/tools/base.py | 11 +++-- tests/unit_tests/agents/test_tools.py | 18 ++++++++- 3 files changed, 67 insertions(+), 20 deletions(-) diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index 913110baddc12..4afd533c54dda 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -32,9 +32,11 @@ def validate_func_not_partial(cls, func: Callable) -> Callable: def args(self) -> dict: if self.args_schema is not None: return self.args_schema.schema()["properties"] - else: - inferred_model = validate_arguments(self.func).model # type: ignore - return get_filtered_args(inferred_model, self.func) + inferred_model = validate_arguments(self.func).model # type: ignore + filtered_args = get_filtered_args(inferred_model, self.func, {"args", "kwargs"}) + if filtered_args: + return filtered_args + return {"tool_input": {"type": "string"}} def _run(self, *args: Any, **kwargs: Any) -> str: """Use the tool.""" @@ -46,9 +48,41 @@ async def _arun(self, *args: Any, **kwargs: Any) -> str: return await self.coroutine(*args, **kwargs) raise NotImplementedError("Tool does not support async") + @classmethod + def from_function( + cls, + func: Callable, + name: Optional[str] = None, + description: Optional[str] = None, + return_direct: bool = False, + args_schema: Optional[Type[BaseModel]] = None, + infer_schema: bool = True, + **kwargs: Any, + ) -> "Tool": + name = name or func.__name__ + description = description or func.__doc__ + assert ( + description is not None + ), "Function must have a docstring if description not provided." + + # Description example: + # search_api(query: str) - Searches the API for the query. + description = f"{name}{signature(func)} - {description.strip()}" + _args_schema = args_schema + if _args_schema is None and infer_schema: + _args_schema = create_schema_from_function(f"{name}Schema", func) + return cls( + name=name, + func=func, + args_schema=_args_schema, + description=description, + return_direct=return_direct, + **kwargs, + ) + # TODO: this is for backwards compatibility, remove in future def __init__( - self, name: str, func: Callable[[str], str], description: str, **kwargs: Any + self, name: str, func: Callable, description: str, **kwargs: Any ) -> None: """Initialize tool.""" super(Tool, self).__init__( @@ -108,21 +142,13 @@ def search_api(query: str) -> str: def _make_with_name(tool_name: str) -> Callable: def _make_tool(func: Callable) -> Tool: - assert func.__doc__, "Function must have a docstring" - # Description example: - # search_api(query: str) - Searches the API for the query. - description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" - _args_schema = args_schema - if _args_schema is None and infer_schema: - _args_schema = create_schema_from_function(f"{tool_name}Schema", func) - tool_ = Tool( + return Tool.from_function( + func, name=tool_name, - func=func, - args_schema=_args_schema, - description=description, return_direct=return_direct, + args_schema=args_schema, + infer_schema=infer_schema, ) - return tool_ return _make_tool diff --git a/langchain/tools/base.py b/langchain/tools/base.py index 95aae5366b77f..78ccd03f84440 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from inspect import signature -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union from pydantic import ( BaseModel, @@ -81,11 +81,16 @@ def _create_subset_model( return create_model(name, **fields) # type: ignore -def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict: +def get_filtered_args( + inferred_model: Type[BaseModel], + func: Callable, + invalid_keys: Optional[Set[str]] = None, +) -> dict: """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] valid_keys = signature(func).parameters - return {k: schema[k] for k in valid_keys} + invalid_keys = invalid_keys or set() + return {k: schema[k] for k in valid_keys if k not in invalid_keys} def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]: diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index 6cac896f39892..49680be319e95 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -1,7 +1,7 @@ """Test tool utils.""" from datetime import datetime from functools import partial -from typing import Optional, Type, Union +from typing import Any, Optional, Type, Union from unittest.mock import MagicMock import pydantic @@ -423,3 +423,19 @@ def the_tool(foo: str, bar: str) -> str: f" multi-input tool {the_tool.name}.", ): agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore + + +def test_tool_no_args_specified_assumes_str(): + """Older tools could assume *args and **kwargs were passed in.""" + + def ambiguous_function(*args: Any, **kwargs: Any) -> str: + """An ambiguously defined function.""" + return args[0] + + some_tool = Tool( + name="chain_run", + description="Run the chain", + func=ambiguous_function, + ) + expected_args = {"tool_input": {"type": "string"}} + assert some_tool.args == expected_args From 30b42514aa753219fd135b5b6818213749cdcb85 Mon Sep 17 00:00:00 2001 From: vowelparrot <130414180+vowelparrot@users.noreply.github.com> Date: Thu, 27 Apr 2023 16:58:23 -0700 Subject: [PATCH 2/2] Filter args when function is only *args and **kwargs --- langchain/agents/tools.py | 76 +++++++++++++-------------- langchain/tools/base.py | 31 ++++++----- tests/unit_tests/agents/test_tools.py | 6 ++- 3 files changed, 56 insertions(+), 57 deletions(-) diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index 4afd533c54dda..be1c83b05d163 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -1,7 +1,7 @@ """Interface for tools.""" from functools import partial from inspect import signature -from typing import Any, Awaitable, Callable, Optional, Type, Union +from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union from pydantic import BaseModel, validate_arguments, validator @@ -30,56 +30,44 @@ def validate_func_not_partial(cls, func: Callable) -> Callable: @property def args(self) -> dict: + """The tool's input arguments.""" if self.args_schema is not None: return self.args_schema.schema()["properties"] inferred_model = validate_arguments(self.func).model # type: ignore - filtered_args = get_filtered_args(inferred_model, self.func, {"args", "kwargs"}) + filtered_args = get_filtered_args( + inferred_model, self.func, invalid_args={"args", "kwargs"} + ) if filtered_args: return filtered_args + # For backwards compatability, if the function signature is ambiguous, + # assume it takes a single string input. return {"tool_input": {"type": "string"}} - def _run(self, *args: Any, **kwargs: Any) -> str: + def _to_args_and_kwargs(self, tool_input: str | Dict) -> Tuple[Tuple, Dict]: + """Convert tool input to pydantic model.""" + args, kwargs = super()._to_args_and_kwargs(tool_input) + if self.is_single_input: + # For backwards compatability. If no schema is inferred, + # the tool must assume it should be run with a single input + all_args = list(args) + list(kwargs.values()) + if len(all_args) != 1: + raise ValueError( + f"Too many arguments to single-input tool {self.name}." + f" Args: {all_args}" + ) + return tuple(all_args), {} + return args, kwargs + + def _run(self, *args: Any, **kwargs: Any) -> Any: """Use the tool.""" return self.func(*args, **kwargs) - async def _arun(self, *args: Any, **kwargs: Any) -> str: + async def _arun(self, *args: Any, **kwargs: Any) -> Any: """Use the tool asynchronously.""" if self.coroutine: return await self.coroutine(*args, **kwargs) raise NotImplementedError("Tool does not support async") - @classmethod - def from_function( - cls, - func: Callable, - name: Optional[str] = None, - description: Optional[str] = None, - return_direct: bool = False, - args_schema: Optional[Type[BaseModel]] = None, - infer_schema: bool = True, - **kwargs: Any, - ) -> "Tool": - name = name or func.__name__ - description = description or func.__doc__ - assert ( - description is not None - ), "Function must have a docstring if description not provided." - - # Description example: - # search_api(query: str) - Searches the API for the query. - description = f"{name}{signature(func)} - {description.strip()}" - _args_schema = args_schema - if _args_schema is None and infer_schema: - _args_schema = create_schema_from_function(f"{name}Schema", func) - return cls( - name=name, - func=func, - args_schema=_args_schema, - description=description, - return_direct=return_direct, - **kwargs, - ) - # TODO: this is for backwards compatibility, remove in future def __init__( self, name: str, func: Callable, description: str, **kwargs: Any @@ -142,13 +130,21 @@ def search_api(query: str) -> str: def _make_with_name(tool_name: str) -> Callable: def _make_tool(func: Callable) -> Tool: - return Tool.from_function( - func, + assert func.__doc__, "Function must have a docstring" + # Description example: + # search_api(query: str) - Searches the API for the query. + description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}" + _args_schema = args_schema + if _args_schema is None and infer_schema: + _args_schema = create_schema_from_function(f"{tool_name}Schema", func) + tool_ = Tool( name=tool_name, + func=func, + args_schema=_args_schema, + description=description, return_direct=return_direct, - args_schema=args_schema, - infer_schema=infer_schema, ) + return tool_ return _make_tool diff --git a/langchain/tools/base.py b/langchain/tools/base.py index 78ccd03f84440..c400af43fabf8 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from inspect import signature -from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union from pydantic import ( BaseModel, @@ -19,15 +19,6 @@ from langchain.callbacks.base import BaseCallbackManager -def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]: - # For backwards compatability, if run_input is a string, - # pass as a positional argument. - if isinstance(run_input, str): - return (run_input,), {} - else: - return [], run_input - - class SchemaAnnotationError(TypeError): """Raised when 'args_schema' is missing or has an incorrect type annotation.""" @@ -84,13 +75,13 @@ def _create_subset_model( def get_filtered_args( inferred_model: Type[BaseModel], func: Callable, - invalid_keys: Optional[Set[str]] = None, + invalid_args: Optional[Set[str]] = None, ) -> dict: """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] valid_keys = signature(func).parameters - invalid_keys = invalid_keys or set() - return {k: schema[k] for k in valid_keys if k not in invalid_keys} + invalid_args = invalid_args or set() + return {k: schema[k] for k in valid_keys if k not in invalid_args} def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]: @@ -165,6 +156,14 @@ def _run(self, *args: Any, **kwargs: Any) -> Any: async def _arun(self, *args: Any, **kwargs: Any) -> Any: """Use the tool asynchronously.""" + def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + # For backwards compatability, if run_input is a string, + # pass as a positional argument. + if isinstance(tool_input, str): + return (tool_input,), {} + else: + return (), tool_input + def run( self, tool_input: Union[str, Dict], @@ -187,7 +186,7 @@ def run( **kwargs, ) try: - tool_args, tool_kwargs = _to_args_and_kwargs(tool_input) + tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) observation = self._run(*tool_args, **tool_kwargs) except (Exception, KeyboardInterrupt) as e: self.callback_manager.on_tool_error(e, verbose=verbose_) @@ -229,8 +228,8 @@ async def arun( ) try: # We then call the tool on the tool input to get an observation - args, kwargs = _to_args_and_kwargs(tool_input) - observation = await self._arun(*args, **kwargs) + tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) + observation = await self._arun(*tool_args, **tool_kwargs) except (Exception, KeyboardInterrupt) as e: if self.callback_manager.is_async: await self.callback_manager.on_tool_error(e, verbose=verbose_) diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index 49680be319e95..21fc9c388c15c 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -425,7 +425,7 @@ def the_tool(foo: str, bar: str) -> str: agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore -def test_tool_no_args_specified_assumes_str(): +def test_tool_no_args_specified_assumes_str() -> None: """Older tools could assume *args and **kwargs were passed in.""" def ambiguous_function(*args: Any, **kwargs: Any) -> str: @@ -439,3 +439,7 @@ def ambiguous_function(*args: Any, **kwargs: Any) -> str: ) expected_args = {"tool_input": {"type": "string"}} assert some_tool.args == expected_args + assert some_tool.run("foobar") == "foobar" + assert some_tool.run({"tool_input": "foobar"}) == "foobar" + with pytest.raises(ValueError, match="Too many arguments to single-input tool"): + some_tool.run({"tool_input": "foobar", "other_input": "bar"})