Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic tool -> single purpose #3697

Merged
merged 3 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 36 additions & 26 deletions langchain/agents/tools.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
"""Interface for tools."""
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Type, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union

from pydantic import BaseModel, validate_arguments, validator
from pydantic import BaseModel, validator

from langchain.tools.base import (
BaseTool,
create_schema_from_function,
get_filtered_args,
)
from langchain.tools.base import BaseTool, StructuredTool


class Tool(BaseTool):
Expand All @@ -30,25 +25,38 @@ def validate_func_not_partial(cls, func: Callable) -> Callable:

@property
def args(self) -> dict:
"""The tool's input arguments."""
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self.func).model # type: ignore
return get_filtered_args(inferred_model, self.func)
# For backwards compatibility, if the function signature is ambiguous,
# assume it takes a single string input.
return {"tool_input": {"type": "string"}}

def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
"""Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input)
# For backwards compatibility. The tool must be run with a single input
all_args = list(args) + list(kwargs.values())
if len(all_args) != 1:
raise ValueError(
f"Too many arguments to single-input tool {self.name}."
f" Args: {all_args}"
)
return tuple(all_args), {}

def _run(self, *args: Any, **kwargs: Any) -> str:
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool."""
return self.func(*args, **kwargs)

async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
raise NotImplementedError("Tool does not support async")

# TODO: this is for backwards compatibility, remove in future
def __init__(
self, name: str, func: Callable[[str], str], description: str, **kwargs: Any
self, name: str, func: Callable, description: str, **kwargs: Any
) -> None:
"""Initialize tool."""
super(Tool, self).__init__(
Expand Down Expand Up @@ -107,22 +115,24 @@ def search_api(query: str) -> str:
"""

def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable) -> Tool:
assert func.__doc__, "Function must have a docstring"
# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{tool_name}{signature(func)} - {func.__doc__.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{tool_name}Schema", func)
tool_ = Tool(
def _make_tool(func: Callable) -> BaseTool:
if infer_schema or args_schema is not None:
return StructuredTool.from_function(
func,
name=tool_name,
return_direct=return_direct,
args_schema=args_schema,
infer_schema=infer_schema,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
assert func.__doc__ is not None, "Function must have a docstring"
return Tool(
name=tool_name,
func=func,
args_schema=_args_schema,
description=description,
description=f"{tool_name} tool",
return_direct=return_direct,
)
return tool_

return _make_tool

Expand Down
105 changes: 90 additions & 15 deletions langchain/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union

from pydantic import (
BaseModel,
Expand All @@ -19,15 +19,6 @@
from langchain.callbacks.base import BaseCallbackManager


def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
# For backwards compatability, if run_input is a string,
# pass as a positional argument.
if isinstance(run_input, str):
return (run_input,), {}
else:
return [], run_input


class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""

Expand Down Expand Up @@ -81,14 +72,20 @@ def _create_subset_model(
return create_model(name, **fields) # type: ignore


def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
def get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}


def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
def create_schema_from_function(
model_name: str,
func: Callable,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
inferred_model = validate_arguments(func).model # type: ignore
# Pydantic adds placeholder virtual fields we need to strip
Expand All @@ -102,12 +99,23 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
"""Interface LangChain tools must implement."""

name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool.

You can provide few-shot examples as a part of the description.
"""
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
return_direct: bool = False
"""Whether to return the tool's output directly. Setting this to True means

that after the tool is called, the AgentExecutor will stop looping.
"""
verbose: bool = False
"""Whether to log the tool's progress."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
"""Callback manager for this tool."""

class Config:
"""Configuration for this pydantic object."""
Expand Down Expand Up @@ -160,6 +168,14 @@ def _run(self, *args: Any, **kwargs: Any) -> Any:
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""

def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
return (), tool_input

def run(
self,
tool_input: Union[str, Dict],
Expand All @@ -182,7 +198,7 @@ def run(
**kwargs,
)
try:
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
observation = self._run(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
Expand Down Expand Up @@ -224,8 +240,8 @@ async def arun(
)
try:
# We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(tool_input)
observation = await self._arun(*args, **kwargs)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
observation = await self._arun(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_)
Expand All @@ -249,3 +265,62 @@ async def arun(
def __call__(self, tool_input: Union[str, dict]) -> Any:
"""Make tool callable."""
return self.run(tool_input)


class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""

description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
"""The input arguments' schema."""
func: Callable[..., str]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function."""

@property
def args(self) -> dict:
"""The tool's input arguments."""
return self.args_schema.schema()["properties"]

def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool."""
return self.func(*args, **kwargs)

async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
raise NotImplementedError("Tool does not support async")

@classmethod
def from_function(
cls,
func: Callable,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
infer_schema: bool = True,
**kwargs: Any,
) -> StructuredTool:
name = name or func.__name__
description = description or func.__doc__
assert (
description is not None
), "Function must have a docstring if description not provided."

# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{name}{signature(func)} - {description.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", func)
return cls(
name=name,
func=func,
args_schema=_args_schema,
description=description,
return_direct=return_direct,
**kwargs,
)
Loading