Skip to content

Commit

Permalink
feat(tasks): accept args in BaseTask.run() (#1598)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Feb 18, 2025
1 parent 4ee4515 commit 89c32d4
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
14 changes: 12 additions & 2 deletions griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class State(Enum):

output: Optional[T] = field(default=None, init=False)
context: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})
_execution_args: tuple = field(factory=tuple, init=False)

@property
def execution_args(self) -> tuple:
return self._execution_args

def __rshift__(self, other: BaseTask | list[BaseTask]) -> BaseTask | list[BaseTask]:
if isinstance(other, list):
Expand Down Expand Up @@ -163,8 +168,10 @@ def before_run(self) -> None:
),
)

def run(self) -> T:
def run(self, *args) -> T:
try:
self._execution_args = args

self.state = BaseTask.State.RUNNING

self.before_run()
Expand Down Expand Up @@ -212,6 +219,7 @@ def can_run(self) -> bool:
def reset(self) -> BaseTask:
self.state = BaseTask.State.PENDING
self.output = None
self._execution_args = ()

return self

Expand All @@ -222,7 +230,9 @@ def try_run(self) -> T: ...
def full_context(self) -> dict[str, Any]:
# Need to deep copy so that the serialized context doesn't contain non-serializable data
context = deepcopy(self.context)
if self.structure is not None:
if self.structure is None:
context.update({"args": self._execution_args})
else:
context.update(self.structure.context(self))

return context
19 changes: 19 additions & 0 deletions tests/unit/tasks/test_base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ def test_full_context(self, task):
task.structure._execution_args = ("foo", "bar")

assert task.full_context == {"args": ("foo", "bar"), "structure": task.structure}
assert task.structure.execution_args == ("foo", "bar")

task.structure = None
task._execution_args = ("foo", "bar")

assert task.full_context == {"args": ("foo", "bar")}
assert task.execution_args == ("foo", "bar")

def test_is_pending(self, task):
task.state = task.State.PENDING
Expand All @@ -249,3 +256,15 @@ def test___str__(self, task):
assert str(task) == "foobar"
task.output = None
assert str(task) == ""

def test_run_args(self, task):
task.run("foo", "bar")

assert task._execution_args == ("foo", "bar")

def test_args_full_context(self):
task = MockTask()
task.context = {"foo": "buzz"}
task.run("foo", "bar")

assert task.full_context["args"] == ("foo", "bar")
10 changes: 7 additions & 3 deletions tests/unit/tasks/test_base_text_input_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ def test_full_context(self):
subtask = MockTextInputTask("test", context={"foo": "bar"})
child = MockTextInputTask("child")

assert parent.full_context == {}
assert subtask.full_context == {"foo": "bar"}
assert child.full_context == {}
assert parent.full_context == {
"args": (),
}
assert subtask.full_context == {"args": (), "foo": "bar"}
assert child.full_context == {
"args": (),
}

pipeline = Pipeline()

Expand Down

0 comments on commit 89c32d4

Please sign in to comment.