Skip to content

Commit

Permalink
feat: Add more tests for parallel map step
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Sep 4, 2024
1 parent c96e0fb commit 36526f9
Showing 1 changed file with 41 additions and 37 deletions.
78 changes: 41 additions & 37 deletions agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from agents_api.models.task.create_task import create_task
from agents_api.routers.tasks.create_task_execution import start_execution
from agents_api.models.task.delete_task import delete_task

from .fixtures import cozo_client, test_agent, test_developer_id
from .utils import patch_testing_temporal
Expand Down Expand Up @@ -697,52 +698,55 @@ async def _(
assert [r["res"] for r in result] == ["a", "b", "c"]


@test("workflow: map reduce step parallel (basic)")
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
data = CreateExecutionRequest(input={"test": "input"})

map_step = {
"over": "'a b c d e f'.split()",
"map": {
"evaluate": {"res": "_"},
},
"parallelism": 2,
}
for p in range(1, 10):
@test(f"workflow: map reduce step parallel (parallelism={p})")
async def _(
client=cozo_client,
developer_id=test_developer_id,
agent=test_agent,
):
data = CreateExecutionRequest(input={"test": "input"})

task_def = {
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [map_step],
}
map_step = {
"over": "'a b c d e f g h i j k l m n o p q r s t u v w x y z'.split()",
"map": {
"evaluate": {"res": "_"},
},
"parallelism": p,
}

task = create_task(
developer_id=developer_id,
agent_id=agent.id,
data=CreateTaskRequest(**task_def),
client=client,
)
task_def = {
"name": "test task",
"description": "test task about",
"input_schema": {"type": "object", "additionalProperties": True},
"main": [map_step],
}

async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
execution, handle = await start_execution(
task = create_task(
developer_id=developer_id,
task_id=task.id,
data=data,
agent_id=agent.id,
data=CreateTaskRequest(**task_def),
client=client,
)

assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input
async with patch_testing_temporal() as (_, mock_run_task_execution_workflow):
execution, handle = await start_execution(
developer_id=developer_id,
task_id=task.id,
data=data,
client=client,
)

mock_run_task_execution_workflow.assert_called_once()
assert handle is not None
assert execution.task_id == task.id
assert execution.input == data.input

result = await handle.result()
assert [r["res"] for r in result] == ["a", "b", "c", "d", "e", "f"]
mock_run_task_execution_workflow.assert_called_once()

result = await handle.result()
assert [r["res"] for r in result] == ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"]

delete_task(developer_id=developer_id, agent_id=agent.id, task_id=task.id, client=client)


@test("workflow: prompt step")
Expand Down

0 comments on commit 36526f9

Please sign in to comment.