Skip to content

Commit

Permalink
fix(agents-api): Fix map reduce step and activity
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 23, 2024
1 parent c3d9ce7 commit 2a12e3c
Show file tree
Hide file tree
Showing 72 changed files with 610 additions and 1,558 deletions.
1 change: 1 addition & 0 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ruff: noqa: F401, F403, F405

from .base_evaluate import base_evaluate
from .evaluate_step import evaluate_step
from .for_each_step import for_each_step
from .if_else_step import if_else_step
Expand Down
44 changes: 44 additions & 0 deletions agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Any

from beartype import beartype
from temporalio import activity

from ...env import testing
from ..utils import get_evaluator


@beartype
async def base_evaluate(
exprs: str | list[str] | dict[str, str],
values: dict[str, Any] = {},
) -> Any | list[Any] | dict[str, Any]:
input_len = 1 if isinstance(exprs, str) else len(exprs)
assert input_len > 0, "exprs must be a non-empty string, list or dict"

evaluator = get_evaluator(names=values)

try:
match exprs:
case str():
return evaluator.eval(exprs)

case list():
return [evaluator.eval(expr) for expr in exprs]

case dict():
return {k: evaluator.eval(v) for k, v in exprs.items()}

except BaseException as e:
if activity.in_activity():
activity.logger.error(f"Error in base_evaluate: {e}")

raise


# Note: This is here just for clarity. We could have just imported base_evaluate directly
# They do the same thing, so we dont need to mock the base_evaluate function
mock_base_evaluate = base_evaluate

base_evaluate = activity.defn(name="base_evaluate")(
base_evaluate if not testing else mock_base_evaluate
)
29 changes: 18 additions & 11 deletions agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,35 @@
from typing import Any

from beartype import beartype
from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import EvaluateStep
from ...activities.utils import simple_eval_dict
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing


@beartype
async def evaluate_step(context: StepContext) -> StepOutcome:
# NOTE: This activity is only for returning immediately, so we just evaluate the expression
# Hence, it's a local activity and SHOULD NOT fail
async def evaluate_step(
context: StepContext,
additional_values: dict[str, Any] = {},
override_expr: dict[str, str] | None = None,
) -> StepOutcome:
try:
assert isinstance(context.current_step, EvaluateStep)

exprs = context.current_step.evaluate
output = simple_eval_dict(exprs, values=context.model_dump())

expr = (
override_expr
if override_expr is not None
else context.current_step.evaluate
)

values = context.model_dump() | additional_values
output = simple_eval_dict(expr, values)
result = StepOutcome(output=output)

return result

except BaseException as e:
activity.logger.error(f"Error in evaluate_step: {e}")
return StepOutcome(error=str(e))
return StepOutcome(error=str(e) or repr(e))


# Note: This is here just for clarity. We could have just imported evaluate_step directly
Expand Down
10 changes: 5 additions & 5 deletions agents-api/agents_api/activities/task_steps/for_each_step.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import ForeachStep
Expand All @@ -10,18 +9,19 @@
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


@beartype
async def for_each_step(context: StepContext) -> StepOutcome:
try:
assert isinstance(context.current_step, ForeachStep)

return StepOutcome(
output=simple_eval(
context.current_step.foreach.in_, names=context.model_dump()
)
output = await base_evaluate(
context.current_step.foreach.in_, context.model_dump()
)
return StepOutcome(output=output)

except BaseException as e:
logging.error(f"Error in for_each_step: {e}")
return StepOutcome(error=str(e))
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/task_steps/if_else_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import IfElseWorkflowStep
Expand All @@ -8,6 +7,7 @@
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


@beartype
Expand All @@ -18,7 +18,7 @@ async def if_else_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, IfElseWorkflowStep)

expr: str = context.current_step.if_
output = simple_eval(expr, names=context.model_dump())
output = await base_evaluate(expr, context.model_dump())
output: bool = bool(output)

result = StepOutcome(output=output)
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/task_steps/log_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import LogStep
Expand All @@ -8,6 +7,7 @@
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


@beartype
Expand All @@ -18,7 +18,7 @@ async def log_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, LogStep)

expr: str = context.current_step.log
output = simple_eval(expr, names=context.model_dump())
output = await base_evaluate(expr, context.model_dump())

result = StepOutcome(output=output)
return result
Expand Down
11 changes: 5 additions & 6 deletions agents-api/agents_api/activities/task_steps/map_reduce_step.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import MapReduceStep
Expand All @@ -10,18 +9,18 @@
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


@beartype
async def map_reduce_step(context: StepContext) -> StepOutcome:
try:
assert isinstance(context.current_step, MapReduceStep)

