Skip to content

Commit

Permalink
test: Test for each step, add small fixes
Browse files Browse the repository at this point in the history
whiterabbit1983 committed Aug 20, 2024
1 parent a63fb4a commit d2ab540
Showing 4 changed files with 66 additions and 18 deletions.
1 change: 1 addition & 0 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ruff: noqa: F401, F403, F405

from .evaluate_step import evaluate_step
from .for_each_step import for_each_step
from .if_else_step import if_else_step
from .log_step import log_step
from .prompt_step import prompt_step
8 changes: 6 additions & 2 deletions agents-api/agents_api/activities/task_steps/for_each_step.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,11 @@ async def for_each_step(context: StepContext) -> StepOutcome:
try:
assert isinstance(context.current_step, ForeachStep)

return StepOutcome(output=simple_eval(context.current_step.foreach.in_, names=context.model_dump()))
return StepOutcome(
output=simple_eval(
context.current_step.foreach.in_, names=context.model_dump()
)
)
except BaseException as e:
logging.error(f"Error in for_each_step: {e}")
return StepOutcome(error=str(e))
@@ -27,6 +31,6 @@ async def for_each_step(context: StepContext) -> StepOutcome:
# They do the same thing, so we dont need to mock the if_else_step function
mock_if_else_step = for_each_step

for_each_step = activity.defn(name="if_else_step")(
for_each_step = activity.defn(name="for_each_step")(
for_each_step if not testing else mock_if_else_step
)
28 changes: 12 additions & 16 deletions agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
@@ -14,8 +14,11 @@
CreateTransitionRequest,
ErrorWorkflowStep,
EvaluateStep,
ForeachDo,
ForeachStep,
IfElseWorkflowStep,
LogStep,
MapReduceStep,
# PromptStep,
ReturnStep,
SleepFor,
@@ -26,10 +29,6 @@
WaitForInputStep,
Workflow,
YieldStep,
ForeachStep,
ForeachDo,
SwitchStep,
MapReduceStep,
)
from ..common.protocol.tasks import (
ExecutionInput,
@@ -52,6 +51,7 @@
ReturnStep: task_steps.return_step,
YieldStep: task_steps.yield_step,
IfElseWorkflowStep: task_steps.if_else_step,
ForeachStep: task_steps.for_each_step,
}

# TODO: Avoid local activities for now (currently experimental)
@@ -198,14 +198,10 @@ async def transition(**kwargs) -> None:
args=if_else_args,
)

case ForeachStep(foreach=ForeachDo(do=do_step)), StepOutcome(
output=items
):
case ForeachStep(foreach=ForeachDo(do=do_step)), StepOutcome(output=items):
for i, item in enumerate(items):
# Create a faux workflow
foreach_wf_name = (
f"`{context.cursor.workflow}`[{context.cursor.step}].foreach[{i}]"
)
foreach_wf_name = f"`{context.cursor.workflow}`[{context.cursor.step}].foreach[{i}]"

foreach_task = execution_input.task.model_copy()
foreach_task.workflows = [
@@ -217,23 +213,23 @@ async def transition(**kwargs) -> None:
foreach_execution_input.task = foreach_task

# Set the next target to the chosen branch
foreach_next_target = TransitionTarget(workflow=foreach_wf_name, step=0)
foreach_next_target = TransitionTarget(
workflow=foreach_wf_name, step=0
)

foreach_args = [
foreach_execution_input,
foreach_next_target,
previous_inputs + [item],
previous_inputs + [{"item": item}],
]

# Execute the chosen branch and come back here
state.output = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=foreach_args,
)

case SwitchStep(switch=cases), StepOutcome(
output=int(case_num)
):

case SwitchStep(switch=cases), StepOutcome(output=int(case_num)):
if case_num > 0:
chosen_branch = cases[case_num]

47 changes: 47 additions & 0 deletions agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
@@ -551,3 +551,50 @@ async def _(

result = await handle.result()
assert result["hello"] == "world"


@test("workflow: for each step")
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
data = CreateExecutionRequest(input={"test": "input"})

task = create_task(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(
**{
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [
{
"foreach": {
"in": "'a b c'.split()",
"do": {"evaluate": {"hello": '"world"'}},
},
},
],
}
),
client=client,
)

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
execution, handle = await start_execution(
developer_id=developer_id,
task_id=task.id,
data=data,
client=client,
)

assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input

mock_run_task_execution_workflow.assert_called_once()

result = await handle.result()
assert result["hello"] == "world"

0 comments on commit d2ab540

Please sign in to comment.