Skip to content

Commit

Permalink
fix(agents-api): Fix api-call tool implementation and test
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Oct 3, 2024
1 parent 040d774 commit c24e57a
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 50 deletions.
42 changes: 14 additions & 28 deletions agents-api/agents_api/activities/excecute_api_call.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, Any, Optional, Union
from typing import Annotated, Any, Optional, TypedDict, Union

import httpx
from beartype import beartype
Expand All @@ -14,47 +14,33 @@
# from ..models.tools import get_tool_args_from_metadata


class RequestArgs(TypedDict):
content: Optional[str]
data: Optional[dict[str, Any]]
json_: Optional[dict[str, Any]]
cookies: Optional[dict[str, str]]
params: Optional[Union[str, dict[str, Any]]]


@beartype
async def execute_api_call(
context: StepContext,
tool_name: str,
api_call: ApiCallDef,
content: Optional[str] = None,
data: Optional[dict[str, Any]] = None,
json_: Annotated[Optional[dict[str, Any]], Field(None, alias="json")] = None,
cookies: Optional[dict[str, str]] = None,
params: Optional[Union[str, dict[str, Any]]] = None,
request_args: RequestArgs,
) -> Any:
developer_id = context.execution_input.developer_id
agent_id = context.execution_input.agent.id
task_id = context.execution_input.task.id

# TODO: Implement get_tool_args_from_metadata to get the arguments and setup for the api call
# merged_tool_setup = get_tool_args_from_metadata(
# developer_id=developer_id, agent_id=agent_id, task_id=task_id, arg_type="setup"
# )

# arguments = (
# merged_tool_args.get(tool_name, {}) | (integration.arguments or {}) | arguments
# )

try:
response = httpx.request(
response = httpx.request(
method=api_call.method,
url=str(api_call.url),
headers=api_call.headers,
content=content,
data=data,
json=json_,
cookies=cookies,
params=params,
follow_redirects=api_call.follow_redirects,
**request_args,
)

response_dict = {
"status_code": response.status_code,
# FIXME: We need to handle the headers properly and convert them to a plain dict
# "headers": response.headers,
# "content": response.content,
"content": response.content,
"json": response.json(),
}

Expand Down
11 changes: 10 additions & 1 deletion agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@beartype
async def base_evaluate(
exprs: str | list[str] | dict[str, str],
exprs: str | list[str] | dict[str, str] | dict[str, dict[str, str]],
values: dict[str, Any] = {},
extra_lambda_strs: dict[str, str] | None = None,
) -> Any | list[Any] | dict[str, Any]:
Expand Down Expand Up @@ -53,9 +53,18 @@ async def base_evaluate(
case list():
return [evaluator.eval(expr) for expr in exprs]

case dict() as d if all(isinstance(v, dict) for v in d.values()):
return {
k: {ik: evaluator.eval(iv) for ik, iv in v.items()}
for k, v in d.items()
}

case dict():
return {k: evaluator.eval(v) for k, v in exprs.items()}

case _:
raise ValueError(f"Invalid expression: {exprs}")

except BaseException as e:
if activity.in_activity():
activity.logger.error(f"Error in base_evaluate: {e}")
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/autogen/Tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ class ToolCallStep(BaseModel):
"""
The tool to run
"""
arguments: dict[str, str] | Literal["_"] = "_"
arguments: dict[str, dict[str, str] | str] | Literal["_"] = "_"
"""
The input parameters for the tool (defaults to last step output)
"""
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/models/task/create_or_update_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def create_or_update_task(
data.metadata = data.metadata or {}
data.input_schema = data.input_schema or {}

task_data = task_to_spec(data).model_dump(exclude_none=True, exclude_unset=True, mode="json")
task_data = task_to_spec(data).model_dump(
exclude_none=True, exclude_unset=True, mode="json"
)
task_data.pop("task_id", None)
task_data["created_at"] = utcnow().timestamp()

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def create_worker(client: Client) -> Any:
from ..activities import task_steps
from ..activities.demo import demo_activity
from ..activities.embed_docs import embed_docs
from ..activities.excecute_api_call import execute_api_call
from ..activities.execute_integration import execute_integration
from ..activities. excecute_api_call import execute_api_call
from ..activities.mem_mgmt import mem_mgmt
from ..activities.mem_rating import mem_rating
from ..activities.summarization import summarization
Expand Down
19 changes: 5 additions & 14 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ async def run(
] == "api_call":
call = tool_call["api_call"]
tool_name = call["name"]
# arguments = call["arguments"]
arguments = call["arguments"]
apicall_spec = next(
(t for t in context.tools if t.name == tool_name), None
)
Expand All @@ -527,25 +527,16 @@ async def run(
follow_redirects=apicall_spec.spec["follow_redirects"],
)

# Extract the optional arguments for `content`, `data`, `json`, `cookies`, and `params`
content = call.get("content", None)
data = call.get("data", None)
json_ = call.get("json", None)
cookies = call.get("cookies", None)
params = call.get("params", None)
if "json_" in arguments:
arguments["json"] = arguments["json_"]
del arguments["json_"]

# Execute the API call using the `execute_api_call` function
tool_call_response = await workflow.execute_activity(
execute_api_call,
args=[
context,
tool_name,
api_call,
content,
data,
json_,
cookies,
params,
arguments,
],
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600
Expand Down
8 changes: 5 additions & 3 deletions agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,11 +470,13 @@ async def _(
"main": [
{
"tool": "hello",
"params": {"test": "_.test"},
"arguments": {
"params": {"test": "_.test"},
},
},
{
"evaluate": {"hello": "_.json.args"},
}
"evaluate": {"hello": "_.json.args.test"},
},
],
}
),
Expand Down
2 changes: 1 addition & 1 deletion typespec/tasks/steps.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ model ToolCallStepDef {
tool: validPythonIdentifier;

/** The input parameters for the tool (defaults to last step output) */
arguments: ExpressionObject<unknown> | "_" = "_";
arguments: NestedExpressionObject<unknown> | "_" = "_";
}

model PromptStep extends BaseWorkflowStep<"prompt"> {
Expand Down

0 comments on commit c24e57a

Please sign in to comment.