Skip to content

Commit

Permalink
feat(tasks-prompt): add reflect_on_tool_use flag to disable LLM ref…
Browse files Browse the repository at this point in the history
…lection
  • Loading branch information
collindutter committed Mar 6, 2025
1 parent 9b6bfff commit 531cee8
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 11 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,9 @@ The important thing to note here is that no matter how big the webpage is it can
In the above example, we set [off_prompt](https://docs.griptape.ai/stable/griptape-framework/structures/task-memory.md#off-prompt) to `True`, which means that the LLM can never see the data it manipulates, but can send it to other Tools.

> [!IMPORTANT]
> This example uses Griptape's [PromptTask](https://docs.griptape.ai/stable/griptape-framework/structures/tasks/#prompt-task) with `tools`, which requires a highly capable LLM to function correctly. By default, Griptape uses the [OpenAiChatPromptDriver](https://docs.griptape.ai/stable/griptape-framework/drivers/prompt-drivers/#openai-chat); for another powerful LLM try swapping to the [AnthropicPromptDriver](https://docs.griptape.ai/stable/griptape-framework/drivers/prompt-drivers/#anthropic)!
> If you're using a less powerful LLM, consider using the [ToolTask](https://docs.griptape.ai/stable/griptape-framework/structures/tasks/#tool-task) instead, as the `PromptTask` with `tools` might not work properly or at all.
> This example uses Griptape's [PromptTask](https://docs.griptape.ai/stable/griptape-framework/structures/tasks/#prompt-task) with multiple `tools`, which requires a highly capable LLM to function correctly.
> If you're using a less powerful LLM, consider setting [reflect_on_tool_use](https://docs.griptape.ai/latest/griptape-framework/structures/tasks/#reflect-on-tool-use) to `False` to have the LLM return tool outputs directly.
> You can then make the same tool calls yourself rather than having the LLM coordinate them.
[Check out our docs](https://docs.griptape.ai/stable/griptape-framework/drivers/prompt-drivers/) to learn more about how to use Griptape with other LLM providers like Anthropic, Claude, Hugging Face, and Azure.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from griptape.drivers.web_search.duck_duck_go import DuckDuckGoWebSearchDriver
from griptape.tasks import PromptTask
from griptape.tools import WebScraperTool, WebSearchTool

search_task = PromptTask(
tools=[WebSearchTool(web_search_driver=DuckDuckGoWebSearchDriver())],
reflect_on_tool_use=False,
)
search_results = search_task.run("Do two searches, one for 'vim' and one for 'emacs'.")

scrape_task = PromptTask(
tools=[WebScraperTool()],
reflect_on_tool_use=True,
)
answer = scrape_task.run(["Compare and contrast vim and emacs:", search_results.to_text()])
20 changes: 20 additions & 0 deletions docs/griptape-framework/structures/tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,22 @@ You can pass in one or more Tools which the LLM will decide to use through Chain
--8<-- "docs/griptape-framework/structures/logs/tasks_4.txt"
```

#### Reflect On Tool Use

By default, Griptape will pass the results of Tool runs back to the LLM for reflection. This enables the LLM to reason about the results and potentially use additional tools.

However, there may be times where you may want the LLM to give you back the results directly, without reflection.
You can disable this behavior by setting [reflect_on_tool_use](../../reference/griptape/tasks/prompt_task.md#griptape.tasks.prompt_task.PromptTask.reflect_on_tool_use) to `False`.

```python
--8<-- "docs/griptape-framework/structures/src/tasks_reflect_on_tool_use.py"
```

!!! important

Disabling reflection will prevent the LLM from using one Tool to inform the use of another Tool.
Instead, you must coordinate the Tool uses yourself.

### Images

If the model supports it, you can also pass image inputs:
Expand All @@ -98,6 +114,10 @@ If the model supports it, you can also pass image inputs:

## Tool Task

!!! warning

`ToolTask` is deprecated and will be removed in a future version. Use [Prompt Task](./tasks.md#prompt-task) with [Reflect On Tool Use](./tasks.md#reflect-on-tool-use) instead.

Another way to use [Griptape Tools](../../griptape-framework/tools/index.md), is with a [Tool Task](../../reference/griptape/tasks/tool_task.md).
This Task takes in a single Tool which the LLM will use without Chain of Thought (CoT) reasoning. Because this Task does not use CoT, it is better suited for less capable models.

Expand Down
17 changes: 8 additions & 9 deletions griptape/tasks/prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class PromptTask(
kw_only=True,
)
response_stop_sequence: str = field(default=RESPONSE_STOP_SEQUENCE, kw_only=True)
reflect_on_tool_use: bool = field(default=True, kw_only=True)

@property
def rulesets(self) -> list:
Expand Down Expand Up @@ -202,19 +203,17 @@ def try_run(self) -> ListArtifact | TextArtifact | AudioArtifact | GenericArtifa
if self.tools:
subtask = self.add_subtask(ActionsSubtask(output))

while True:
if subtask.output is None:
if len(self.subtasks) >= self.max_subtasks:
subtask.output = ErrorArtifact(f"Exceeded tool limit of {self.max_subtasks} subtasks per task")
else:
subtask.run()
while subtask.output is None:
if len(self.subtasks) >= self.max_subtasks:
subtask.output = ErrorArtifact(f"Exceeded tool limit of {self.max_subtasks} subtasks per task")
else:
subtask.run()

if self.reflect_on_tool_use:
output = self.prompt_driver.run(self.prompt_stack).to_artifact(
meta={"is_react_prompt": not self.prompt_driver.use_native_tools}
)
subtask = self.add_subtask(ActionsSubtask(output))
else:
break

output = subtask.output

Expand All @@ -227,7 +226,7 @@ def try_run(self) -> ListArtifact | TextArtifact | AudioArtifact | GenericArtifa
return ModelArtifact(TypeAdapter(self.output_schema).validate_json(output.value))
else:
raise ValueError(f"Unsupported output schema type: {type(self.output_schema)}")
elif isinstance(output, (TextArtifact, AudioArtifact, JsonArtifact, ErrorArtifact)):
elif isinstance(output, (ListArtifact, TextArtifact, AudioArtifact, JsonArtifact, ErrorArtifact)):
return output
else:
raise ValueError(f"Unsupported output type: {type(output)}")
Expand Down
6 changes: 6 additions & 0 deletions griptape/tasks/tool_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import re
import warnings
from typing import TYPE_CHECKING, Optional

from attrs import define, field
Expand Down Expand Up @@ -64,6 +65,11 @@ def actions_schema(self) -> Schema:
return self._actions_schema_for_tools([self.tool])

def try_run(self) -> ListArtifact | TextArtifact | ErrorArtifact:
warnings.warn(
"`ToolTask` is deprecated and will be removed in a future release. Use `PromptTask` with `reflect_on_tool_use=False` instead.",
DeprecationWarning,
stacklevel=2,
)
result = self.prompt_driver.run(self.prompt_stack)

if self.prompt_driver.use_native_tools:
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/tasks/test_prompt_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,18 @@ def test_parse_output_schema(self, structured_output_strategy, output_schema, ex
assert output.value.model_dump_json() == expected_output.model_dump_json()
else:
assert output.value == expected_output

@pytest.mark.parametrize(
("reflect_on_tool_use", "expected"),
[(True, "mock output"), (False, "ack test-value")],
)
def test_reflect_on_tool_use(self, reflect_on_tool_use, expected):
task = PromptTask(
tools=[MockTool()],
prompt_driver=MockPromptDriver(use_native_tools=True),
reflect_on_tool_use=reflect_on_tool_use,
)

result = task.run()

assert result.to_text() == expected
8 changes: 8 additions & 0 deletions tests/unit/tasks/test_tool_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,11 @@ def test_from_dict(self):

deserialized_tool_task = ToolTask.from_dict(serialized_tool_task)
assert isinstance(deserialized_tool_task, ToolTask)

def test_deprecated_warning(self, agent):
task = ToolTask(tool=MockTool())

agent.add_task(task)

with pytest.warns(DeprecationWarning):
task.run()

0 comments on commit 531cee8

Please sign in to comment.