diff --git a/agents-api/agents_api/activities/task_steps/__init__.py b/agents-api/agents_api/activities/task_steps/__init__.py index e1792f051..2f63fb81e 100644 --- a/agents-api/agents_api/activities/task_steps/__init__.py +++ b/agents-api/agents_api/activities/task_steps/__init__.py @@ -1,6 +1,7 @@ # ruff: noqa: F401, F403, F405 from .evaluate_step import evaluate_step +from .for_each_step import for_each_step from .if_else_step import if_else_step from .log_step import log_step from .prompt_step import prompt_step diff --git a/agents-api/agents_api/activities/task_steps/for_each_step.py b/agents-api/agents_api/activities/task_steps/for_each_step.py index e297d11dc..f2f24430f 100644 --- a/agents-api/agents_api/activities/task_steps/for_each_step.py +++ b/agents-api/agents_api/activities/task_steps/for_each_step.py @@ -17,7 +17,11 @@ async def for_each_step(context: StepContext) -> StepOutcome: try: assert isinstance(context.current_step, ForeachStep) - return StepOutcome(output=simple_eval(context.current_step.foreach.in_, names=context.model_dump())) + return StepOutcome( + output=simple_eval( + context.current_step.foreach.in_, names=context.model_dump() + ) + ) except BaseException as e: logging.error(f"Error in for_each_step: {e}") return StepOutcome(error=str(e)) @@ -27,6 +31,6 @@ async def for_each_step(context: StepContext) -> StepOutcome: # They do the same thing, so we dont need to mock the if_else_step function mock_if_else_step = for_each_step -for_each_step = activity.defn(name="if_else_step")( +for_each_step = activity.defn(name="for_each_step")( for_each_step if not testing else mock_if_else_step ) diff --git a/agents-api/agents_api/workflows/task_execution.py b/agents-api/agents_api/workflows/task_execution.py index f3452b8e9..c0ee89688 100644 --- a/agents-api/agents_api/workflows/task_execution.py +++ b/agents-api/agents_api/workflows/task_execution.py @@ -14,8 +14,11 @@ CreateTransitionRequest, ErrorWorkflowStep, EvaluateStep, + ForeachDo, + ForeachStep, IfElseWorkflowStep, LogStep, + MapReduceStep, # PromptStep, ReturnStep, SleepFor, @@ -26,10 +29,6 @@ WaitForInputStep, Workflow, YieldStep, - ForeachStep, - ForeachDo, - SwitchStep, - MapReduceStep, ) from ..common.protocol.tasks import ( ExecutionInput, @@ -52,6 +51,7 @@ ReturnStep: task_steps.return_step, YieldStep: task_steps.yield_step, IfElseWorkflowStep: task_steps.if_else_step, + ForeachStep: task_steps.for_each_step, } # TODO: Avoid local activities for now (currently experimental) @@ -198,14 +198,10 @@ async def transition(**kwargs) -> None: args=if_else_args, ) - case ForeachStep(foreach=ForeachDo(do=do_step)), StepOutcome( - output=items - ): + case ForeachStep(foreach=ForeachDo(do=do_step)), StepOutcome(output=items): for i, item in enumerate(items): # Create a faux workflow - foreach_wf_name = ( - f"`{context.cursor.workflow}`[{context.cursor.step}].foreach[{i}]" - ) + foreach_wf_name = f"`{context.cursor.workflow}`[{context.cursor.step}].foreach[{i}]" foreach_task = execution_input.task.model_copy() foreach_task.workflows = [ @@ -217,12 +213,14 @@ async def transition(**kwargs) -> None: foreach_execution_input.task = foreach_task # Set the next target to the chosen branch - foreach_next_target = TransitionTarget(workflow=foreach_wf_name, step=0) + foreach_next_target = TransitionTarget( + workflow=foreach_wf_name, step=0 + ) foreach_args = [ foreach_execution_input, foreach_next_target, - previous_inputs + [item], + previous_inputs + [{"item": item}], ] # Execute the chosen branch and come back here @@ -230,10 +228,8 @@ async def transition(**kwargs) -> None: TaskExecutionWorkflow.run, args=foreach_args, ) - - case SwitchStep(switch=cases), StepOutcome( - output=int(case_num) - ): + + case SwitchStep(switch=cases), StepOutcome(output=int(case_num)): if case_num > 0: chosen_branch = cases[case_num] diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py index 2f6d434a3..19c4b68e7 100644 --- a/agents-api/tests/test_execution_workflow.py +++ b/agents-api/tests/test_execution_workflow.py @@ -551,3 +551,50 @@ async def _( result = await handle.result() assert result["hello"] == "world" + + +@test("workflow: for each step") +async def _( + client=cozo_client, + developer_id=test_developer_id, + agent=test_agent, +): + data = CreateExecutionRequest(input={"test": "input"}) + + task = create_task( + developer_id=developer_id, + agent_id=agent.id, + data=CreateTaskRequest( + **{ + "name": "test task", + "description": "test task about", + "input_schema": {"type": "object", "additionalProperties": True}, + "main": [ + { + "foreach": { + "in": "'a b c'.split()", + "do": {"evaluate": {"hello": '"world"'}}, + }, + }, + ], + } + ), + client=client, + ) + + async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): + execution, handle = await start_execution( + developer_id=developer_id, + task_id=task.id, + data=data, + client=client, + ) + + assert handle is not None + assert execution.task_id == task.id + assert execution.input == data.input + + mock_run_task_execution_workflow.assert_called_once() + + result = await handle.result() + assert result["hello"] == "world"