Skip to content

Commit

Permalink
Merge pull request #483 from julep-ai/x/fix-task-execution
Browse files Browse the repository at this point in the history
  • Loading branch information
creatorrr authored Sep 3, 2024
2 parents f8d32e4 + 5ee187f commit e73ae61
Show file tree
Hide file tree
Showing 22 changed files with 531 additions and 276 deletions.
5 changes: 2 additions & 3 deletions agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ async def base_evaluate(
k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in values.items()
}

# TODO: We should make this frozen_box=True, but we need to make sure that
# we don't break anything
values = Box(values, frozen_box=False, conversion_box=False)
# frozen_box doesn't work coz we need some mutability in the values
values = Box(values, frozen_box=False, conversion_box=True)

evaluator = get_evaluator(names=values)

Expand Down
15 changes: 3 additions & 12 deletions agents-api/agents_api/activities/task_steps/transition_step.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from uuid import uuid4

from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import CreateTransitionRequest, Transition
from ...common.protocol.tasks import StepContext
from ...env import testing
from ...models.execution.create_execution_transition import (
create_execution_transition as create_execution_transition_query,
)
from ...models.execution.create_execution_transition import create_execution_transition


@beartype
Expand All @@ -24,7 +20,7 @@ async def transition_step(
transition_info.task_token = task_token

# Create transition
transition = create_execution_transition_query(
transition = create_execution_transition(
developer_id=context.execution_input.developer_id,
execution_id=context.execution_input.execution.id,
task_id=context.execution_input.task.id,
Expand All @@ -34,13 +30,8 @@ async def transition_step(

return transition

async def mock_transition_step(
context: StepContext,
transition_info: CreateTransitionRequest,
) -> None:
# Does nothing
return None

mock_transition_step = transition_step

transition_step = activity.defn(name="transition_step")(
transition_step if not testing else mock_transition_step
Expand Down
22 changes: 20 additions & 2 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,31 @@
from simpleeval import EvalWithCompoundTypes, SimpleEval
from yaml import CSafeLoader

# TODO: We need to make sure that we dont expose any security issues
ALLOWED_FUNCTIONS = {
"zip": zip,
"abs": abs,
"all": all,
"any": any,
"bool": bool,
"dict": dict,
"enumerate": enumerate,
"float": float,
"frozenset": frozenset,
"int": int,
"len": len,
"list": list,
"load_json": json.loads,
"load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader),
"match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)),
"max": max,
"min": min,
"round": round,
"search_regex": lambda pattern, string: re2.search(pattern, string),
"load_json": json.loads,
"set": set,
"str": str,
"sum": sum,
"tuple": tuple,
"zip": zip,
}


Expand Down
116 changes: 60 additions & 56 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ruff: noqa: F401, F403, F405
from typing import Annotated, Any, Generic, Literal, Self, Type, TypeVar
from typing import Annotated, Any, Generic, Literal, Self, Type, TypeVar, get_args
from uuid import UUID

from litellm.utils import _select_tokenizer as select_tokenizer
Expand Down Expand Up @@ -32,24 +32,45 @@ class ListResponse(BaseModel, Generic[DataT]):
# Aliases
# -------

CreateToolRequest = UpdateToolRequest
CreateOrUpdateAgentRequest = UpdateAgentRequest
CreateOrUpdateUserRequest = UpdateUserRequest
CreateOrUpdateSessionRequest = CreateSessionRequest

class CreateToolRequest(UpdateToolRequest):
pass


class CreateOrUpdateAgentRequest(UpdateAgentRequest):
pass


class CreateOrUpdateUserRequest(UpdateUserRequest):
pass


class CreateOrUpdateSessionRequest(CreateSessionRequest):
pass


ChatResponse = ChunkChatResponse | MessageChatResponse

# TODO: Figure out wtf... 🤷‍♂️
MapReduceStep = Main
ChatMLTextContentPart = Content
ChatMLImageContentPart = ContentModel
InputChatMLMessage = Message

class MapReduceStep(Main):
pass


class ChatMLTextContentPart(Content):
pass


class ChatMLImageContentPart(ContentModel):
pass


class InputChatMLMessage(Message):
pass


# Custom types (not generated correctly)
# --------------------------------------

# TODO: Remove these when auto-population is fixed

ChatMLContent = (
list[ChatMLTextContentPart | ChatMLImageContentPart]
| Tool
Expand All @@ -65,48 +86,23 @@ class ListResponse(BaseModel, Generic[DataT]):
]
)

ChatMLRole = Literal[
"user",
"assistant",
"system",
"function",
"function_response",
"function_call",
"auto",
]
assert BaseEntry.model_fields["role"].annotation == ChatMLRole

ChatMLSource = Literal[
"api_request", "api_response", "tool_response", "internal", "summarizer", "meta"
]
assert BaseEntry.model_fields["source"].annotation == ChatMLSource


ExecutionStatus = Literal[
"queued",
"starting",
"running",
"awaiting_input",
"succeeded",
"failed",
"cancelled",
]
assert Execution.model_fields["status"].annotation == ExecutionStatus


TransitionType = Literal[
"init",
"init_branch",
"finish",
"finish_branch",
"wait",
"resume",
"error",
"step",
"cancelled",
]

assert Transition.model_fields["type"].annotation == TransitionType
# Extract ChatMLRole
ChatMLRole = BaseEntry.model_fields["role"].annotation

# Extract ChatMLSource
ChatMLSource = BaseEntry.model_fields["source"].annotation

# Extract ExecutionStatus
ExecutionStatus = Execution.model_fields["status"].annotation

# Extract TransitionType
TransitionType = Transition.model_fields["type"].annotation

# Assertions to ensure consistency (optional, but recommended for runtime checks)
assert ChatMLRole == BaseEntry.model_fields["role"].annotation
assert ChatMLSource == BaseEntry.model_fields["source"].annotation
assert ExecutionStatus == Execution.model_fields["status"].annotation
assert TransitionType == Transition.model_fields["type"].annotation


# Create models
Expand Down Expand Up @@ -155,8 +151,8 @@ def from_model_input(
)


# Task related models
# -------------------
# Workflow related models
# -----------------------

WorkflowStep = (
EvaluateStep
Expand Down Expand Up @@ -185,6 +181,10 @@ class Workflow(BaseModel):
steps: list[WorkflowStep]


# Task spec helper models
# ----------------------


class TaskToolDef(BaseModel):
type: str
name: str
Expand Down Expand Up @@ -223,6 +223,10 @@ class Task(_Task):
)


# Patch some models to allow extra fields
# --------------------------------------


_CreateTaskRequest = CreateTaskRequest


Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/clients/cozo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Dict

from pycozo.client import Client

Expand All @@ -10,7 +10,7 @@
options.update({"auth": cozo_auth})


def get_cozo_client() -> Any:
def get_cozo_client() -> Client:
client = getattr(app.state, "cozo_client", Client("http", options=options))
if not hasattr(app.state, "cozo_client"):
app.state.cozo_client = client
Expand Down
70 changes: 62 additions & 8 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,53 @@
WorkflowStep,
)

### NOTE: Here, "init" is NOT a real state, but a placeholder for the start state of the state machine
# TODO: Maybe we should use a library for this

# State Machine
#
# init -> wait | error | step | cancelled | init_branch | finish
# init_branch -> wait | error | step | cancelled | finish_branch
# wait -> resume | error | cancelled
# resume -> wait | error | cancelled | step | finish | finish_branch | init_branch
# step -> wait | error | cancelled | step | finish | finish_branch | init_branch
# finish_branch -> wait | error | cancelled | step | finish | init_branch
# error ->

## Mermaid Diagram
# ```mermaid
# ---
# title: Execution state machine
# ---
# stateDiagram-v2
# [*] --> queued
# queued --> starting
# queued --> cancelled
# starting --> cancelled
# starting --> failed
# starting --> running
# running --> running
# running --> awaiting_input
# running --> cancelled
# running --> failed
# running --> succeeded
# awaiting_input --> running
# awaiting_input --> cancelled
# cancelled --> [*]
# succeeded --> [*]
# failed --> [*]

# ```
# TODO: figure out how to type this
valid_transitions: dict[TransitionType, list[TransitionType]] = {
# Start state
"init": ["wait", "error", "step", "cancelled", "init_branch"],
"init_branch": ["wait", "error", "step", "cancelled"],
"init": ["wait", "error", "step", "cancelled", "init_branch", "finish"],
"init_branch": ["wait", "error", "step", "cancelled", "finish_branch"],
# End states
"finish": [],
"error": [],
"cancelled": [],
# Intermediate states
"wait": ["resume", "error", "cancelled"],
"wait": ["resume", "cancelled"],
"resume": [
"wait",
"error",
Expand All @@ -59,8 +95,13 @@
} # type: ignore

valid_previous_statuses: dict[ExecutionStatus, list[ExecutionStatus]] = {
"running": ["queued", "starting", "awaiting_input"],
"running": ["starting", "awaiting_input", "running"],
"starting": ["queued"],
"queued": [],
"awaiting_input": ["starting", "running"],
"cancelled": ["queued", "starting", "awaiting_input", "running"],
"succeeded": ["starting", "running"],
"failed": ["starting", "running"],
} # type: ignore

transition_to_execution_status: dict[TransitionType | None, ExecutionStatus] = {
Expand Down Expand Up @@ -100,12 +141,12 @@ class StepContext(BaseModel):

@computed_field
@property
def outputs(self) -> Annotated[list[dict[str, Any]], Field(exclude=True)]:
def outputs(self) -> list[dict[str, Any]]: # included in dump
return self.inputs[1:]

@computed_field
@property
def current_input(self) -> Annotated[dict[str, Any], Field(exclude=True)]:
def current_input(self) -> dict[str, Any]: # included in dump
return self.inputs[-1]

@computed_field
Expand All @@ -130,9 +171,22 @@ def is_last_step(self) -> Annotated[bool, Field(exclude=True)]:
def is_first_step(self) -> Annotated[bool, Field(exclude=True)]:
return self.cursor.step == 0

@computed_field
@property
def is_main(self) -> Annotated[bool, Field(exclude=True)]:
return self.cursor.workflow == "main"

def model_dump(self, *args, **kwargs) -> dict[str, Any]:
dump = super().model_dump(*args, **kwargs)
dump["_"] = self.current_input

# Merge execution inputs into the dump dict
execution_input: dict = dump.pop("execution_input")
current_input: Any = dump.pop("current_input")
dump = {
**dump,
**execution_input,
"_": current_input,
}

return dump

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def create_execution_transition(
?[valid] :=
matched[prev_transitions],
found = length(prev_transitions),
valid = assert(found > 0, "Invalid transition"),
valid = if($next_type == "init", found == 0, found > 0),
assert(valid, "Invalid transition"),
"""

# Prepare the insert query
Expand Down
Loading

0 comments on commit e73ae61

Please sign in to comment.