Skip to content

Commit

Permalink
Tool Single-Input, Structured Tool Any-Input
Browse files Browse the repository at this point in the history
  • Loading branch information
vowelparrot committed Apr 28, 2023
1 parent c91c71d commit ec64e7d
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 60 deletions.
64 changes: 26 additions & 38 deletions langchain/agents/tools.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
"""Interface for tools."""
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union

from pydantic import BaseModel, validate_arguments, validator
from pydantic import BaseModel, validator

from langchain.tools.base import (
BaseTool,
create_schema_from_function,
get_filtered_args,
)
from langchain.tools.base import BaseTool, StructuredTool


class Tool(BaseTool):
Expand All @@ -33,30 +28,21 @@ def args(self) -> dict:
"""The tool's input arguments."""
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
inferred_model = validate_arguments(self.func).model # type: ignore
filtered_args = get_filtered_args(
inferred_model, self.func, invalid_args={"args", "kwargs"}
)
if filtered_args:
return filtered_args
# For backwards compatability, if the function signature is ambiguous,
# For backwards compatibility, if the function signature is ambiguous,
# assume it takes a single string input.
return {"tool_input": {"type": "string"}}

def _to_args_and_kwargs(self, tool_input: str | Dict) -> Tuple[Tuple, Dict]:
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
"""Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input)
if self.is_single_input:
# For backwards compatability. If no schema is inferred,
# the tool must assume it should be run with a single input
all_args = list(args) + list(kwargs.values())
if len(all_args) != 1:
raise ValueError(
f"Too many arguments to single-input tool {self.name}."
f" Args: {all_args}"
)
return tuple(all_args), {}
return args, kwargs
# For backwards compatibility. The tool must be run with a single input
all_args = list(args) + list(kwargs.values())
if len(all_args) != 1:
raise ValueError(
f"Too many arguments to single-input tool {self.name}."
f" Args: {all_args}"
)
return tuple(all_args), {}

def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool."""
Expand Down Expand Up @@ -129,22 +115,24 @@ def search_api(query: str) -> str:
"""

def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable) -> Tool:
assert func.__doc__, "Function must have a docstring"
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{tool_name}Schema", func)
tool_ = Tool(
def _make_tool(func: Callable) -> BaseTool:
if infer_schema or args_schema is not None:
return StructuredTool.from_function(
func,
name=tool_name,
return_direct=return_direct,
args_schema=args_schema,
infer_schema=infer_schema,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
assert func.__doc__ is not None, "Function must have a docstring"
return Tool(
name=tool_name,
func=func,
args_schema=_args_schema,
description=description,
description=f"{tool_name} tool",
return_direct=return_direct,
)
return tool_

return _make_tool

Expand Down
83 changes: 77 additions & 6 deletions langchain/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -75,16 +75,17 @@ def _create_subset_model(
def get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
invalid_args: Optional[Set[str]] = None,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
invalid_args = invalid_args or set()
return {k: schema[k] for k in valid_keys if k not in invalid_args}
return {k: schema[k] for k in valid_keys}


def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
inferred_model = validate_arguments(func).model # type: ignore
# Pydantic adds placeholder virtual fields we need to strip
Expand All @@ -98,12 +99,23 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
"""Interface LangChain tools must implement."""

name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool.
You can provide few-shot examples as a part of the description.
"""
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
"""Whether to return the tool's output directly. Setting this to True means
that after the tool is called, the AgentExecutor will stop looping.
"""
verbose: bool = False
"""Whether to log the tool's progress."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
"""Callback manager for this tool."""

class Config:
"""Configuration for this pydantic object."""
Expand Down Expand Up @@ -157,7 +169,7 @@ async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""

def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatability, if run_input is a string,
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
Expand Down Expand Up @@ -253,3 +265,62 @@ async def arun(
def __call__(self, tool_input: Union[str, dict]) -> Any:
"""Make tool callable."""
return self.run(tool_input)


class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""

description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
"""The input arguments' schema."""
func: Callable[..., str]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function."""

@property
def args(self) -> dict:
"""The tool's input arguments."""
return self.args_schema.schema()["properties"]

