Skip to content

Commit

Permalink
chore: misc task route fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Dec 28, 2024
1 parent f2b3039 commit 2fb9970
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
CASE WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at END DESC NULLS LAST
LIMIT $2 OFFSET $3;
"""
# Query to get a single transition
get_execution_transition_query = """
SELECT * FROM transitions
WHERE
execution_id = $1
AND transition_id = $2;
"""


def _transform(d):
Expand Down Expand Up @@ -53,11 +60,12 @@ def _transform(d):
Transition,
transform=_transform,
)
@pg_query
@pg_query(debug=True)
@beartype
async def list_execution_transitions(
*,
execution_id: UUID,
transition_id: UUID | None = None,
limit: int = 100,
offset: int = 0,
sort_by: Literal["created_at"] = "created_at",
Expand All @@ -76,6 +84,14 @@ async def list_execution_transitions(
Returns:
tuple[str, list]: SQL query and parameters for listing execution transitions.
"""
if transition_id is not None:
return (
get_execution_transition_query,
[
str(execution_id),
str(transition_id),
],
)
return (
list_execution_transitions_query,
[
Expand Down
3 changes: 1 addition & 2 deletions agents-api/agents_api/routers/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .list_execution_transitions import list_execution_transitions
from .list_task_executions import list_task_executions
from .list_tasks import list_tasks

# from .patch_execution import patch_execution
from .router import router
from .stream_transitions_events import stream_transitions_events
from .update_execution import update_execution
9 changes: 0 additions & 9 deletions agents-api/agents_api/routers/tasks/create_task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,6 @@ async def create_task_execution(
detail="Invalid request arguments schema",
)

# except QueryException as e:
# if e.code == "transact::assertion_failure":
# raise HTTPException(
# status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
# )

# raise

# get developer data
developer: Developer = await get_developer(developer_id=x_developer_id)

Expand All @@ -159,7 +151,6 @@ async def create_task_execution(

background_tasks.add_task(
create_temporal_lookup,
#
execution_id=execution.id,
workflow_handle=handle,
)
Expand Down
39 changes: 20 additions & 19 deletions agents-api/agents_api/routers/tasks/list_execution_transitions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Literal
from uuid import UUID

from fastapi import HTTPException, status

from ...autogen.openapi_model import (
ListResponse,
Transition,
Expand Down Expand Up @@ -30,22 +32,21 @@ async def list_execution_transitions(
return ListResponse[Transition](items=transitions)


# TODO: Do we need this?
# @router.get("/executions/{execution_id}/transitions/{transition_id}", tags=["tasks"])
# async def get_execution_transition(
# execution_id: UUID,
# transition_id: UUID,
# ) -> Transition:
# try:
# res = [
# row.to_dict()
# for _, row in get_execution_transition_query(
# execution_id, transition_id
# ).iterrows()
# ][0]
# return Transition(**res)
# except (IndexError, KeyError):
# raise HTTPException(
# status_code=status.HTTP_404_NOT_FOUND,
# detail="Transition not found",
# )
@router.get("/executions/{execution_id}/transitions/{transition_id}", tags=["tasks"])
async def get_execution_transition(
execution_id: UUID,
transition_id: UUID,
) -> Transition:
try:
transitions = await list_execution_transitions_query(
execution_id=execution_id,
transition_id=transition_id,
)
if not transitions:
raise IndexError
return transitions[0]
except (IndexError, KeyError):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Transition not found",
)
2 changes: 0 additions & 2 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,13 @@ async def test_execution_started(
# Start the execution
await create_execution_transition(
developer_id=developer_id,
# task_id=task.id,
execution_id=execution.id,
data=CreateTransitionRequest(
type="init",
output={},
current={"workflow": "main", "step": 0},
next={"workflow": "main", "step": 0},
),
# update_execution_status=True,
connection_pool=pool,
)
yield execution
Expand Down
100 changes: 75 additions & 25 deletions agents-api/tests/test_task_routes.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
# Tests for task routes

from agents_api.autogen.openapi_model import (
Transition,
)
from agents_api.queries.executions.create_execution_transition import (
create_execution_transition,
)
from uuid_extensions import uuid7
from ward import test
from ward import skip, test

from .fixtures import (
CreateTransitionRequest,
client,
create_db_pool,
make_request,
pg_dsn,
test_agent,
test_developer_id,
test_execution,
test_execution_started,
test_task,
test_transition,
)
from .utils import patch_testing_temporal

Expand Down Expand Up @@ -121,8 +131,8 @@ def _(make_request=make_request, task=test_task):
assert response.status_code == 200


@test("route: list execution transitions")
def _(make_request=make_request, execution=test_execution, transition=test_transition):
@test("route: list all execution transition")
async def _(make_request=make_request, execution=test_execution_started):
response = make_request(
method="GET",
url=f"/executions/{execution.id!s}/transitions",
Expand All @@ -136,6 +146,46 @@ def _(make_request=make_request, execution=test_execution, transition=test_trans
assert len(transitions) > 0


@test("route: list a single execution transition")
async def _(
dsn=pg_dsn,
make_request=make_request,
execution=test_execution_started,
developer_id=test_developer_id,
):
pool = await create_db_pool(dsn=dsn)

# Create a transition
transition = await create_execution_transition(
developer_id=developer_id,
execution_id=execution.id,
data=CreateTransitionRequest(
type="step",
output={},
current={"workflow": "main", "step": 0},
next={"workflow": "wf1", "step": 1},
),
connection_pool=pool,
)

response = make_request(
method="GET",
url=f"/executions/{execution.id!s}/transitions/{transition.id!s}",
)

assert response.status_code == 200
response = response.json()

assert isinstance(transition, Transition)
assert str(transition.id) == response["id"]
assert transition.type == response["type"]
assert transition.output == response["output"]
assert transition.current.workflow == response["current"]["workflow"]
assert transition.current.step == response["current"]["step"]
assert transition.next.workflow == response["next"]["workflow"]
assert transition.next.step == response["next"]["step"]


@test("route: list task executions")
def _(make_request=make_request, execution=test_execution):
response = make_request(
Expand Down Expand Up @@ -191,10 +241,8 @@ def _(make_request=make_request, agent=test_agent):
assert len(tasks) > 0


# FIXME: This test is failing


@test("route: patch execution")
@skip("Temporal connextion issue")
@test("route: update execution")
async def _(make_request=make_request, task=test_task):
data = {
"input": {},
Expand All @@ -210,26 +258,28 @@ async def _(make_request=make_request, task=test_task):

execution = response.json()

data = {
"status": "running",
}
data = {
"status": "running",
}

response = make_request(
method="PATCH",
url=f"/tasks/{task.id!s}/executions/{execution['id']!s}",
json=data,
)
execution_id = execution["id"]

assert response.status_code == 200
response = make_request(
method="PUT",
url=f"/executions/{execution_id}",
json=data,
)

execution_id = response.json()["id"]
assert response.status_code == 200

response = make_request(
method="GET",
url=f"/executions/{execution_id}",
)
execution_id = response.json()["id"]

assert response.status_code == 200
execution = response.json()
response = make_request(
method="GET",
url=f"/executions/{execution_id}",
)

assert response.status_code == 200
execution = response.json()

assert execution["status"] == "running"
assert execution["status"] == "running"

0 comments on commit 2fb9970

Please sign in to comment.