Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Tool single-input #3684

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions langchain/agents/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Interface for tools."""
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Optional, Type, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union

from pydantic import BaseModel, validate_arguments, validator

Expand Down Expand Up @@ -30,25 +30,47 @@ def validate_func_not_partial(cls, func: Callable) -> Callable:

@property
def args(self) -> dict:
"""The tool's input arguments."""
if self.args_schema is not None:
return self.args_schema.schema()["properties"]
else:
inferred_model = validate_arguments(self.func).model # type: ignore
return get_filtered_args(inferred_model, self.func)

def _run(self, *args: Any, **kwargs: Any) -> str:
inferred_model = validate_arguments(self.func).model # type: ignore
filtered_args = get_filtered_args(
inferred_model, self.func, invalid_args={"args", "kwargs"}
)
if filtered_args:
return filtered_args
# For backwards compatability, if the function signature is ambiguous,
# assume it takes a single string input.
return {"tool_input": {"type": "string"}}

def _to_args_and_kwargs(self, tool_input: str | Dict) -> Tuple[Tuple, Dict]:
"""Convert tool input to pydantic model."""
args, kwargs = super()._to_args_and_kwargs(tool_input)
if self.is_single_input:
# For backwards compatability. If no schema is inferred,
# the tool must assume it should be run with a single input
all_args = list(args) + list(kwargs.values())
if len(all_args) != 1:
raise ValueError(
f"Too many arguments to single-input tool {self.name}."
f" Args: {all_args}"
)
return tuple(all_args), {}
return args, kwargs

def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool."""
return self.func(*args, **kwargs)

async def _arun(self, *args: Any, **kwargs: Any) -> str:
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
raise NotImplementedError("Tool does not support async")

# TODO: this is for backwards compatibility, remove in future
def __init__(
self, name: str, func: Callable[[str], str], description: str, **kwargs: Any
self, name: str, func: Callable, description: str, **kwargs: Any
) -> None:
"""Initialize tool."""
super(Tool, self).__init__(
Expand Down
34 changes: 19 additions & 15 deletions langchain/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union

from pydantic import (
BaseModel,
Expand All @@ -19,15 +19,6 @@
from langchain.callbacks.base import BaseCallbackManager


def _to_args_and_kwargs(run_input: Union[str, Dict]) -> Tuple[Sequence, dict]:
# For backwards compatability, if run_input is a string,
# pass as a positional argument.
if isinstance(run_input, str):
return (run_input,), {}
else:
return [], run_input


class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""

Expand Down Expand Up @@ -81,11 +72,16 @@ def _create_subset_model(
return create_model(name, **fields) # type: ignore


def get_filtered_args(inferred_model: Type[BaseModel], func: Callable) -> dict:
def get_filtered_args(
inferred_model: Type[BaseModel],
func: Callable,
invalid_args: Optional[Set[str]] = None,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {k: schema[k] for k in valid_keys}
invalid_args = invalid_args or set()
return {k: schema[k] for k in valid_keys if k not in invalid_args}


def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseModel]:
Expand Down Expand Up @@ -160,6 +156,14 @@ def _run(self, *args: Any, **kwargs: Any) -> Any:
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
"""Use the tool asynchronously."""

def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatability, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
return (), tool_input

def run(
self,
tool_input: Union[str, Dict],
Expand All @@ -182,7 +186,7 @@ def run(
**kwargs,
)
try:
tool_args, tool_kwargs = _to_args_and_kwargs(tool_input)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
observation = self._run(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
self.callback_manager.on_tool_error(e, verbose=verbose_)
Expand Down Expand Up @@ -224,8 +228,8 @@ async def arun(
)
try:
# We then call the tool on the tool input to get an observation
args, kwargs = _to_args_and_kwargs(tool_input)
observation = await self._arun(*args, **kwargs)
tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input)
observation = await self._arun(*tool_args, **tool_kwargs)
except (Exception, KeyboardInterrupt) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_tool_error(e, verbose=verbose_)
Expand Down
22 changes: 21 additions & 1 deletion tests/unit_tests/agents/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test tool utils."""
from datetime import datetime
from functools import partial
from typing import Optional, Type, Union
from typing import Any, Optional, Type, Union
from unittest.mock import MagicMock

import pydantic
Expand Down Expand Up @@ -423,3 +423,23 @@ def the_tool(foo: str, bar: str) -> str:
f" multi-input tool {the_tool.name}.",
):
agent_cls.from_llm_and_tools(MagicMock(), [the_tool]) # type: ignore


def test_tool_no_args_specified_assumes_str() -> None:
"""Older tools could assume *args and **kwargs were passed in."""

def ambiguous_function(*args: Any, **kwargs: Any) -> str:
"""An ambiguously defined function."""
return args[0]

some_tool = Tool(
name="chain_run",
description="Run the chain",
func=ambiguous_function,
)
expected_args = {"tool_input": {"type": "string"}}
assert some_tool.args == expected_args
assert some_tool.run("foobar") == "foobar"
assert some_tool.run({"tool_input": "foobar"}) == "foobar"
with pytest.raises(ValueError, match="Too many arguments to single-input tool"):
some_tool.run({"tool_input": "foobar", "other_input": "bar"})