def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool."""
return self.func(*args, **kwargs)

async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
raise NotImplementedError("Tool does not support async")

@classmethod
def from_function(
cls,
func: Callable,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
**kwargs: Any,
) -> StructuredTool:
name = name or func.__name__
description = description or func.__doc__
assert (
description is not None
), "Function must have a docstring if description not provided."

# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{name}{signature(func)} - {description.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", func)
return cls(
name=name,
func=func,
args_schema=_args_schema,
description=description,
return_direct=return_direct,
**kwargs,
)
33 changes: 17 additions & 16 deletions tests/unit_tests/agents/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from langchain.agents.react.base import ReActDocstoreAgent, ReActTextWorldAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchAgent
from langchain.agents.tools import Tool, tool
from langchain.tools.base import BaseTool, SchemaAnnotationError
from langchain.tools.base import BaseTool, SchemaAnnotationError, StructuredTool


def test_unnamed_decorator() -> None:
Expand All @@ -27,7 +27,7 @@ def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"

assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
assert search_api.name == "search_api"
assert not search_api.return_direct
assert search_api("test") == "API result"
Expand Down Expand Up @@ -145,7 +145,7 @@ def tool_func(arg1: int, arg2: bool, arg3: Optional[dict] = None) -> str:
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"

assert isinstance(tool_func, Tool)
assert isinstance(tool_func, BaseTool)
assert tool_func.args_schema == _MockSchema


Expand All @@ -159,7 +159,7 @@ def structured_tool_input(
"""Return the arguments directly."""
return f"{arg1} {arg2} {arg3}"

assert isinstance(structured_tool_input, Tool)
assert isinstance(structured_tool_input, BaseTool)
assert structured_tool_input.args_schema is not None
assert (
structured_tool_input.args_schema.schema()["properties"]
Expand All @@ -171,14 +171,14 @@ def structured_tool_input(
def test_structured_args_decorator_no_infer_schema() -> None:
"""Test functionality with structured arguments parsed as a decorator."""

@tool(infer_schema=False)
@tool
def structured_tool_input(
arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None
) -> str:
"""Return the arguments directly."""
return f"{arg1}, {arg2}, {opt_arg}"

assert isinstance(structured_tool_input, Tool)
assert isinstance(structured_tool_input, BaseTool)
assert structured_tool_input.name == "structured_tool_input"
args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}}
expected_result = "1, 0.001, {'foo': 'bar'}"
Expand All @@ -193,8 +193,9 @@ def unstructured_tool_input(tool_input: str) -> str:
"""Return the arguments directly."""
return f"{tool_input}"

assert isinstance(unstructured_tool_input, Tool)
assert isinstance(unstructured_tool_input, BaseTool)
assert unstructured_tool_input.args_schema is None
assert unstructured_tool_input.run("foo") == "foo"


def test_base_tool_inheritance_base_schema() -> None:
Expand Down Expand Up @@ -225,18 +226,18 @@ def test_tool_lambda_args_schema() -> None:
func=lambda tool_input: tool_input,
)
assert tool.args_schema is None
expected_args = {"tool_input": {"title": "Tool Input"}}
expected_args = {"tool_input": {"type": "string"}}
assert tool.args == expected_args


def test_tool_lambda_multi_args_schema() -> None:
def test_structured_tool_lambda_multi_args_schema() -> None:
"""Test args schema inference when the tool argument is a lambda function."""
tool = Tool(
tool = StructuredTool.from_function(
name="tool",
description="A tool",
func=lambda tool_input, other_arg: f"{tool_input}{other_arg}", # type: ignore
)
assert tool.args_schema is None
assert tool.args_schema is not None
expected_args = {
"tool_input": {"title": "Tool Input"},
"other_arg": {"title": "Other Arg"},
Expand Down Expand Up @@ -268,7 +269,7 @@ def empty_tool_input() -> str:
"""Return a constant."""
return "the empty result"

assert isinstance(empty_tool_input, Tool)
assert isinstance(empty_tool_input, BaseTool)
assert empty_tool_input.name == "empty_tool_input"
assert empty_tool_input.args == {}
assert empty_tool_input.run({}) == "the empty result"
Expand All @@ -282,7 +283,7 @@ def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"

assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
assert search_api.name == "search"
assert not search_api.return_direct

Expand All @@ -295,7 +296,7 @@ def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"

assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
assert search_api.name == "search"
assert search_api.return_direct

Expand All @@ -308,7 +309,7 @@ def search_api(query: str) -> str:
"""Search the API for the query."""
return "API result"

assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
assert search_api.name == "search_api"
assert search_api.return_direct

Expand All @@ -325,7 +326,7 @@ def search_api(
"""Search the API for the query."""
return f"arg_0={arg_0}, arg_1={arg_1}, ping={ping}"

assert isinstance(search_api, Tool)
assert isinstance(search_api, BaseTool)
result = search_api.run(
tool_input={
"arg_0": "foo",
Expand Down

0 comments on commit ec64e7d

Please sign in to comment.