diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 0e5f5055c..595ec3175 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -59,6 +59,7 @@ from ...common.retry_policies import DEFAULT_RETRY_POLICY from ...env import ( debug, + task_max_parallelism, temporal_heartbeat_timeout, temporal_schedule_to_close_timeout, testing, @@ -304,7 +305,9 @@ async def _handle_MapReduceStep( return PartialTransition(output=None) parallelism = step.parallelism - if parallelism is None or parallelism == 1: + if parallelism is None: + parallelism = task_max_parallelism + if parallelism == 1: result = await execute_map_reduce_step( context=self.context, execution_input=self.context.execution_input, diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index 3bf023bb6..c0e4fb6f6 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -15,6 +15,7 @@ TransitionTarget, Workflow, WorkflowStep, + YieldStep, ) from ...common.protocol.tasks import ( ExecutionInput, @@ -274,6 +275,10 @@ async def execute_map_reduce_step_parallel( workflow.logger.info(f"MapReduce step: Processing {len(items)} items") results = initial + if isinstance(context.current_step.map, YieldStep): + msg = "Subworkflow step not supported in parallel map reduce" + raise ValueError(msg) + parallelism = min(parallelism, task_max_parallelism) assert parallelism > 1, "Parallelism must be greater than 1" diff --git a/agents-api/tests/test_workflow_helpers.py b/agents-api/tests/test_workflow_helpers.py new file mode 100644 index 000000000..a1c8bf43f --- /dev/null +++ b/agents-api/tests/test_workflow_helpers.py @@ -0,0 +1,119 @@ +import uuid +from unittest.mock import patch + +from agents_api.autogen.openapi_model import ( + Agent, + MapReduceStep, + PromptItem, + PromptStep, + TaskSpecDef, + TransitionTarget, + Workflow, + YieldStep, +) +from agents_api.common.protocol.tasks import ( + ExecutionInput, + StepContext, +) +from agents_api.common.utils.datetime import utcnow +from agents_api.workflows.task_execution.helpers import execute_map_reduce_step_parallel +from ward import raises, test + + +@test("execute_map_reduce_step_parallel: subworkflow step not supported") +async def _(): + async def _resp(): + return "response" + + subworkflow_step = YieldStep( + kind_="yield", workflow="subworkflow", arguments={"test": "$ _"} + ) + + step = MapReduceStep( + kind_="map_reduce", + map=subworkflow_step, + over="$ [1, 2, 3]", + parallelism=3, + ) + + execution_input = ExecutionInput( + developer_id=uuid.uuid4(), + agent=Agent( + id=uuid.uuid4(), name="test agent", created_at=utcnow(), updated_at=utcnow() + ), + agent_tools=[], + arguments={}, + task=TaskSpecDef( + name="task1", + tools=[], + workflows=[Workflow(name="main", steps=[step])], + ), + ) + + context = StepContext( + execution_input=execution_input, + current_input={"current_input": "value 1"}, + cursor=TransitionTarget( + workflow="main", + step=0, + ), + ) + with patch("agents_api.workflows.task_execution.helpers.workflow") as workflow: + workflow.execute_activity.return_value = await _resp() + with raises(ValueError): + await execute_map_reduce_step_parallel( + context=context, + map_defn=step.map, + execution_input=execution_input, + items=["1", "2", "3"], + current_input={}, + ) + + +@test("execute_map_reduce_step_parallel: parallelism must be greater than 1") +async def _(): + async def _resp(): + return "response" + + subworkflow_step = PromptStep(prompt=[PromptItem(content="hi there", role="user")]) + + step = MapReduceStep( + kind_="map_reduce", + map=subworkflow_step, + over="$ [1, 2, 3]", + parallelism=1, + ) + + execution_input = ExecutionInput( + developer_id=uuid.uuid4(), + agent=Agent( + id=uuid.uuid4(), name="test agent", created_at=utcnow(), updated_at=utcnow() + ), + agent_tools=[], + arguments={}, + task=TaskSpecDef( + name="task1", + tools=[], + workflows=[Workflow(name="main", steps=[step])], + ), + ) + + context = StepContext( + execution_input=execution_input, + current_input={"current_input": "value 1"}, + cursor=TransitionTarget( + workflow="main", + step=0, + ), + ) + with patch("agents_api.workflows.task_execution.helpers.workflow") as workflow: + workflow.execute_activity.return_value = await _resp() + with raises(AssertionError): + await execute_map_reduce_step_parallel( + context=context, + map_defn=step.map, + execution_input=execution_input, + items=["1", "2", "3"], + current_input={}, + parallelism=1, + )