Skip to content

Commit

Permalink
fix: Fix tasks routes and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Aug 15, 2024
1 parent 9dca03e commit adb2d9f
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def prepare_execution_input(
)

# Remove the outer curly braces
task_query = task_query.strip()[1:-1]
task_query = task_query[-1].strip()

task_fields = (
"id",
Expand Down
17 changes: 5 additions & 12 deletions agents-api/agents_api/routers/tasks/create_task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,17 @@ async def create_task_execution(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
) -> ResourceCreatedResponse:
try:
task = [
row.to_dict()
for _, row in get_task_query(
task_id=task_id, developer_id=x_developer_id
).iterrows()
][0]
task = get_task_query(
task_id=task_id, developer_id=x_developer_id
)
task_data = task.model_dump()

validate(data.input, task["input_schema"])
validate(data.input, task_data["input_schema"])
except ValidationError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid request arguments schema",
)
except (IndexError, KeyError):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Task not found",
)
except QueryException as e:
if e.code == "transact::assertion_failure":
raise HTTPException(
Expand Down
11 changes: 3 additions & 8 deletions agents-api/agents_api/routers/tasks/get_execution_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,8 @@
@router.get("/executions/{execution_id}", tags=["executions"])
async def get_execution_details(execution_id: UUID4) -> Execution:
try:
res = [
row.to_dict()
for _, row in get_execution_query(execution_id=execution_id).iterrows()
][0]
return Execution(**res)
except (IndexError, KeyError):
return get_execution_query(execution_id=execution_id)
except AssertionError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Execution not found",
status_code=status.HTTP_404_NOT_FOUND, detail="Execution not found"
)
41 changes: 19 additions & 22 deletions agents-api/agents_api/routers/tasks/get_task_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,32 @@
from .router import router


@router.get("/agents/{agent_id}/tasks/{task_id}", tags=["tasks"])
@router.get("/tasks/{task_id}", tags=["tasks"])
async def get_task_details(
task_id: UUID4,
agent_id: UUID4,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
) -> Task:
not_found = HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
)

try:
resp = [
row.to_dict()
for _, row in get_task_query(
agent_id=agent_id, task_id=task_id, developer_id=x_developer_id
).iterrows()
][0]

for workflow in resp["workflows"]:
if workflow["name"] == "main":
resp["main"] = workflow["steps"]
break

return Task(**resp)
except (IndexError, KeyError):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Task not found",
task = get_task_query(
developer_id=x_developer_id, task_id=task_id
)
task_data = task.model_dump()
except AssertionError:
raise not_found
except QueryException as e:
if e.code == "transact::assertion_failure":
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
)
raise not_found

raise

for workflow in task_data.get("workflows", []):
if workflow["name"] == "main":
task_data["main"] = workflow.get("steps", [])
break

return Task(**task_data)

16 changes: 11 additions & 5 deletions agents-api/agents_api/routers/tasks/list_execution_transitions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

from pydantic import UUID4

from agents_api.autogen.openapi_model import (
Expand All @@ -16,13 +18,17 @@ async def list_execution_transitions(
execution_id: UUID4,
limit: int = 100,
offset: int = 0,
sort_by: Literal["created_at", "updated_at"] = "created_at",
direction: Literal["asc", "desc"] = "desc",
) -> ListResponse[Transition]:
res = list_execution_transitions_query(
execution_id=execution_id, limit=limit, offset=offset
)
return ListResponse[Transition](
items=[Transition(**row.to_dict()) for _, row in res.iterrows()]
transitions = list_execution_transitions_query(
execution_id=execution_id,
limit=limit,
offset=offset,
sort_by=sort_by,
direction=direction,
)
return ListResponse[Transition](items=transitions)


# @router.get("/executions/{execution_id}/transitions/{transition_id}", tags=["tasks"])
Expand Down
17 changes: 11 additions & 6 deletions agents-api/agents_api/routers/tasks/list_task_executions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated
from typing import Annotated, Literal

from fastapi import Depends
from pydantic import UUID4
Expand All @@ -21,10 +21,15 @@ async def list_task_executions(
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
limit: int = 100,
offset: int = 0,
sort_by: Literal["created_at", "updated_at"] = "created_at",
direction: Literal["asc", "desc"] = "desc",
) -> ListResponse[Execution]:
res = list_task_executions_query(
task_id=task_id, developer_id=x_developer_id, limit=limit, offse=offset
)
return ListResponse[Execution](
items=[Execution(**row.to_dict()) for _, row in res.iterrows()]
executions = list_task_executions_query(
task_id=task_id,
developer_id=x_developer_id,
limit=limit,
offse=offset,
sort_by=sort_by,
direction=direction,
)
return ListResponse[Execution](items=executions)
6 changes: 3 additions & 3 deletions agents-api/agents_api/routers/tasks/list_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ async def list_tasks(
)

tasks = []
for _, row in query_results.iterrows():
row_dict = row.to_dict()
for row in query_results:
row_dict = row.model_dump()

for workflow in row_dict["workflows"]:
for workflow in row_dict.get("workflows", []):
if workflow["name"] == "main":
row_dict["main"] = workflow["steps"]
break
Expand Down
25 changes: 8 additions & 17 deletions agents-api/agents_api/routers/tasks/patch_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from agents_api.autogen.openapi_model import (
Execution,
ResourceUpdatedResponse,
UpdateExecutionRequest,
)
from agents_api.dependencies.developer_id import get_developer_id
Expand All @@ -22,20 +23,10 @@ async def patch_execution(
task_id: UUID4,
execution_id: UUID4,
data: UpdateExecutionRequest,
) -> Execution:
try:
res = [
row.to_dict()
for _, row in update_execution_query(
developer_id=x_developer_id,
task_id=task_id,
execution_id=execution_id,
data=data,
).iterrows()
][0]
return Execution(**res)
except (IndexError, KeyError):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Execution not found",
)
) -> ResourceUpdatedResponse:
return update_execution_query(
developer_id=x_developer_id,
task_id=task_id,
execution_id=execution_id,
data=data,
)
32 changes: 32 additions & 0 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CreateSessionRequest,
CreateTaskRequest,
CreateToolRequest,
CreateTransitionRequest,
CreateUserRequest,
)
from agents_api.env import api_key, api_key_header_name
Expand All @@ -25,6 +26,9 @@
from agents_api.models.docs.create_doc import create_doc
from agents_api.models.docs.delete_doc import delete_doc
from agents_api.models.execution.create_execution import create_execution
from agents_api.models.execution.create_execution_transition import (
create_execution_transition,
)
from agents_api.models.session.create_session import create_session
from agents_api.models.session.delete_session import delete_session
from agents_api.models.task.create_task import create_task
Expand Down Expand Up @@ -308,6 +312,34 @@ def test_execution(
)


@fixture(scope="global")
def test_transition(
client=cozo_client,
developer_id=test_developer_id,
execution=test_execution,
):
transition = create_execution_transition(
developer_id=developer_id,
execution_id=execution.id,
data=CreateTransitionRequest(
type="step",
output={},
current=[],
next=[],
),
client=client,
)

yield transition

client.run(
f"""
?[transition_id] <- ["{str(transition.id)}"]
:delete transitions {{ transition_id }}
"""
)


@fixture(scope="global")
def test_tool(
client=cozo_client,
Expand Down
Loading

0 comments on commit adb2d9f

Please sign in to comment.