diff --git a/agents-api/agents_api/activities/task_steps/__init__.py b/agents-api/agents_api/activities/task_steps/__init__.py index 73cf802de..e1792f051 100644 --- a/agents-api/agents_api/activities/task_steps/__init__.py +++ b/agents-api/agents_api/activities/task_steps/__init__.py @@ -6,6 +6,7 @@ from .prompt_step import prompt_step from .raise_complete_async import raise_complete_async from .return_step import return_step +from .switch_step import switch_step from .tool_call_step import tool_call_step from .transition_step import transition_step from .wait_for_input_step import wait_for_input_step diff --git a/agents-api/agents_api/activities/task_steps/switch_step.py b/agents-api/agents_api/activities/task_steps/switch_step.py new file mode 100644 index 000000000..60e64033e --- /dev/null +++ b/agents-api/agents_api/activities/task_steps/switch_step.py @@ -0,0 +1,48 @@ +import logging + +from beartype import beartype +from simpleeval import simple_eval +from temporalio import activity + +from ...autogen.openapi_model import SwitchStep +from ...common.protocol.tasks import ( + StepContext, + StepOutcome, +) +from ...env import testing + + +@beartype +async def switch_step(context: StepContext) -> StepOutcome: + # NOTE: This activity is only for logging, so we just evaluate the expression + # Hence, it's a local activity and SHOULD NOT fail + try: + assert isinstance(context.current_step, SwitchStep) + + # Assume that none of the cases evaluate to truthy + output: int = -1 + + cases: list[str] = [c.case for c in context.current_step.switch] + + for i, case in enumerate(cases): + result = simple_eval(case, names=context.model_dump()) + + if result: + output = i + break + + result = StepOutcome(output=output) + return result + + except BaseException as e: + logging.error(f"Error in switch_step: {e}") + return StepOutcome(error=str(e)) + + +# Note: This is here just for clarity. We could have just imported switch_step directly +# They do the same thing, so we dont need to mock the switch_step function +mock_switch_step = switch_step + +switch_step = activity.defn(name="switch_step")( + switch_step if not testing else mock_switch_step +) diff --git a/agents-api/agents_api/worker/worker.py b/agents-api/agents_api/worker/worker.py index 10b3753ad..604454c4c 100644 --- a/agents-api/agents_api/worker/worker.py +++ b/agents-api/agents_api/worker/worker.py @@ -22,6 +22,7 @@ def create_worker(client: Client) -> Any: log_step, prompt_step, return_step, + switch_step, tool_call_step, transition_step, wait_for_input_step, @@ -45,6 +46,7 @@ def create_worker(client: Client) -> Any: log_step, prompt_step, return_step, + switch_step, tool_call_step, transition_step, wait_for_input_step, diff --git a/agents-api/agents_api/workflows/task_execution.py b/agents-api/agents_api/workflows/task_execution.py index 546bbca62..4cc419e86 100644 --- a/agents-api/agents_api/workflows/task_execution.py +++ b/agents-api/agents_api/workflows/task_execution.py @@ -20,6 +20,7 @@ ReturnStep, SleepFor, SleepStep, + SwitchStep, # ToolCallStep, TransitionTarget, WaitForInputStep, @@ -40,8 +41,10 @@ # ToolCallStep: tool_call_step, WaitForInputStep: task_steps.wait_for_input_step, LogStep: task_steps.log_step, + SwitchStep: task_steps.switch_step, } +# Use few local activities (currently experimental) STEP_TO_LOCAL_ACTIVITY = { # NOTE: local activities are directly called in the workflow executor # They MUST NOT FAIL, otherwise they will crash the workflow @@ -142,6 +145,13 @@ async def transition(**kwargs) -> None: await transition(output=output, type="finish", next=None) return output # <--- Byeeee! + case SwitchStep(switch=switch), StepOutcome(output=index) if index >= 0: + raise NotImplementedError("SwitchStep is not implemented") + + case SwitchStep(), StepOutcome(output=index) if index < 0: + # If no case matched, then the output will be -1 + raise NotImplementedError("SwitchStep is not implemented") + case IfElseWorkflowStep(then=then_branch, else_=else_branch), StepOutcome( output=condition ): diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py index 999643168..2f6d434a3 100644 --- a/agents-api/tests/test_execution_workflow.py +++ b/agents-api/tests/test_execution_workflow.py @@ -494,3 +494,60 @@ async def _( result = await handle.result() assert result["hello"] == "world" + + +@test("workflow: switch step") +async def _( + client=cozo_client, + developer_id=test_developer_id, + agent=test_agent, +): + data = CreateExecutionRequest(input={"test": "input"}) + + task = create_task( + developer_id=developer_id, + agent_id=agent.id, + data=CreateTaskRequest( + **{ + "name": "test task", + "description": "test task about", + "input_schema": {"type": "object", "additionalProperties": True}, + "main": [ + { + "switch": [ + { + "case": "False", + "then": {"evaluate": {"hello": '"bubbles"'}}, + }, + { + "case": "True", + "then": {"evaluate": {"hello": '"world"'}}, + }, + { + "case": "True", + "then": {"evaluate": {"hello": '"bye"'}}, + }, + ] + }, + ], + } + ), + client=client, + ) + + async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): + execution, handle = await start_execution( + developer_id=developer_id, + task_id=task.id, + data=data, + client=client, + ) + + assert handle is not None + assert execution.task_id == task.id + assert execution.input == data.input + + mock_run_task_execution_workflow.assert_called_once() + + result = await handle.result() + assert result["hello"] == "world"