return StepOutcome(
output=simple_eval(
context.current_step.map.over, names=context.model_dump()
)
)
output = await base_evaluate(context.current_step.over, context.model_dump())

return StepOutcome(output=output)

except BaseException as e:
logging.error(f"Error in map_reduce_step: {e}")
return StepOutcome(error=str(e))
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/task_steps/return_step.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import ReturnStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from ...env import testing
from .base_evaluate import base_evaluate


async def return_step(context: StepContext) -> StepOutcome:
Expand All @@ -16,7 +16,7 @@ async def return_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, ReturnStep)

exprs: dict[str, str] = context.current_step.return_
output = simple_eval_dict(exprs, values=context.model_dump())
output = await base_evaluate(exprs, context.model_dump())

result = StepOutcome(output=output)
return result
Expand Down
7 changes: 4 additions & 3 deletions agents-api/agents_api/activities/task_steps/switch_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import SwitchStep
Expand All @@ -8,6 +7,7 @@
StepOutcome,
)
from ...env import testing
from ..utils import get_evaluator


@beartype
Expand All @@ -17,11 +17,12 @@ async def switch_step(context: StepContext) -> StepOutcome:

# Assume that none of the cases evaluate to truthy
output: int = -1

cases: list[str] = [c.case for c in context.current_step.switch]

evaluator = get_evaluator(names=context.model_dump())

for i, case in enumerate(cases):
result = simple_eval(case, names=context.model_dump())
result = evaluator.eval(case)

if result:
output = i
Expand Down
11 changes: 0 additions & 11 deletions agents-api/agents_api/activities/task_steps/utils.py

This file was deleted.

17 changes: 11 additions & 6 deletions agents-api/agents_api/activities/task_steps/wait_for_input_step.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
from ...autogen.openapi_model import WaitForInputStep
from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing
from .base_evaluate import base_evaluate


async def wait_for_input_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, WaitForInputStep)
try:
assert isinstance(context.current_step, WaitForInputStep)

exprs = context.current_step.wait_for_input
output = simple_eval_dict(exprs, values=context.model_dump())
exprs = context.current_step.wait_for_input
output = await base_evaluate(exprs, context.model_dump())

result = StepOutcome(output=output)
return result
result = StepOutcome(output=output)
return result

except BaseException as e:
activity.logger.error(f"Error in wait_for_input_step: {e}")
return StepOutcome(error=str(e))


# Note: This is here just for clarity. We could have just imported wait_for_input_step directly
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ...common.protocol.tasks import StepContext, StepOutcome
from ...env import testing
from .utils import simple_eval_dict
from .base_evaluate import base_evaluate


@beartype
Expand All @@ -19,14 +19,14 @@ async def yield_step(context: StepContext) -> StepOutcome:

all_workflows = context.execution_input.task.workflows
workflow = context.current_step.workflow
exprs = context.current_step.arguments

assert workflow in [
wf.name for wf in all_workflows
], f"Workflow {workflow} not found in task"

# Evaluate the expressions in the arguments
exprs = context.current_step.arguments
arguments = simple_eval_dict(exprs, values=context.model_dump())
arguments = await base_evaluate(exprs, context.model_dump())

# Transition to the first step of that workflow
transition_target = TransitionTarget(
Expand Down
17 changes: 17 additions & 0 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Any

from beartype import beartype
from simpleeval import EvalWithCompoundTypes, SimpleEval


@beartype
def get_evaluator(names: dict[str, Any]) -> SimpleEval:
evaluator = EvalWithCompoundTypes(names=names)
return evaluator


@beartype
def simple_eval_dict(exprs: dict[str, str], values: dict[str, Any]) -> dict[str, Any]:
evaluator = get_evaluator(names=values)

return {k: evaluator.eval(v) for k, v in exprs.items()}
10 changes: 0 additions & 10 deletions agents-api/agents_api/autogen/Common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,6 @@
from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, RootModel


class JinjaTemplate(RootModel[str]):
model_config = ConfigDict(
populate_by_name=True,
)
root: str
"""
A valid jinja template.
"""


class Limit(RootModel[int]):
model_config = ConfigDict(
populate_by_name=True,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/autogen/Executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Transition(BaseModel):
Field(json_schema_extra={"readOnly": True}),
]
execution_id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})]
output: Annotated[dict[str, Any], Field(json_schema_extra={"readOnly": True})]
output: Annotated[Any, Field(json_schema_extra={"readOnly": True})]
current: Annotated[TransitionTarget, Field(json_schema_extra={"readOnly": True})]
next: Annotated[
TransitionTarget | None, Field(json_schema_extra={"readOnly": True})
Expand Down
Loading

0 comments on commit 2a12e3c

Please sign in to comment.