Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Add parallelism option to map-reduce step #490

Merged
merged 9 commits into from
Sep 7, 2024
9 changes: 8 additions & 1 deletion agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@
async def base_evaluate(
exprs: str | list[str] | dict[str, str],
values: dict[str, Any] = {},
extra_lambda_strs: dict[str, str] | None = None,
) -> Any | list[Any] | dict[str, Any]:
input_len = 1 if isinstance(exprs, str) else len(exprs)
assert input_len > 0, "exprs must be a non-empty string, list or dict"

extra_lambdas = {}
if extra_lambda_strs:
for k, v in extra_lambda_strs.items():
assert v.startswith("lambda "), "All extra lambdas must start with 'lambda'"
extra_lambdas[k] = eval(v)
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
creatorrr marked this conversation as resolved.
Show resolved Hide resolved

# Turn the nested dict values from pydantic to dicts where possible
values = {
k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in values.items()
Expand All @@ -25,7 +32,7 @@ async def base_evaluate(
# frozen_box doesn't work coz we need some mutability in the values
values = Box(values, frozen_box=False, conversion_box=True)

evaluator = get_evaluator(names=values)
evaluator = get_evaluator(names=values, extra_functions=extra_lambdas)

try:
match exprs:
Expand Down
17 changes: 14 additions & 3 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
from typing import Any
from functools import reduce
from itertools import accumulate
from typing import Any, Callable

import re2
import yaml
Expand All @@ -10,6 +12,7 @@
# TODO: We need to make sure that we dont expose any security issues
ALLOWED_FUNCTIONS = {
"abs": abs,
"accumulate": accumulate,
"all": all,
"any": any,
"bool": bool,
Expand All @@ -22,9 +25,12 @@
"list": list,
"load_json": json.loads,
"load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader),
"map": map,
"match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)),
"max": max,
"min": min,
"range": range,
"reduce": reduce,
"round": round,
"search_regex": lambda pattern, string: re2.search(pattern, string),
"set": set,
Expand All @@ -36,8 +42,13 @@


@beartype
def get_evaluator(names: dict[str, Any]) -> SimpleEval:
evaluator = EvalWithCompoundTypes(names=names, functions=ALLOWED_FUNCTIONS)
def get_evaluator(
names: dict[str, Any], extra_functions: dict[str, Callable] | None = None
) -> SimpleEval:
evaluator = EvalWithCompoundTypes(
names=names, functions=ALLOWED_FUNCTIONS | (extra_functions or {})
)

return evaluator


Expand Down
14 changes: 14 additions & 0 deletions agents-api/agents_api/autogen/Tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,13 @@ class Main(BaseModel):
A special parameter named `results` is the accumulator and `_` is the current value.
"""
initial: Any = []
"""
The initial value of the reduce expression
"""
parallelism: int | None = None
"""
Whether to run the reduce expression in parallel and how many items to run in each batch
"""


class MainModel(BaseModel):
Expand Down Expand Up @@ -381,6 +388,13 @@ class MainModel(BaseModel):
A special parameter named `results` is the accumulator and `_` is the current value.
"""
initial: Any = []
"""
The initial value of the reduce expression
"""
parallelism: int | None = None
"""
Whether to run the reduce expression in parallel and how many items to run in each batch
"""


class ParallelStep(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
api_prefix: str = env.str("AGENTS_API_PREFIX", default="")


# Tasks
# -----
task_max_parallelism: int = env.int("AGENTS_API_TASK_MAX_PARALLELISM", default=100)

# Debug
# -----
debug: bool = env.bool("AGENTS_API_DEBUG", default=False)
Expand Down
Loading
Loading