From 45106f2102a03bceedd0b9c4ad2214df8a32245f Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Wed, 4 Sep 2024 19:23:27 -0400 Subject: [PATCH] fix: Fix nasty parallelism race condition on cozodb Signed-off-by: Diwank Singh Tomer --- .../activities/task_steps/base_evaluate.py | 3 +- .../execution/create_execution_transition.py | 12 +- agents-api/agents_api/models/utils.py | 2 +- .../workflows/task_execution/helpers.py | 18 +-- agents-api/tests/test_execution_workflow.py | 105 +++++++++++------- 5 files changed, 86 insertions(+), 54 deletions(-) diff --git a/agents-api/agents_api/activities/task_steps/base_evaluate.py b/agents-api/agents_api/activities/task_steps/base_evaluate.py index 24d2459ec..fb9412cd9 100644 --- a/agents-api/agents_api/activities/task_steps/base_evaluate.py +++ b/agents-api/agents_api/activities/task_steps/base_evaluate.py @@ -1,3 +1,4 @@ +import ast from typing import Any from beartype import beartype @@ -22,7 +23,7 @@ async def base_evaluate( if extra_lambda_strs: for k, v in extra_lambda_strs.items(): assert v.startswith("lambda "), "All extra lambdas must start with 'lambda'" - extra_lambdas[k] = eval(v) + extra_lambdas[k] = ast.literal_eval(v) # Turn the nested dict values from pydantic to dicts where possible values = { diff --git a/agents-api/agents_api/models/execution/create_execution_transition.py b/agents-api/agents_api/models/execution/create_execution_transition.py index 89e924bf6..79497120b 100644 --- a/agents-api/agents_api/models/execution/create_execution_transition.py +++ b/agents-api/agents_api/models/execution/create_execution_transition.py @@ -82,12 +82,16 @@ def create_execution_transition( # Only required for updating the execution status as well update_execution_status: bool = False, task_id: UUID | None = None, -) -> tuple[list[str], dict]: +) -> tuple[list[str | None], dict]: transition_id = transition_id or uuid4() data.metadata = data.metadata or {} data.execution_id = execution_id + # TODO: This is a hack to make sure the transition is valid + # (parallel transitions are whack, we should do something better) + is_parallel = data.current.workflow.startswith("PAR:") + # Prepare the transition data transition_data = data.model_dump(exclude_unset=True, exclude={"id"}) @@ -184,9 +188,9 @@ def create_execution_transition( execution_id=execution_id, parents=[("agents", "agent_id"), ("tasks", "task_id")], ), - validate_status_query, - update_execution_query, - check_last_transition_query, + validate_status_query if not is_parallel else None, + update_execution_query if not is_parallel else None, + check_last_transition_query if not is_parallel else None, insert_query, ] diff --git a/agents-api/agents_api/models/utils.py b/agents-api/agents_api/models/utils.py index 4613fe7c7..98bf2a590 100644 --- a/agents-api/agents_api/models/utils.py +++ b/agents-api/agents_api/models/utils.py @@ -173,7 +173,7 @@ def make_cozo_json_query(fields): def cozo_query( - func: Callable[P, tuple[str | list[str], dict]] | None = None, + func: Callable[P, tuple[str | list[str | None], dict]] | None = None, debug: bool | None = None, ): def cozo_query_dec(func: Callable[P, tuple[str | list[Any], dict]]): diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index ae16acca1..9bb383299 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -210,7 +210,9 @@ async def execute_map_reduce_step_parallel( batch_pending = [] for j, item in enumerate(batch): - workflow_name = f"`{context.cursor.workflow}`[{context.cursor.step}].mapreduce[{i}][{j}]" + # Parallel batch workflow name + # Note: Added PAR: prefix to easily identify parallel batches in logs + workflow_name = f"PAR:`{context.cursor.workflow}`[{context.cursor.step}].mapreduce[{i}][{j}]" map_reduce_task = execution_input.task.model_copy() map_reduce_task.workflows = [Workflow(name=workflow_name, steps=[map_defn])] @@ -219,11 +221,13 @@ async def execute_map_reduce_step_parallel( map_reduce_next_target = TransitionTarget(workflow=workflow_name, step=0) batch_pending.append( - continue_as_child( - map_reduce_execution_input, - map_reduce_next_target, - previous_inputs + [item], - user_state=user_state, + asyncio.create_task( + continue_as_child( + map_reduce_execution_input, + map_reduce_next_target, + previous_inputs + [item], + user_state=user_state, + ) ) ) @@ -239,7 +243,7 @@ async def execute_map_reduce_step_parallel( {"results": results, "_": batch_results}, extra_lambda_strs, ], - schedule_to_close_timeout=timedelta(seconds=2), + schedule_to_close_timeout=timedelta(seconds=5), ) except BaseException as e: diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py index d29001071..81daa8933 100644 --- a/agents-api/tests/test_execution_workflow.py +++ b/agents-api/tests/test_execution_workflow.py @@ -14,7 +14,6 @@ ) from agents_api.models.task.create_task import create_task from agents_api.routers.tasks.create_task_execution import start_execution -from agents_api.models.task.delete_task import delete_task from .fixtures import cozo_client, test_agent, test_developer_id from .utils import patch_testing_temporal @@ -698,55 +697,79 @@ async def _( assert [r["res"] for r in result] == ["a", "b", "c"] -for p in range(1, 10): - @test(f"workflow: map reduce step parallel (parallelism={p})") - async def _( - client=cozo_client, - developer_id=test_developer_id, - agent=test_agent, - ): - data = CreateExecutionRequest(input={"test": "input"}) +@test("workflow: map reduce step parallel (parallelism=10)") +async def _( + client=cozo_client, + developer_id=test_developer_id, + agent=test_agent, +): + data = CreateExecutionRequest(input={"test": "input"}) - map_step = { - "over": "'a b c d e f g h i j k l m n o p q r s t u v w x y z'.split()", - "map": { - "evaluate": {"res": "_"}, - }, - "parallelism": p, - } + map_step = { + "over": "'a b c d e f g h i j k l m n o p q r s t u v w x y z'.split()", + "map": { + "evaluate": {"res": "_ + '!'"}, + }, + "parallelism": 10, + } - task_def = { - "name": "test task", - "description": "test task about", - "input_schema": {"type": "object", "additionalProperties": True}, - "main": [map_step], - } + task_def = { + "name": "test task", + "description": "test task about", + "input_schema": {"type": "object", "additionalProperties": True}, + "main": [map_step], + } - task = create_task( + task = create_task( + developer_id=developer_id, + agent_id=agent.id, + data=CreateTaskRequest(**task_def), + client=client, + ) + + async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): + execution, handle = await start_execution( developer_id=developer_id, - agent_id=agent.id, - data=CreateTaskRequest(**task_def), + task_id=task.id, + data=data, client=client, ) - async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): - execution, handle = await start_execution( - developer_id=developer_id, - task_id=task.id, - data=data, - client=client, - ) - - assert handle is not None - assert execution.task_id == task.id - assert execution.input == data.input - - mock_run_task_execution_workflow.assert_called_once() + assert handle is not None + assert execution.task_id == task.id + assert execution.input == data.input - result = await handle.result() - assert [r["res"] for r in result] == ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"] + mock_run_task_execution_workflow.assert_called_once() - delete_task(developer_id=developer_id, agent_id=agent.id, task_id=task.id, client=client) + result = await handle.result() + assert [r["res"] for r in result] == [ + "a!", + "b!", + "c!", + "d!", + "e!", + "f!", + "g!", + "h!", + "i!", + "j!", + "k!", + "l!", + "m!", + "n!", + "o!", + "p!", + "q!", + "r!", + "s!", + "t!", + "u!", + "v!", + "w!", + "x!", + "y!", + "z!", + ] @test("workflow: prompt step")