Skip to content

Commit

Permalink
feat(agents-api): Make the sample work (more or less)
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 24, 2024
1 parent 936ffbe commit 3183267
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 13 deletions.
13 changes: 12 additions & 1 deletion agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import json
from typing import Any

import re2
import yaml
from beartype import beartype
from simpleeval import EvalWithCompoundTypes, SimpleEval
from yaml import CSafeLoader

ALLOWED_FUNCTIONS = {
"len": len,
"load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader),
"match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)),
"search_regex": lambda pattern, string: re2.search(pattern, string),
"load_json": json.loads,
}

@beartype
def get_evaluator(names: dict[str, Any]) -> SimpleEval:
evaluator = EvalWithCompoundTypes(names=names)
evaluator = EvalWithCompoundTypes(names=names, functions=ALLOWED_FUNCTIONS)
return evaluator


Expand Down
6 changes: 4 additions & 2 deletions agents-api/agents_api/routers/tasks/create_task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ async def start_execution(
client=client,
)

job_id=uuid4()

try:
handle = await run_task_execution_workflow(
execution_input=execution_input,
job_id=uuid4(),
job_id=job_id,
)

except Exception as e:
Expand Down Expand Up @@ -130,5 +132,5 @@ async def create_task_execution(
return ResourceCreatedResponse(
id=execution.id,
created_at=execution.created_at,
jobs=[],
jobs=[handle.id],
)
3 changes: 3 additions & 0 deletions agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ async def transition(**kwargs) -> None:
args=map_reduce_args,
)

if hasattr(output, "model_dump"):
output = output.model_dump()

initial = await execute_activity(
task_steps.base_evaluate,
args=[
Expand Down
62 changes: 61 additions & 1 deletion agents-api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions agents-api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ simpleeval = "^0.9.13"
lz4 = "^4.3.3"

pyyaml = "^6.0.2"
google-re2 = "^1.1.20240702"
[tool.poetry.group.dev.dependencies]
ipython = "^8.26.0"
ruff = "^0.5.5"
Expand Down
6 changes: 3 additions & 3 deletions agents-api/tests/sample_tasks/find_selector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ main:
content:
- type: image_url
image_url:
url: "inputs[0].screenshot_base64"
url: "{{inputs[0].screenshot_base64}}"

over: _.parameters
reduce: >-
Expand All @@ -74,8 +74,8 @@ main:
result: >-
[
{"value": result["value"], "network_request": request}
for request in execution.input.network_requests
if result["value"] in nr.response.body
for request in inputs[0]["network_requests"]
if result["value"] in nr["response"]["body"]
for result in _
if result["found"]
]
Expand Down
62 changes: 56 additions & 6 deletions agents-api/tests/sample_tasks/test_find_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from uuid import uuid4

from ward import test
from ward import raises, test

from ..fixtures import cozo_client, test_agent, test_developer_id
from ..utils import patch_embed_acompletion, patch_http_client_with_temporal
Expand Down Expand Up @@ -39,7 +39,7 @@ async def _(
).raise_for_status()


@test("workflow sample: find-selector start")
@test("workflow sample: find-selector start with bad input should fail")
async def _(
cozo_client=cozo_client,
developer_id=test_developer_id,
Expand Down Expand Up @@ -68,12 +68,62 @@ async def _(

execution_data = dict(input={"test": "input"})

execution = (
with raises(BaseException):
make_request(
method="POST",
url=f"/tasks/{task_id}/executions",
json=execution_data,
)
.raise_for_status()
.json()
).raise_for_status()


@test("workflow sample: find-selector start with correct input")
async def _(
cozo_client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
agent_id = str(agent.id)
task_id = str(uuid4())


with patch_embed_acompletion(), open(
f"{this_dir}/find_selector.yaml", "r"
) as sample_file:
task_def = sample_file.read()

async with patch_http_client_with_temporal(
cozo_client=cozo_client, developer_id=developer_id
) as (
make_request,
temporal_client,
):
make_request(
method="POST",
url=f"/agents/{agent_id}/tasks/{task_id}",
headers={"Content-Type": "application/x-yaml"},
data=task_def,
).raise_for_status()

input = dict(
screenshot_base64="iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAACXBIWXMAAAsTAAALEwEAmpwYAAAA",
network_requests=[{
"request": {},
"response": {
"body": "Lady Gaga"
}
}],
parameters=["name"],
)
execution_data = dict(input=input)

execution_created = make_request(
method="POST",
url=f"/tasks/{task_id}/executions",
json=execution_data,
).json()

handle = temporal_client.get_workflow_handle(
execution_created["jobs"][0]
)

await handle.result()

0 comments on commit 3183267

Please sign in to comment.