Skip to content

Commit

Permalink
wip(agents-api): Add stub for switch step
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 19, 2024
1 parent 9fdafaf commit 69100c5
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 0 deletions.
1 change: 1 addition & 0 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .prompt_step import prompt_step
from .raise_complete_async import raise_complete_async
from .return_step import return_step
from .switch_step import switch_step
from .tool_call_step import tool_call_step
from .transition_step import transition_step
from .wait_for_input_step import wait_for_input_step
Expand Down
48 changes: 48 additions & 0 deletions agents-api/agents_api/activities/task_steps/switch_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import SwitchStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...env import testing


@beartype
async def switch_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, SwitchStep)

# Assume that none of the cases evaluate to truthy
output: int = -1

cases: list[str] = [c.case for c in context.current_step.switch]

for i, case in enumerate(cases):
result = simple_eval(case, names=context.model_dump())

if result:
output = i
break

result = StepOutcome(output=output)
return result

except BaseException as e:
logging.error(f"Error in switch_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported switch_step directly
# They do the same thing, so we dont need to mock the switch_step function
mock_switch_step = switch_step

switch_step = activity.defn(name="switch_step")(
switch_step if not testing else mock_switch_step
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def create_worker(client: Client) -> Any:
log_step,
prompt_step,
return_step,
switch_step,
tool_call_step,
transition_step,
wait_for_input_step,
Expand All @@ -45,6 +46,7 @@ def create_worker(client: Client) -> Any:
log_step,
prompt_step,
return_step,
switch_step,
tool_call_step,
transition_step,
wait_for_input_step,
Expand Down
10 changes: 10 additions & 0 deletions agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ReturnStep,
SleepFor,
SleepStep,
SwitchStep,
# ToolCallStep,
TransitionTarget,
WaitForInputStep,
Expand All @@ -40,8 +41,10 @@
# ToolCallStep: tool_call_step,
WaitForInputStep: task_steps.wait_for_input_step,
LogStep: task_steps.log_step,
SwitchStep: task_steps.switch_step,
}

# Use few local activities (currently experimental)
STEP_TO_LOCAL_ACTIVITY = {
# NOTE: local activities are directly called in the workflow executor
# They MUST NOT FAIL, otherwise they will crash the workflow
Expand Down Expand Up @@ -142,6 +145,13 @@ async def transition(**kwargs) -> None:
await transition(output=output, type="finish", next=None)
return output # <--- Byeeee!

case SwitchStep(switch=switch), StepOutcome(output=index) if index >= 0:
raise NotImplementedError("SwitchStep is not implemented")

case SwitchStep(), StepOutcome(output=index) if index < 0:
# If no case matched, then the output will be -1
raise NotImplementedError("SwitchStep is not implemented")

case IfElseWorkflowStep(then=then_branch, else_=else_branch), StepOutcome(
output=condition
):
Expand Down
57 changes: 57 additions & 0 deletions agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,60 @@ async def _(

result = await handle.result()
assert result["hello"] == "world"


@test("workflow: switch step")
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
data = CreateExecutionRequest(input={"test": "input"})

task = create_task(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(
**{
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [
{
"switch": [
{
"case": "False",
"then": {"evaluate": {"hello": '"bubbles"'}},
},
{
"case": "True",
"then": {"evaluate": {"hello": '"world"'}},
},
{
"case": "True",
"then": {"evaluate": {"hello": '"bye"'}},
},
]
},
],
}
),
client=client,
)

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
execution, handle = await start_execution(
developer_id=developer_id,
task_id=task.id,
data=data,
client=client,
)

assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input

mock_run_task_execution_workflow.assert_called_once()

result = await handle.result()
assert result["hello"] == "world"

0 comments on commit 69100c5

Please sign in to comment.