Skip to content

Commit

Permalink
feat(tasks): accept args in BaseTask.run()
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Feb 17, 2025
1 parent 6f4ac33 commit 5772e2e
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 11 deletions.
6 changes: 3 additions & 3 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Structure(RuleMixin, SerializableMixin, RunnableMixin["Structure"], ABC):
)
meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True)
fail_fast: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_execution_args: tuple = ()
_execution_args: tuple = field(factory=tuple, init=False)
_event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue()), init=False)

def __attrs_post_init__(self) -> None:
Expand Down Expand Up @@ -154,7 +154,6 @@ def resolve_relationships(self) -> None:
@observable
def before_run(self, args: Any) -> None:
super().before_run(args)
self._execution_args = args

[task.reset() for task in self.tasks]

Expand Down Expand Up @@ -197,7 +196,8 @@ def after_run(self) -> None:
def add_task(self, task: BaseTask) -> BaseTask: ...

@observable
def run(self, *args) -> Structure:
def run(self, *args, **kwargs) -> Structure:
self._execution_args = args
self.before_run(args)

result = self.try_run(*args)
Expand Down
14 changes: 12 additions & 2 deletions griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class State(Enum):

output: Optional[T] = field(default=None, init=False)
context: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})
_execution_args: tuple = field(factory=tuple, init=False)

@property
def execution_args(self) -> tuple:
return self._execution_args

def __rshift__(self, other: BaseTask | list[BaseTask]) -> BaseTask | list[BaseTask]:
if isinstance(other, list):
Expand Down Expand Up @@ -163,8 +168,10 @@ def before_run(self) -> None:
),
)

def run(self) -> T:
def run(self, *args) -> T:
try:
self._execution_args = args

self.state = BaseTask.State.RUNNING

self.before_run()
Expand Down Expand Up @@ -212,6 +219,7 @@ def can_run(self) -> bool:
def reset(self) -> BaseTask:
self.state = BaseTask.State.PENDING
self.output = None
self._execution_args = ()

return self

Expand All @@ -222,7 +230,9 @@ def try_run(self) -> T: ...
def full_context(self) -> dict[str, Any]:
# Need to deep copy so that the serialized context doesn't contain non-serializable data
context = deepcopy(self.context)
if self.structure is not None:
if self.structure is None:
context.update({"args": self._execution_args})
else:
context.update(self.structure.context(self))

return context
3 changes: 2 additions & 1 deletion tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,12 @@ def test_context(self):

agent.add_task(task)

agent.run("hello")
agent.run("hello", foo="bar")

context = agent.context(task)

assert context["structure"] == agent
assert context["args"] == ("hello",)

def test_task_memory_defaults(self, mock_config):
agent = Agent()
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/structures/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def test_context(self):

assert context["parent_output"] is None

pipeline.run()
pipeline.run("hello", foo="bar")

context = pipeline.context(task)

Expand All @@ -365,6 +365,7 @@ def test_context(self):
assert context["structure"] == pipeline
assert context["parent"] == parent
assert context["child"] == child
assert context["args"] == ("hello",)

def test_run_with_error_artifact(self, error_artifact_task, waiting_task):
end_task = PromptTask("end")
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/structures/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def test_context(self):

assert context["parent_outputs"] == {}

workflow.run()
workflow.run("hello", foo="bar")

context = workflow.context(task)

Expand All @@ -749,6 +749,7 @@ def test_context(self):
assert context["structure"] == workflow
assert context["parents"] == {parent.id: parent}
assert context["children"] == {child.id: child}
assert context["args"] == ("hello",)

def test_run_with_error_artifact(self, error_artifact_task, waiting_task):
end_task = PromptTask("end")
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/tasks/test_base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ def test_full_context(self, task):
task.structure._execution_args = ("foo", "bar")

assert task.full_context == {"args": ("foo", "bar"), "structure": task.structure}
assert task.structure.execution_args == ("foo", "bar")

task.structure = None
task._execution_args = ("foo", "bar")

assert task.full_context == {"args": ("foo", "bar")}
assert task.execution_args == ("foo", "bar")

def test_is_pending(self, task):
task.state = task.State.PENDING
Expand All @@ -249,3 +256,15 @@ def test___str__(self, task):
assert str(task) == "foobar"
task.output = None
assert str(task) == ""

def test_run_args(self, task):
task.run("foo", "bar")

assert task._execution_args == ("foo", "bar")

def test_args_full_context(self):
task = MockTask()
task.context = {"foo": "buzz"}
task.run("foo", "bar")

assert task.full_context["args"] == ("foo", "bar")
10 changes: 7 additions & 3 deletions tests/unit/tasks/test_base_text_input_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ def test_full_context(self):
subtask = MockTextInputTask("test", context={"foo": "bar"})
child = MockTextInputTask("child")

assert parent.full_context == {}
assert subtask.full_context == {"foo": "bar"}
assert child.full_context == {}
assert parent.full_context == {
"args": (),
}
assert subtask.full_context == {"args": (), "foo": "bar"}
assert child.full_context == {
"args": (),
}

pipeline = Pipeline()

Expand Down

0 comments on commit 5772e2e

Please sign in to comment.