From adb2d9f4e7d3a768be546d16f22bf003f85ad11a Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Thu, 15 Aug 2024 15:54:28 +0300 Subject: [PATCH] fix: Fix tasks routes and tests --- .../execution/prepare_execution_input.py | 2 +- .../routers/tasks/create_task_execution.py | 17 +-- .../routers/tasks/get_execution_details.py | 11 +- .../routers/tasks/get_task_details.py | 41 ++++--- .../tasks/list_execution_transitions.py | 16 ++- .../routers/tasks/list_task_executions.py | 17 ++- .../agents_api/routers/tasks/list_tasks.py | 6 +- .../routers/tasks/patch_execution.py | 25 ++--- agents-api/tests/fixtures.py | 32 ++++++ agents-api/tests/test_task_routes.py | 103 ++++++++---------- 10 files changed, 140 insertions(+), 130 deletions(-) diff --git a/agents-api/agents_api/models/execution/prepare_execution_input.py b/agents-api/agents_api/models/execution/prepare_execution_input.py index c1a767740..2a07d4d3f 100644 --- a/agents-api/agents_api/models/execution/prepare_execution_input.py +++ b/agents-api/agents_api/models/execution/prepare_execution_input.py @@ -65,7 +65,7 @@ def prepare_execution_input( ) # Remove the outer curly braces - task_query = task_query.strip()[1:-1] + task_query = task_query[-1].strip() task_fields = ( "id", diff --git a/agents-api/agents_api/routers/tasks/create_task_execution.py b/agents-api/agents_api/routers/tasks/create_task_execution.py index 9f78de8ca..6f2c0bc3a 100644 --- a/agents-api/agents_api/routers/tasks/create_task_execution.py +++ b/agents-api/agents_api/routers/tasks/create_task_execution.py @@ -42,24 +42,17 @@ async def create_task_execution( x_developer_id: Annotated[UUID4, Depends(get_developer_id)], ) -> ResourceCreatedResponse: try: - task = [ - row.to_dict() - for _, row in get_task_query( - task_id=task_id, developer_id=x_developer_id - ).iterrows() - ][0] + task = get_task_query( + task_id=task_id, developer_id=x_developer_id + ) + task_data = task.model_dump() - validate(data.input, task["input_schema"]) + validate(data.input, task_data["input_schema"]) except ValidationError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request arguments schema", ) - except (IndexError, KeyError): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Task not found", - ) except QueryException as e: if e.code == "transact::assertion_failure": raise HTTPException( diff --git a/agents-api/agents_api/routers/tasks/get_execution_details.py b/agents-api/agents_api/routers/tasks/get_execution_details.py index 6a9a01caa..e6f87b8af 100644 --- a/agents-api/agents_api/routers/tasks/get_execution_details.py +++ b/agents-api/agents_api/routers/tasks/get_execution_details.py @@ -14,13 +14,8 @@ @router.get("/executions/{execution_id}", tags=["executions"]) async def get_execution_details(execution_id: UUID4) -> Execution: try: - res = [ - row.to_dict() - for _, row in get_execution_query(execution_id=execution_id).iterrows() - ][0] - return Execution(**res) - except (IndexError, KeyError): + return get_execution_query(execution_id=execution_id) + except AssertionError: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Execution not found", + status_code=status.HTTP_404_NOT_FOUND, detail="Execution not found" ) diff --git a/agents-api/agents_api/routers/tasks/get_task_details.py b/agents-api/agents_api/routers/tasks/get_task_details.py index dbf8bf7da..0890d5aeb 100644 --- a/agents-api/agents_api/routers/tasks/get_task_details.py +++ b/agents-api/agents_api/routers/tasks/get_task_details.py @@ -13,35 +13,32 @@ from .router import router -@router.get("/agents/{agent_id}/tasks/{task_id}", tags=["tasks"]) +@router.get("/tasks/{task_id}", tags=["tasks"]) async def get_task_details( task_id: UUID4, - agent_id: UUID4, x_developer_id: Annotated[UUID4, Depends(get_developer_id)], ) -> Task: + not_found = HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Task not found" + ) + try: - resp = [ - row.to_dict() - for _, row in get_task_query( - agent_id=agent_id, task_id=task_id, developer_id=x_developer_id - ).iterrows() - ][0] - - for workflow in resp["workflows"]: - if workflow["name"] == "main": - resp["main"] = workflow["steps"] - break - - return Task(**resp) - except (IndexError, KeyError): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Task not found", + task = get_task_query( + developer_id=x_developer_id, task_id=task_id ) + task_data = task.model_dump() + except AssertionError: + raise not_found except QueryException as e: if e.code == "transact::assertion_failure": - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Task not found" - ) + raise not_found raise + + for workflow in task_data.get("workflows", []): + if workflow["name"] == "main": + task_data["main"] = workflow.get("steps", []) + break + + return Task(**task_data) + diff --git a/agents-api/agents_api/routers/tasks/list_execution_transitions.py b/agents-api/agents_api/routers/tasks/list_execution_transitions.py index 36f61601e..98b7f1cd7 100644 --- a/agents-api/agents_api/routers/tasks/list_execution_transitions.py +++ b/agents-api/agents_api/routers/tasks/list_execution_transitions.py @@ -1,3 +1,5 @@ +from typing import Literal + from pydantic import UUID4 from agents_api.autogen.openapi_model import ( @@ -16,13 +18,17 @@ async def list_execution_transitions( execution_id: UUID4, limit: int = 100, offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", ) -> ListResponse[Transition]: - res = list_execution_transitions_query( - execution_id=execution_id, limit=limit, offset=offset - ) - return ListResponse[Transition]( - items=[Transition(**row.to_dict()) for _, row in res.iterrows()] + transitions = list_execution_transitions_query( + execution_id=execution_id, + limit=limit, + offset=offset, + sort_by=sort_by, + direction=direction, ) + return ListResponse[Transition](items=transitions) # @router.get("/executions/{execution_id}/transitions/{transition_id}", tags=["tasks"]) diff --git a/agents-api/agents_api/routers/tasks/list_task_executions.py b/agents-api/agents_api/routers/tasks/list_task_executions.py index 718d13a0d..24a541c7b 100644 --- a/agents-api/agents_api/routers/tasks/list_task_executions.py +++ b/agents-api/agents_api/routers/tasks/list_task_executions.py @@ -1,4 +1,4 @@ -from typing import Annotated +from typing import Annotated, Literal from fastapi import Depends from pydantic import UUID4 @@ -21,10 +21,15 @@ async def list_task_executions( x_developer_id: Annotated[UUID4, Depends(get_developer_id)], limit: int = 100, offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", ) -> ListResponse[Execution]: - res = list_task_executions_query( - task_id=task_id, developer_id=x_developer_id, limit=limit, offse=offset - ) - return ListResponse[Execution]( - items=[Execution(**row.to_dict()) for _, row in res.iterrows()] + executions = list_task_executions_query( + task_id=task_id, + developer_id=x_developer_id, + limit=limit, + offse=offset, + sort_by=sort_by, + direction=direction, ) + return ListResponse[Execution](items=executions) diff --git a/agents-api/agents_api/routers/tasks/list_tasks.py b/agents-api/agents_api/routers/tasks/list_tasks.py index 771b9f38c..5066fcc96 100644 --- a/agents-api/agents_api/routers/tasks/list_tasks.py +++ b/agents-api/agents_api/routers/tasks/list_tasks.py @@ -32,10 +32,10 @@ async def list_tasks( ) tasks = [] - for _, row in query_results.iterrows(): - row_dict = row.to_dict() + for row in query_results: + row_dict = row.model_dump() - for workflow in row_dict["workflows"]: + for workflow in row_dict.get("workflows", []): if workflow["name"] == "main": row_dict["main"] = workflow["steps"] break diff --git a/agents-api/agents_api/routers/tasks/patch_execution.py b/agents-api/agents_api/routers/tasks/patch_execution.py index 74ae3fc72..56c38fc38 100644 --- a/agents-api/agents_api/routers/tasks/patch_execution.py +++ b/agents-api/agents_api/routers/tasks/patch_execution.py @@ -5,6 +5,7 @@ from agents_api.autogen.openapi_model import ( Execution, + ResourceUpdatedResponse, UpdateExecutionRequest, ) from agents_api.dependencies.developer_id import get_developer_id @@ -22,20 +23,10 @@ async def patch_execution( task_id: UUID4, execution_id: UUID4, data: UpdateExecutionRequest, -) -> Execution: - try: - res = [ - row.to_dict() - for _, row in update_execution_query( - developer_id=x_developer_id, - task_id=task_id, - execution_id=execution_id, - data=data, - ).iterrows() - ][0] - return Execution(**res) - except (IndexError, KeyError): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Execution not found", - ) +) -> ResourceUpdatedResponse: + return update_execution_query( + developer_id=x_developer_id, + task_id=task_id, + execution_id=execution_id, + data=data, + ) diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 437c95083..e8dfff790 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -16,6 +16,7 @@ CreateSessionRequest, CreateTaskRequest, CreateToolRequest, + CreateTransitionRequest, CreateUserRequest, ) from agents_api.env import api_key, api_key_header_name @@ -25,6 +26,9 @@ from agents_api.models.docs.create_doc import create_doc from agents_api.models.docs.delete_doc import delete_doc from agents_api.models.execution.create_execution import create_execution +from agents_api.models.execution.create_execution_transition import ( + create_execution_transition, +) from agents_api.models.session.create_session import create_session from agents_api.models.session.delete_session import delete_session from agents_api.models.task.create_task import create_task @@ -308,6 +312,34 @@ def test_execution( ) +@fixture(scope="global") +def test_transition( + client=cozo_client, + developer_id=test_developer_id, + execution=test_execution, +): + transition = create_execution_transition( + developer_id=developer_id, + execution_id=execution.id, + data=CreateTransitionRequest( + type="step", + output={}, + current=[], + next=[], + ), + client=client, + ) + + yield transition + + client.run( + f""" + ?[transition_id] <- ["{str(transition.id)}"] + :delete transitions {{ transition_id }} + """ + ) + + @fixture(scope="global") def test_tool( client=cozo_client, diff --git a/agents-api/tests/test_task_routes.py b/agents-api/tests/test_task_routes.py index 80aae9e7f..ce75348bb 100644 --- a/agents-api/tests/test_task_routes.py +++ b/agents-api/tests/test_task_routes.py @@ -4,7 +4,14 @@ from ward import test -from tests.fixtures import client, make_request, test_agent +from tests.fixtures import ( + client, + make_request, + test_agent, + test_execution, + test_task, + test_transition, +) @test("route: unauthorized should fail") @@ -54,8 +61,7 @@ def _(make_request=make_request, agent=test_agent): @test("route: create task execution") -def _(make_request=make_request): - task_id = str(uuid4()) +def _(make_request=make_request, task=test_task): data = dict( input={}, metadata={}, @@ -63,7 +69,7 @@ def _(make_request=make_request): response = make_request( method="POST", - url=f"/tasks/{task_id}/executions", + url=f"/tasks/{str(task.id)}/executions", json=data, ) @@ -83,12 +89,10 @@ def _(make_request=make_request): @test("route: get execution exists") -def _(make_request=make_request): - execution_id = str(uuid4()) - +def _(make_request=make_request, execution=test_execution): response = make_request( method="GET", - url=f"/executions/{execution_id}", + url=f"/executions/{str(execution.id)}", ) assert response.status_code == 200 @@ -96,106 +100,93 @@ def _(make_request=make_request): @test("route: get task not exists") def _(make_request=make_request): - data = dict( - name="test user", - main="test user about", - ) + task_id = str(uuid4()) response = make_request( - method="POST", - url="/tasks", - json=data, + method="GET", + url=f"/tasks/{task_id}", ) - assert response.status_code == 201 + assert response.status_code == 404 @test("route: get task exists") -def _(make_request=make_request): - data = dict( - name="test user", - main="test user about", - ) - +def _(make_request=make_request, task=test_task): response = make_request( - method="POST", - url="/tasks", - json=data, + method="GET", + url=f"/tasks/{str(task.id)}", ) - assert response.status_code == 201 + assert response.status_code == 200 @test("model: list execution transitions") -def _(make_request=make_request): +def _(make_request=make_request, transition=test_transition): response = make_request( method="GET", - url="/users", + url=f"/executions/{str(transition.execution_id)}/transitions", ) assert response.status_code == 200 response = response.json() - users = response["items"] + transitions = response["items"] - assert isinstance(users, list) - assert len(users) > 0 + assert isinstance(transitions, list) + assert len(transitions) > 0 @test("model: list task executions") -def _(make_request=make_request): +def _(make_request=make_request, execution=test_execution): response = make_request( method="GET", - url="/users", + url=f"/tasks/{str(execution.task_id)}/executions", ) assert response.status_code == 200 response = response.json() - users = response["items"] + executions = response["items"] - assert isinstance(users, list) - assert len(users) > 0 + assert isinstance(executions, list) + assert len(executions) > 0 @test("model: list tasks") -def _(make_request=make_request): +def _(make_request=make_request, agent=test_agent): response = make_request( method="GET", - url="/users", + url=f"/agents/{str(agent.id)}/tasks", ) assert response.status_code == 200 response = response.json() - users = response["items"] + tasks = response["items"] - assert isinstance(users, list) - assert len(users) > 0 + assert isinstance(tasks, list) + assert len(tasks) > 0 @test("model: patch execution") -def _(make_request=make_request): +def _(make_request=make_request, execution=test_execution): + data = dict( + status="running", + ) + response = make_request( - method="GET", - url="/users", + method="PATCH", + url=f"/tasks/{str(execution.task_id)}/executions/{str(execution.id)}", + json=data, ) assert response.status_code == 200 - response = response.json() - users = response["items"] - - assert isinstance(users, list) - assert len(users) > 0 + execution_id = response.json()["id"] -@test("model: update execution") -def _(make_request=make_request): response = make_request( method="GET", - url="/users", + url=f"/executions/{execution_id}", ) assert response.status_code == 200 - response = response.json() - users = response["items"] + execution = response.json() - assert isinstance(users, list) - assert len(users) > 0 + assert execution["status"] == "running"