Skip to content

Commit

Permalink
refactor: Refactor for each step
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Nov 27, 2024
1 parent f9f2308 commit 85a688c
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 87 deletions.
1 change: 0 additions & 1 deletion agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .base_evaluate import base_evaluate
from .cozo_query_step import cozo_query_step
from .evaluate_step import evaluate_step
from .for_each_step import for_each_step
from .get_value_step import get_value_step
from .if_else_step import if_else_step
from .log_step import log_step
Expand Down
13 changes: 10 additions & 3 deletions agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ def __init__(self, error, expression, values):


# Recursive evaluation helper function
def _recursive_evaluate(expr, evaluator: SimpleEval):
def _recursive_evaluate(expr, evaluator: SimpleEval, eval_prompt_prefix: str = "$_"):
if isinstance(expr, str):
try:
return evaluator.eval(expr)
result = expr
if expr.startswith(eval_prompt_prefix):
result = evaluator.eval(expr)

return result
except Exception as e:
if activity.in_activity():
evaluate_error = EvaluateError(e, expr, evaluator.names)
Expand Down Expand Up @@ -69,6 +73,7 @@ async def base_evaluate(
exprs: Any,
values: dict[str, Any] = {},
extra_lambda_strs: dict[str, str] | None = None,
eval_prompt_prefix: str = "$_",
) -> Any | list[Any] | dict[str, Any]:
input_len = 1 if isinstance(exprs, str) else len(exprs)
assert input_len > 0, "exprs must be a non-empty string, list or dict"
Expand Down Expand Up @@ -100,7 +105,9 @@ async def base_evaluate(
evaluator: SimpleEval = get_evaluator(names=values, extra_functions=extra_lambdas)

# Recursively evaluate the expression
result = _recursive_evaluate(exprs, evaluator)
result = _recursive_evaluate(
exprs, evaluator, eval_prompt_prefix=eval_prompt_prefix
)
return result


Expand Down
36 changes: 0 additions & 36 deletions agents-api/agents_api/activities/task_steps/for_each_step.py

This file was deleted.

14 changes: 5 additions & 9 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ...activities.excecute_api_call import execute_api_call
from ...activities.execute_integration import execute_integration
from ...activities.execute_system import execute_system
from ...activities.task_steps.base_evaluate import base_evaluate
from ...activities.sync_items_remote import load_inputs_remote, save_inputs_remote
from ...autogen.openapi_model import (
ApiCallDef,
Expand Down Expand Up @@ -52,7 +53,6 @@
from ...env import debug, temporal_schedule_to_close_timeout, testing
from .helpers import (
continue_as_child,
execute_foreach_step,
execute_if_else_branch,
execute_map_reduce_step,
execute_map_reduce_step_parallel,
Expand Down Expand Up @@ -96,7 +96,6 @@
ReturnStep: task_steps.return_step,
YieldStep: task_steps.yield_step,
IfElseWorkflowStep: task_steps.if_else_step,
ForeachStep: task_steps.for_each_step,
MapReduceStep: task_steps.map_reduce_step,
SetStep: task_steps.set_value_step,
# GetStep: task_steps.get_value_step,
Expand Down Expand Up @@ -271,13 +270,10 @@ async def run(
state = PartialTransition(output=result)

case ForeachStep(foreach=ForeachDo(do=do_step)), StepOutcome(output=items):
result = await execute_foreach_step(
context=context,
execution_input=execution_input,
do_step=do_step,
items=items,
previous_inputs=previous_inputs,
)
result = [
base_evaluate(do_step.foreach.in_, {f"item": item})
for item in items
]
state = PartialTransition(output=result)

case MapReduceStep(
Expand Down
39 changes: 1 addition & 38 deletions agents-api/agents_api/workflows/task_execution/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

with workflow.unsafe.imports_passed_through():
from ...activities import task_steps
from ...activities.task_steps.base_evaluate import base_evaluate
from ...autogen.openapi_model import (
EvaluateStep,
TransitionTarget,
Expand Down Expand Up @@ -124,44 +125,6 @@ async def execute_if_else_branch(
)


@auto_blob_store_workflow
async def execute_foreach_step(
*,
context: StepContext,
execution_input: ExecutionInput,
do_step: WorkflowStep,
items: list[Any],
previous_inputs: RemoteList | list[Any],
user_state: dict[str, Any] = {},
) -> Any:
workflow.logger.info(f"Foreach step: Iterating over {len(items)} items")
results = []

for i, item in enumerate(items):
foreach_wf_name = (
f"`{context.cursor.workflow}`[{context.cursor.step}].foreach[{i}]"
)
foreach_task = execution_input.task.model_copy()
foreach_task.workflows = [
Workflow(name=foreach_wf_name, steps=[do_step]),
*foreach_task.workflows,
]

foreach_execution_input = execution_input.model_copy()
foreach_execution_input.task = foreach_task
foreach_next_target = TransitionTarget(workflow=foreach_wf_name, step=0)

result = await continue_as_child(
foreach_execution_input,
foreach_next_target,
previous_inputs + [item],
user_state=user_state,
)
results.append(result)

return results


@auto_blob_store_workflow
async def execute_map_reduce_step(
*,
Expand Down

0 comments on commit 85a688c

Please sign in to comment.