Skip to content

Commit

Permalink
feat: Add error message to execution if failed
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Sep 2, 2024
1 parent 0791d51 commit 3395093
Show file tree
Hide file tree
Showing 16 changed files with 222 additions and 138 deletions.
26 changes: 25 additions & 1 deletion agents-api/agents_api/autogen/Executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ class CreateExecutionRequest(BaseModel):
"""
The input to the execution
"""
output: Any | None = None
"""
The output of the execution if it succeeded
"""
error: str | None = None
"""
The error of the execution if it failed
"""
metadata: dict[str, Any] | None = None


Expand Down Expand Up @@ -51,6 +59,14 @@ class Execution(BaseModel):
"""
The input to the execution
"""
output: Any | None = None
"""
The output of the execution if it succeeded
"""
error: str | None = None
"""
The error of the execution if it failed
"""
created_at: Annotated[AwareDatetime, Field(json_schema_extra={"readOnly": True})]
"""
When this resource was created as UTC date-time
Expand Down Expand Up @@ -80,7 +96,15 @@ class TransitionEvent(BaseModel):
)
type: Annotated[
Literal[
"finish", "branch_finish", "wait", "resume", "error", "step", "cancelled"
"init",
"branch_init",
"finish",
"finish_branch",
"wait",
"resume",
"error",
"step",
"cancelled",
],
Field(json_schema_extra={"readOnly": True}),
]
Expand Down
10 changes: 9 additions & 1 deletion agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,15 @@ class ListResponse(BaseModel, Generic[DataT]):


TransitionType = Literal[
"finish", "branch_finish", "wait", "resume", "error", "step", "cancelled"
"init",
"branch_init",
"finish",
"finish_branch",
"wait",
"resume",
"error",
"step",
"cancelled",
]

assert Transition.model_fields["type"].annotation == TransitionType
Expand Down
26 changes: 17 additions & 9 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
CreateTaskRequest,
CreateTransitionRequest,
Execution,
ExecutionStatus,
PartialTaskSpecDef,
PatchTaskRequest,
Session,
Expand All @@ -26,38 +27,40 @@
)

### NOTE: Here, "init" is NOT a real state, but a placeholder for the start state of the state machine
valid_transitions = {
valid_transitions: dict[TransitionType, list[TransitionType]] = {
# Start state
"init": ["wait", "error", "step", "cancelled", "finish", "branch_finish"],
"init": ["wait", "error", "step", "cancelled", "init_branch"],
"init_branch": ["wait", "error", "step", "cancelled"],
# End states
"finish": [],
"error": [],
"cancelled": [],
# Intermediate states
"wait": ["resume", "error", "cancelled"],
"resume": ["wait", "error", "step", "finish", "cancelled", "branch_finish"],
"step": ["wait", "error", "step", "finish", "cancelled", "branch_finish"],
"branch_finish": ["wait", "error", "step", "finish", "cancelled"],
"resume": ["wait", "error", "cancelled", "step", "finish", "finish_branch", "init_branch"],
"step": ["wait", "error", "cancelled", "step", "finish", "finish_branch", "init_branch"],
"finish_branch": ["wait", "error", "cancelled", "step", "finish", "init_branch"],
}

valid_previous_statuses = {
valid_previous_statuses: dict[ExecutionStatus, list[ExecutionStatus]] = {
"running": ["queued", "starting", "awaiting_input"],
"cancelled": ["queued", "starting", "awaiting_input", "running"],
}

transition_to_execution_status = {
transition_to_execution_status: dict[TransitionType, ExecutionStatus] = {
"init": "queued",
"init_branch": "running",
"wait": "awaiting_input",
"resume": "running",
"step": "running",
"finish": "succeeded",
"branch_finish": "running",
"finish_branch": "running",
"error": "failed",
"cancelled": "cancelled",
}


PendingTransition: Type[BaseModel] = create_partial_model(CreateTransitionRequest)
PartialTransition: Type[BaseModel] = create_partial_model(CreateTransitionRequest)


class ExecutionInput(BaseModel):
Expand Down Expand Up @@ -105,6 +108,11 @@ def current_step(self) -> Annotated[WorkflowStep, Field(exclude=True)]:
def is_last_step(self) -> Annotated[bool, Field(exclude=True)]:
return (self.cursor.step + 1) == len(self.current_workflow.steps)

@computed_field
@property
def is_first_step(self) -> Annotated[bool, Field(exclude=True)]:
return self.cursor.step == 0

def model_dump(self, *args, **kwargs) -> dict[str, Any]:
dump = super().model_dump(*args, **kwargs)
dump["_"] = self.current_input
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,30 @@

def validate_transition_targets(data: CreateTransitionRequest) -> None:
# Make sure the current/next targets are valid
if data.type in ("finish", "branch_finish", "error", "cancelled"):
assert (
data.next is None
), "Next target must be None for finish/branch_finish/error/cancelled"
match data.type:
case "finish" | "finish_branch" | "error" | "cancelled":
assert (
data.next is None
), "Next target must be None for finish/finish_branch/error/cancelled"

case "init_branch" | "init":
assert (
data.next and data.current.step == data.next.step == 0
), "Next target must be same as current for init_branch/init and step 0"

if data.type in ("wait", "init"):
assert data.next is None, "Next target must be None for wait/init"
case "wait":
assert data.next is None, "Next target must be None for wait"

if data.type in ("resume", "step"):
assert data.next is not None, "Next target must be provided for resume/step"
case "resume" | "step":
assert data.next is not None, "Next target must be provided for resume/step"

if data.next.workflow == data.current.workflow:
assert (
data.next.step > data.current.step
), "Next step must be greater than current"
if data.next.workflow == data.current.workflow:
assert (
data.next.step > data.current.step
), "Next step must be greater than current"

case _:
raise ValueError(f"Invalid transition type: {data.type}")


@rewrap_exceptions(
Expand Down Expand Up @@ -159,6 +168,8 @@ def create_execution_transition(
data=UpdateExecutionRequest(
status=transition_to_execution_status[data.type]
),
output=data.output if data.type == "finish" else None,
error=str(data.output) if data.type == "error" and data.output else None,
)
)

Expand Down
10 changes: 7 additions & 3 deletions agents-api/agents_api/models/execution/update_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def update_execution(
task_id: UUID,
execution_id: UUID,
data: UpdateExecutionRequest,
output: dict | None = None,
error: str | None = None,
) -> tuple[list[str], dict]:
developer_id = str(developer_id)
task_id = str(task_id)
Expand All @@ -64,6 +66,8 @@ def update_execution(
**execution_data,
"task_id": task_id,
"execution_id": execution_id,
"output": output,
"error": error,
}
)

Expand Down Expand Up @@ -98,9 +102,9 @@ def update_execution(
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(
developer_id,
"tasks",
task_id=task_id,
parents=[("agents", "agent_id")],
"executions",
execution_id=execution_id,
parents=[("agents", "agent_id"), ("tasks", "task_id")],
),
validate_status_query if valid_previous_statuses is not None else "",
update_query,
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/tasks/get_task_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
)
from ...dependencies.developer_id import get_developer_id
from ...models.task.get_task import get_task as get_task_query

from .router import router


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from temporalio.exceptions import ApplicationError

with workflow.unsafe.imports_passed_through():
from ..activities import task_steps
from ..autogen.openapi_model import (
from ...activities import task_steps
from ...autogen.openapi_model import (
CreateTransitionRequest,
ErrorWorkflowStep,
EvaluateStep,
Expand All @@ -32,13 +32,13 @@
WorkflowStep,
YieldStep,
)
from ..common.protocol.tasks import (
from ...common.protocol.tasks import (
ExecutionInput,
PendingTransition,
PartialTransition,
StepContext,
StepOutcome,
)
from ..env import debug, testing
from ...env import debug, testing


STEP_TO_ACTIVITY = {
Expand Down Expand Up @@ -90,6 +90,14 @@ async def transition(state, context, **kwargs) -> None:
)


# init
# init_branch
# run
# finish_branch
# finish

#

@workflow.defn
class TaskExecutionWorkflow:
@workflow.run
Expand Down Expand Up @@ -123,16 +131,16 @@ async def run(
case (True, TransitionTarget(workflow="main")):
state_type = "finish"
case (True, _):
state_type = "branch_finish"
state_type = "finish_branch"
case _, _:
state_type = "step"

state = PendingTransition(
state = PartialTransition(
type=state_type,
next=None
if context.is_last_step
else TransitionTarget(workflow=start.workflow, step=start.step + 1),
metadata={"__meta__": {"step_type": step_type.__name__}},
metadata={"workflow_step_type": step_type.__name__},
)

# ---
Expand Down Expand Up @@ -463,7 +471,7 @@ async def run(

# 6. Closing
# End if the last step
if state.type in ("finish", "branch_finish", "cancelled"):
if state.type in ("finish", "finish_branch", "cancelled"):
workflow.logger.info(f"Workflow finished with state: {state.type}")
return state.output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def run(client, queries):
created_at,
updated_at,
output,
error,
metadata,
] :=
*executions {
Expand All @@ -35,6 +36,7 @@ def run(client, queries):
updated_at,
},
output = null,
error = null,
metadata = {}
:replace executions {
Expand All @@ -46,6 +48,7 @@ def run(client, queries):
input: Json,
output: Json? default null,
error: String? default null,
session_id: Uuid? default null,
metadata: Json default {},
created_at: Float default now(),
Expand Down
Loading

0 comments on commit 3395093

Please sign in to comment.