Skip to content

Commit

Permalink
Merge pull request #490 from julep-ai/f/parallel-map-reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
creatorrr authored Sep 7, 2024
2 parents 6fe6e4e + 888887a commit 60a0129
Show file tree
Hide file tree
Showing 28 changed files with 535 additions and 185 deletions.
20 changes: 19 additions & 1 deletion agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
from typing import Any

from beartype import beartype
Expand All @@ -13,10 +14,27 @@
async def base_evaluate(
exprs: str | list[str] | dict[str, str],
values: dict[str, Any] = {},
extra_lambda_strs: dict[str, str] | None = None,
) -> Any | list[Any] | dict[str, Any]:
input_len = 1 if isinstance(exprs, str) else len(exprs)
assert input_len > 0, "exprs must be a non-empty string, list or dict"

extra_lambdas = {}
if extra_lambda_strs:
for k, v in extra_lambda_strs.items():
v = v.strip()

# Check that all extra lambdas are valid
assert v.startswith("lambda "), "All extra lambdas must start with 'lambda'"

try:
ast.parse(v)
except Exception as e:
raise ValueError(f"Invalid lambda: {v}") from e

# Eval the lambda and add it to the extra lambdas
extra_lambdas[k] = eval(v)

# Turn the nested dict values from pydantic to dicts where possible
values = {
k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in values.items()
Expand All @@ -25,7 +43,7 @@ async def base_evaluate(
# frozen_box doesn't work coz we need some mutability in the values
values = Box(values, frozen_box=False, conversion_box=True)

evaluator = get_evaluator(names=values)
evaluator = get_evaluator(names=values, extra_functions=extra_lambdas)

try:
match exprs:
Expand Down
21 changes: 18 additions & 3 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import json
from typing import Any
from functools import reduce
from itertools import accumulate
from random import random
from time import time
from typing import Any, Callable

import re2
import yaml
Expand All @@ -10,6 +14,7 @@
# TODO: We need to make sure that we dont expose any security issues
ALLOWED_FUNCTIONS = {
"abs": abs,
"accumulate": accumulate,
"all": all,
"any": any,
"bool": bool,
Expand All @@ -22,22 +27,32 @@
"list": list,
"load_json": json.loads,
"load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader),
"map": map,
"match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)),
"max": max,
"min": min,
"random": random,
"range": range,
"reduce": reduce,
"round": round,
"search_regex": lambda pattern, string: re2.search(pattern, string),
"set": set,
"str": str,
"sum": sum,
"time": time,
"tuple": tuple,
"zip": zip,
}


@beartype
def get_evaluator(names: dict[str, Any]) -> SimpleEval:
evaluator = EvalWithCompoundTypes(names=names, functions=ALLOWED_FUNCTIONS)
def get_evaluator(
names: dict[str, Any], extra_functions: dict[str, Callable] | None = None
) -> SimpleEval:
evaluator = EvalWithCompoundTypes(
names=names, functions=ALLOWED_FUNCTIONS | (extra_functions or {})
)

return evaluator


Expand Down
14 changes: 14 additions & 0 deletions agents-api/agents_api/autogen/Tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,13 @@ class Main(BaseModel):
A special parameter named `results` is the accumulator and `_` is the current value.
"""
initial: Any = []
"""
The initial value of the reduce expression
"""
parallelism: Annotated[int | None, Field(None, ge=1, le=100)]
"""
Whether to run the reduce expression in parallel and how many items to run in each batch
"""


class MainModel(BaseModel):
Expand Down Expand Up @@ -381,6 +388,13 @@ class MainModel(BaseModel):
A special parameter named `results` is the accumulator and `_` is the current value.
"""
initial: Any = []
"""
The initial value of the reduce expression
"""
parallelism: Annotated[int | None, Field(None, ge=1, le=100)]
"""
Whether to run the reduce expression in parallel and how many items to run in each batch
"""


class ParallelStep(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
api_prefix: str = env.str("AGENTS_API_PREFIX", default="")


# Tasks
# -----
task_max_parallelism: int = env.int("AGENTS_API_TASK_MAX_PARALLELISM", default=100)

# Debug
# -----
debug: bool = env.bool("AGENTS_API_DEBUG", default=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,16 @@ def create_execution_transition(
# Only required for updating the execution status as well
update_execution_status: bool = False,
task_id: UUID | None = None,
) -> tuple[list[str], dict]:
) -> tuple[list[str | None], dict]:
transition_id = transition_id or uuid4()

data.metadata = data.metadata or {}
data.execution_id = execution_id

# TODO: This is a hack to make sure the transition is valid
# (parallel transitions are whack, we should do something better)
is_parallel = data.current.workflow.startswith("PAR:")

# Prepare the transition data
transition_data = data.model_dump(exclude_unset=True, exclude={"id"})

Expand Down Expand Up @@ -184,9 +188,9 @@ def create_execution_transition(
execution_id=execution_id,
parents=[("agents", "agent_id"), ("tasks", "task_id")],
),
validate_status_query,
update_execution_query,
check_last_transition_query,
validate_status_query if not is_parallel else None,
update_execution_query if not is_parallel else None,
check_last_transition_query if not is_parallel else None,
insert_query,
]

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/get_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_session(
render_templates,
token_budget,
context_overflow,
@ "NOW"
@ "END"
}, updated_at = to_int(validity)
"""

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/list_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def list_sessions(
metadata,
token_budget,
context_overflow,
@ "NOW"
@ "END"
}},
users_p[users, id],
participants[agents, "agent", id],
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/patch_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def patch_session(
input[{session_update_cols}],
ids[session_id, developer_id],
*sessions{{
{rest_fields}, metadata: md, @ "NOW"
{rest_fields}, metadata: md, @ "END"
}},
updated_at = 'ASSERT',
metadata = concat(md, $metadata),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def prepare_session_data(
render_templates,
token_budget,
context_overflow,
@ "NOW"
@ "END"
},
updated_at = to_int(validity),
record = {
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/session/update_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def update_session(
input[{session_update_cols}],
ids[session_id, developer_id],
*sessions{{
{rest_fields}, @ "NOW"
{rest_fields}, @ "END"
}},
updated_at = 'ASSERT'
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/task/get_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_task(
workflows,
created_at,
metadata,
@ 'NOW'
@ 'END'
},
updated_at = to_int(updated_at_ms) / 1000
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/task/list_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def list_tasks(
workflows,
created_at,
metadata,
@ 'NOW'
@ 'END'
}},
updated_at = to_int(updated_at_ms) / 1000
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def mark_session_updated_query(developer_id: UUID | str, session_id: UUID | str)
render_templates,
token_budget,
context_overflow,
@ 'NOW'
@ 'END'
}},
updated_at = [floor(now()), true]
Expand Down Expand Up @@ -173,7 +173,7 @@ def make_cozo_json_query(fields):


def cozo_query(
func: Callable[P, tuple[str | list[str], dict]] | None = None,
func: Callable[P, tuple[str | list[str | None], dict]] | None = None,
debug: bool | None = None,
):
def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]):
Expand Down
Loading

0 comments on commit 60a0129

Please sign in to comment.