Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Add parallelism option to map-reduce step #490

Merged
merged 9 commits into from
Sep 7, 2024
20 changes: 19 additions & 1 deletion agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
from typing import Any

from beartype import beartype
Expand All @@ -13,10 +14,27 @@
async def base_evaluate(
exprs: str | list[str] | dict[str, str],
values: dict[str, Any] = {},
extra_lambda_strs: dict[str, str] | None = None,
) -> Any | list[Any] | dict[str, Any]:
input_len = 1 if isinstance(exprs, str) else len(exprs)
assert input_len > 0, "exprs must be a non-empty string, list or dict"

extra_lambdas = {}
if extra_lambda_strs:
for k, v in extra_lambda_strs.items():
v = v.strip()

# Check that all extra lambdas are valid
assert v.startswith("lambda "), "All extra lambdas must start with 'lambda'"

try:
ast.parse(v)
except Exception as e:
raise ValueError(f"Invalid lambda: {v}") from e

# Eval the lambda and add it to the extra lambdas
extra_lambdas[k] = eval(v)
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
creatorrr marked this conversation as resolved.
Show resolved Hide resolved

# Turn the nested dict values from pydantic to dicts where possible
values = {
k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in values.items()
Expand All @@ -25,7 +43,7 @@ async def base_evaluate(
# frozen_box doesn't work coz we need some mutability in the values
values = Box(values, frozen_box=False, conversion_box=True)

evaluator = get_evaluator(names=values)
evaluator = get_evaluator(names=values, extra_functions=extra_lambdas)

try:
match exprs:
Expand Down
21 changes: 18 additions & 3 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import json
from typing import Any
from functools import reduce
from itertools import accumulate
from random import random
from time import time
from typing import Any, Callable

import re2
import yaml
Expand All @@ -10,6 +14,7 @@
# TODO: We need to make sure that we dont expose any security issues
ALLOWED_FUNCTIONS = {
"abs": abs,
"accumulate": accumulate,
"all": all,
"any": any,
"bool": bool,
Expand All @@ -22,22 +27,32 @@
"list": list,
"load_json": json.loads,
"load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader),
"map": map,
"match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)),
"max": max,
"min": min,
"random": random,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding random to ALLOWED_FUNCTIONS can lead to non-deterministic behavior. Ensure this is intended and consider the implications for reproducibility.

"range": range,
"reduce": reduce,
"round": round,
"search_regex": lambda pattern, string: re2.search(pattern, string),
"set": set,
"str": str,
"sum": sum,
"time": time,
"tuple": tuple,
"zip": zip,
}


@beartype
def get_evaluator(names: dict[str, Any]) -> SimpleEval:
evaluator = EvalWithCompoundTypes(names=names, functions=ALLOWED_FUNCTIONS)
def get_evaluator(
names: dict[str, Any], extra_functions: dict[str, Callable] | None = None
) -> SimpleEval:
evaluator = EvalWithCompoundTypes(
names=names, functions=ALLOWED_FUNCTIONS | (extra_functions or {})
)

return evaluator


Expand Down
14 changes: 14 additions & 0 deletions agents-api/agents_api/autogen/Tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,13 @@ class Main(BaseModel):
A special parameter named `results` is the accumulator and `_` is the current value.
"""
initial: Any = []
"""
The initial value of the reduce expression
"""
parallelism: Annotated[int | None, Field(None, ge=1, le=100)]
"""
Whether to run the reduce expression in parallel and how many items to run in each batch
"""


class MainModel(BaseModel):
Expand Down Expand Up @@ -381,6 +388,13 @@ class MainModel(BaseModel):
A special parameter named `results` is the accumulator and `_` is the current value.
"""
initial: Any = []
"""
The initial value of the reduce expression
"""
parallelism: Annotated[int | None, Field(None, ge=1, le=100)]
"""
Whether to run the reduce expression in parallel and how many items to run in each batch
"""


class ParallelStep(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
api_prefix: str = env.str("AGENTS_API_PREFIX", default="")


# Tasks
# -----
task_max_parallelism: int = env.int("AGENTS_API_TASK_MAX_PARALLELISM", default=100)

# Debug
# -----
debug: bool = env.bool("AGENTS_API_DEBUG", default=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,16 @@ def create_execution_transition(
# Only required for updating the execution status as well
update_execution_status: bool = False,
task_id: UUID | None = None,
) -> tuple[list[str], dict]:
) -> tuple[list[str | None], dict]:
transition_id = transition_id or uuid4()

data.metadata = data.metadata or {}
data.execution_id = execution_id

# TODO: This is a hack to make sure the transition is valid
# (parallel transitions are whack, we should do something better)
is_parallel = data.current.workflow.startswith("PAR:")

# Prepare the transition data
transition_data = data.model_dump(exclude_unset=True, exclude={"id"})

Expand Down Expand Up @@ -184,9 +188,9 @@ def create_execution_transition(
execution_id=execution_id,
parents=[("agents", "agent_id"), ("tasks", "task_id")],
),
validate_status_query,
update_execution_query,
check_last_transition_query,
validate_status_query if not is_parallel else None,
update_execution_query if not is_parallel else None,
check_last_transition_query if not is_parallel else None,
insert_query,
]

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/get_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_session(
render_templates,
token_budget,
context_overflow,
@ "NOW"
@ "END"
}, updated_at = to_int(validity)
"""

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/list_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def list_sessions(
metadata,
token_budget,
context_overflow,
@ "NOW"
@ "END"
}},
users_p[users, id],
participants[agents, "agent", id],
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/patch_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def patch_session(
input[{session_update_cols}],
ids[session_id, developer_id],
*sessions{{
{rest_fields}, metadata: md, @ "NOW"
{rest_fields}, metadata: md, @ "END"
}},
updated_at = 'ASSERT',
metadata = concat(md, $metadata),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def prepare_session_data(
render_templates,
token_budget,
context_overflow,
@ "NOW"
@ "END"
},
updated_at = to_int(validity),
record = {
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/update_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def update_session(
input[{session_update_cols}],
ids[session_id, developer_id],
*sessions{{
{rest_fields}, @ "NOW"
{rest_fields}, @ "END"
}},
updated_at = 'ASSERT'

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/task/get_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_task(
workflows,
created_at,
metadata,
@ 'NOW'
@ 'END'
},
updated_at = to_int(updated_at_ms) / 1000

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/task/list_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def list_tasks(
workflows,
created_at,
metadata,
@ 'NOW'
@ 'END'
}},
updated_at = to_int(updated_at_ms) / 1000

Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def mark_session_updated_query(developer_id: UUID | str, session_id: UUID | str)
render_templates,
token_budget,
context_overflow,
@ 'NOW'
@ 'END'
}},
updated_at = [floor(now()), true]

Expand Down Expand Up @@ -173,7 +173,7 @@ def make_cozo_json_query(fields):


def cozo_query(
func: Callable[P, tuple[str | list[str], dict]] | None = None,
func: Callable[P, tuple[str | list[str | None], dict]] | None = None,
debug: bool | None = None,
):
def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]):
Expand Down
Loading
Loading