Skip to content

Commit

Permalink
feat: Add "switch" and "for each" activities
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Aug 20, 2024
1 parent 8263aea commit a63fb4a
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 2 deletions.
32 changes: 32 additions & 0 deletions agents-api/agents_api/activities/task_steps/for_each_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import ForeachStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...env import testing


@beartype
async def for_each_step(context: StepContext) -> StepOutcome:
try:
assert isinstance(context.current_step, ForeachStep)

return StepOutcome(output=simple_eval(context.current_step.foreach.in_, names=context.model_dump()))
except BaseException as e:
logging.error(f"Error in for_each_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported if_else_step directly
# They do the same thing, so we dont need to mock the if_else_step function
mock_if_else_step = for_each_step

for_each_step = activity.defn(name="if_else_step")(
for_each_step if not testing else mock_if_else_step
)
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/switch_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

@beartype
async def switch_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for logging, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
try:
assert isinstance(context.current_step, SwitchStep)

Expand Down
72 changes: 72 additions & 0 deletions agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
WaitForInputStep,
Workflow,
YieldStep,
ForeachStep,
ForeachDo,
SwitchStep,
MapReduceStep,
)
from ..common.protocol.tasks import (
ExecutionInput,
Expand Down Expand Up @@ -194,6 +198,74 @@ async def transition(**kwargs) -> None:
args=if_else_args,
)

case ForeachStep(foreach=ForeachDo(do=do_step)), StepOutcome(
output=items
):
for i, item in enumerate(items):
# Create a faux workflow
foreach_wf_name = (
f"`{context.cursor.workflow}`[{context.cursor.step}].foreach[{i}]"
)

foreach_task = execution_input.task.model_copy()
foreach_task.workflows = [
Workflow(name=foreach_wf_name, steps=[do_step])
]

# Create a new execution input
foreach_execution_input = execution_input.model_copy()
foreach_execution_input.task = foreach_task

# Set the next target to the chosen branch
foreach_next_target = TransitionTarget(workflow=foreach_wf_name, step=0)

foreach_args = [
foreach_execution_input,
foreach_next_target,
previous_inputs + [item],
]

# Execute the chosen branch and come back here
state.output = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=foreach_args,
)

case SwitchStep(switch=cases), StepOutcome(
output=int(case_num)
):
if case_num > 0:
chosen_branch = cases[case_num]

# Create a faux workflow
case_wf_name = (
f"`{context.cursor.workflow}`[{context.cursor.step}].case"
)

case_task = execution_input.task.model_copy()
case_task.workflows = [
Workflow(name=case_wf_name, steps=[chosen_branch.then])
]

# Create a new execution input
case_execution_input = execution_input.model_copy()
case_execution_input.task = case_task

# Set the next target to the chosen branch
case_next_target = TransitionTarget(workflow=case_wf_name, step=0)

case_args = [
case_execution_input,
case_next_target,
previous_inputs,
]

# Execute the chosen branch and come back here
state.output = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=case_args,
)

case SleepStep(
sleep=SleepFor(
seconds=seconds,
Expand Down

0 comments on commit a63fb4a

Please sign in to comment.