From 3a38e70016ba6e08bcf8874c3841a79221550406 Mon Sep 17 00:00:00 2001 From: Diwank Tomer Date: Wed, 14 Aug 2024 13:46:16 -0400 Subject: [PATCH] feat(agents-api): Add some workflow tests Signed-off-by: Diwank Tomer --- .../activities/task_steps/__init__.py | 18 +----- .../agents_api/common/protocol/tasks.py | 2 +- .../execution/create_execution_transition.py | 57 +++++++++++++++++-- agents-api/tests/test_execution_queries.py | 28 +++++++++ agents-api/tests/test_task_routes.py | 14 ++--- 5 files changed, 91 insertions(+), 28 deletions(-) diff --git a/agents-api/agents_api/activities/task_steps/__init__.py b/agents-api/agents_api/activities/task_steps/__init__.py index 13f1adcfe..e3d3099bb 100644 --- a/agents-api/agents_api/activities/task_steps/__init__.py +++ b/agents-api/agents_api/activities/task_steps/__init__.py @@ -11,7 +11,6 @@ InputChatMLMessage, PromptStep, ToolCallStep, - UpdateExecutionRequest, YieldStep, ) from ...clients import ( @@ -25,9 +24,6 @@ from ...models.execution.create_execution_transition import ( create_execution_transition as create_execution_transition_query, ) -from ...models.execution.update_execution import ( - update_execution as update_execution_query, -) @activity.defn @@ -134,8 +130,7 @@ async def transition_step( "cancelled", ] = "awaiting_input", ): - print("Running transition step") - # raise NotImplementedError() + activity.heartbeat("Running transition step") # Get transition info transition_data = transition_info.model_dump(by_alias=False) @@ -150,16 +145,9 @@ async def transition_step( developer_id=context.developer_id, execution_id=context.execution.id, transition_id=uuid4(), - **transition_data, - ) - - update_execution_query( - developer_id=context.developer_id, + update_execution_status=True, task_id=context.task.id, - execution_id=context.execution.id, - data=UpdateExecutionRequest( - status=execution_status, - ), + **transition_data, ) # Raise if it's a waiting step diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index 850aa700a..1bd4241b2 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -25,7 +25,7 @@ class ExecutionInput(BaseModel): developer_id: UUID execution: Execution - task: TaskSpec + task: TaskSpecDef agent: Agent tools: list[Tool] arguments: dict[str, Any] diff --git a/agents-api/agents_api/models/execution/create_execution_transition.py b/agents-api/agents_api/models/execution/create_execution_transition.py index 06874016d..df538be79 100644 --- a/agents-api/agents_api/models/execution/create_execution_transition.py +++ b/agents-api/agents_api/models/execution/create_execution_transition.py @@ -5,7 +5,11 @@ from pycozo.client import QueryException from pydantic import ValidationError -from ...autogen.openapi_model import CreateTransitionRequest, Transition +from ...autogen.openapi_model import ( + CreateTransitionRequest, + Transition, + UpdateExecutionRequest, +) from ...common.utils.cozo import cozo_process_mutate_data from ..utils import ( cozo_query, @@ -15,6 +19,7 @@ verify_developer_owns_resource_query, wrap_in_class, ) +from .update_execution import update_execution valid_transitions = { # Start state @@ -29,6 +34,16 @@ "step": ["wait", "error", "step", "finish", "cancelled"], } +transition_to_execution_status = { + "init": "queued", + "wait": "awaiting_input", + "resume": "running", + "step": "running", + "finish": "succeeded", + "error": "failed", + "cancelled": "cancelled", +} + @rewrap_exceptions( { @@ -46,17 +61,22 @@ def create_execution_transition( *, developer_id: UUID, execution_id: UUID, - transition_id: UUID | None = None, data: CreateTransitionRequest, + # Only one of these needed + transition_id: UUID | None = None, task_token: str | None = None, + # Only required for updating the execution status as well + update_execution_status: bool = False, + task_id: UUID | None = None, ) -> tuple[list[str], dict]: transition_id = transition_id or uuid4() data.metadata = data.metadata or {} data.execution_id = execution_id + # Prepare the transition data transition_data = data.model_dump(exclude_unset=True) - columns, values = cozo_process_mutate_data( + columns, transition_values = cozo_process_mutate_data( { **transition_data, "task_token": task_token, @@ -87,8 +107,9 @@ def create_execution_transition( :assert some """ + # Prepare the insert query insert_query = f""" - ?[{columns}] <- $values + ?[{columns}] <- $transition_values :insert transitions {{ {columns} @@ -97,6 +118,29 @@ def create_execution_transition( :returning """ + validate_status_query, update_execution_query, update_execution_params = ( + "", + "", + {}, + ) + + if update_execution_status: + assert ( + task_id is not None + ), "task_id is required for updating the execution status" + + # Prepare the execution update query + [*_, validate_status_query, update_execution_query], update_execution_params = ( + update_execution.__wrapped__( + developer_id=developer_id, + task_id=task_id, + execution_id=execution_id, + data=UpdateExecutionRequest( + status=transition_to_execution_status[data.type] + ), + ) + ) + queries = [ verify_developer_id_query(developer_id), verify_developer_owns_resource_query( @@ -105,6 +149,8 @@ def create_execution_transition( execution_id=execution_id, parents=[("agents", "agent_id"), ("tasks", "task_id")], ), + validate_status_query, + update_execution_query, check_last_transition_query, insert_query, ] @@ -112,8 +158,9 @@ def create_execution_transition( return ( queries, { - "values": values, + "transition_values": transition_values, "next_type": data.type, "valid_transitions": valid_transitions, + **update_execution_params, }, ) diff --git a/agents-api/tests/test_execution_queries.py b/agents-api/tests/test_execution_queries.py index c4dec7c5f..af81f75ff 100644 --- a/agents-api/tests/test_execution_queries.py +++ b/agents-api/tests/test_execution_queries.py @@ -83,3 +83,31 @@ def _(client=cozo_client, developer_id=test_developer_id, execution=test_executi assert result is not None assert result.type == "step" assert result.output == {"result": "test"} + + +@test("model: create execution transition with execution update") +def _( + client=cozo_client, + developer_id=test_developer_id, + task=test_task, + execution=test_execution, +): + result = create_execution_transition( + developer_id=developer_id, + execution_id=execution.id, + data=CreateTransitionRequest( + **{ + "type": "step", + "output": {"result": "test"}, + "current": ["main", 0], + "next": None, + } + ), + task_id=task.id, + update_execution_status=True, + client=client, + ) + + assert result is not None + assert result.type == "step" + assert result.output == {"result": "test"} diff --git a/agents-api/tests/test_task_routes.py b/agents-api/tests/test_task_routes.py index fcae8e979..8f2eef7c6 100644 --- a/agents-api/tests/test_task_routes.py +++ b/agents-api/tests/test_task_routes.py @@ -3,24 +3,24 @@ from ward import test -from tests.fixtures import client, make_request, test_execution, test_task +from tests.fixtures import client, make_request, test_agent @test("route: unauthorized should fail") -def _(client=client): +def _(client=client, agent=test_agent): data = dict( name="test user", main={ "kind_": "evaluate", "evaluate": { "additionalProp1": "value1", - } + }, }, ) response = client.request( method="POST", - url="/tasks", + url=f"/agents/{str(agent.id)}/tasks", data=data, ) @@ -28,20 +28,20 @@ def _(client=client): @test("route: create task") -def _(make_request=make_request): +def _(make_request=make_request, agent=test_agent): data = dict( name="test user", main={ "kind_": "evaluate", "evaluate": { "additionalProp1": "value1", - } + }, }, ) response = make_request( method="POST", - url="/tasks", + url=f"/agents/{str(agent.id)}/tasks", json=data, )