Skip to content

Commit

Permalink
fix: Fix nasty parallelism race condition on cozodb
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Sep 4, 2024
1 parent 36526f9 commit 45106f2
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 54 deletions.
3 changes: 2 additions & 1 deletion agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
from typing import Any

from beartype import beartype
Expand All @@ -22,7 +23,7 @@ async def base_evaluate(
if extra_lambda_strs:
for k, v in extra_lambda_strs.items():
assert v.startswith("lambda "), "All extra lambdas must start with 'lambda'"
extra_lambdas[k] = eval(v)
extra_lambdas[k] = ast.literal_eval(v)

# Turn the nested dict values from pydantic to dicts where possible
values = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,16 @@ def create_execution_transition(
# Only required for updating the execution status as well
update_execution_status: bool = False,
task_id: UUID | None = None,
) -> tuple[list[str], dict]:
) -> tuple[list[str | None], dict]:
transition_id = transition_id or uuid4()

data.metadata = data.metadata or {}
data.execution_id = execution_id

# TODO: This is a hack to make sure the transition is valid
# (parallel transitions are whack, we should do something better)
is_parallel = data.current.workflow.startswith("PAR:")

# Prepare the transition data
transition_data = data.model_dump(exclude_unset=True, exclude={"id"})

Expand Down Expand Up @@ -184,9 +188,9 @@ def create_execution_transition(
execution_id=execution_id,
parents=[("agents", "agent_id"), ("tasks", "task_id")],
),
validate_status_query,
update_execution_query,
check_last_transition_query,
validate_status_query if not is_parallel else None,
update_execution_query if not is_parallel else None,
check_last_transition_query if not is_parallel else None,
insert_query,
]

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def make_cozo_json_query(fields):


def cozo_query(
func: Callable[P, tuple[str | list[str], dict]] | None = None,
func: Callable[P, tuple[str | list[str | None], dict]] | None = None,
debug: bool | None = None,
):
def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]):
Expand Down
18 changes: 11 additions & 7 deletions agents-api/agents_api/workflows/task_execution/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ async def execute_map_reduce_step_parallel(
batch_pending = []

for j, item in enumerate(batch):
workflow_name = f"`{context.cursor.workflow}`[{context.cursor.step}].mapreduce[{i}][{j}]"
# Parallel batch workflow name
# Note: Added PAR: prefix to easily identify parallel batches in logs
workflow_name = f"PAR:`{context.cursor.workflow}`[{context.cursor.step}].mapreduce[{i}][{j}]"
map_reduce_task = execution_input.task.model_copy()
map_reduce_task.workflows = [Workflow(name=workflow_name, steps=[map_defn])]

Expand All @@ -219,11 +221,13 @@ async def execute_map_reduce_step_parallel(
map_reduce_next_target = TransitionTarget(workflow=workflow_name, step=0)

batch_pending.append(
continue_as_child(
map_reduce_execution_input,
map_reduce_next_target,
previous_inputs + [item],
user_state=user_state,
asyncio.create_task(
continue_as_child(
map_reduce_execution_input,
map_reduce_next_target,
previous_inputs + [item],
user_state=user_state,
)
)
)

Expand All @@ -239,7 +243,7 @@ async def execute_map_reduce_step_parallel(
{"results": results, "_": batch_results},
extra_lambda_strs,
],
schedule_to_close_timeout=timedelta(seconds=2),
schedule_to_close_timeout=timedelta(seconds=5),
)

except BaseException as e:
Expand Down
105 changes: 64 additions & 41 deletions agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)
from agents_api.models.task.create_task import create_task
from agents_api.routers.tasks.create_task_execution import start_execution
from agents_api.models.task.delete_task import delete_task

from .fixtures import cozo_client, test_agent, test_developer_id
from .utils import patch_testing_temporal
Expand Down Expand Up @@ -698,55 +697,79 @@ async def _(
assert [r["res"] for r in result] == ["a", "b", "c"]


for p in range(1, 10):
@test(f"workflow: map reduce step parallel (parallelism={p})")
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
data = CreateExecutionRequest(input={"test": "input"})
@test("workflow: map reduce step parallel (parallelism=10)")
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
data = CreateExecutionRequest(input={"test": "input"})

map_step = {
"over": "'a b c d e f g h i j k l m n o p q r s t u v w x y z'.split()",
"map": {
"evaluate": {"res": "_"},
},
"parallelism": p,
}
map_step = {
"over": "'a b c d e f g h i j k l m n o p q r s t u v w x y z'.split()",
"map": {
"evaluate": {"res": "_ + '!'"},
},
"parallelism": 10,
}

task_def = {
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [map_step],
}
task_def = {
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [map_step],
}

task = create_task(
task = create_task(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(**task_def),
client=client,
)

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
execution, handle = await start_execution(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(**task_def),
task_id=task.id,
data=data,
client=client,
)

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
execution, handle = await start_execution(
developer_id=developer_id,
task_id=task.id,
data=data,
client=client,
)

assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input

mock_run_task_execution_workflow.assert_called_once()
assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input

result = await handle.result()
assert [r["res"] for r in result] == ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"]
mock_run_task_execution_workflow.assert_called_once()

delete_task(developer_id=developer_id, agent_id=agent.id, task_id=task.id, client=client)
result = await handle.result()
assert [r["res"] for r in result] == [
"a!",
"b!",
"c!",
"d!",
"e!",
"f!",
"g!",
"h!",
"i!",
"j!",
"k!",
"l!",
"m!",
"n!",
"o!",
"p!",
"q!",
"r!",
"s!",
"t!",
"u!",
"v!",
"w!",
"x!",
"y!",
"z!",
]


@test("workflow: prompt step")
Expand Down

0 comments on commit 45106f2

Please sign in to comment.