Skip to content

Commit

Permalink
chore: misc refactor + fixed list file route test
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Dec 25, 2024
1 parent 3caf682 commit eaada88
Show file tree
Hide file tree
Showing 22 changed files with 409 additions and 185 deletions.
23 changes: 23 additions & 0 deletions agents-api/agents_api/queries/executions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# ruff: noqa: F401, F403, F405

"""
The `execution` module provides SQL query functions for managing executions
in the TimescaleDB database. This includes operations for:
- Creating new executions
- Deleting executions
- Retrieving execution history
- Listing executions with filtering and pagination
"""

from .count_executions import count_executions
from .create_execution import create_execution
from .create_execution_transition import create_execution_transition
Expand All @@ -9,3 +19,16 @@
from .list_executions import list_executions
from .lookup_temporal_data import lookup_temporal_data
from .prepare_execution_input import prepare_execution_input


__all__ = [
"count_executions",
"create_execution",
"create_execution_transition",
"get_execution",
"get_execution_transition",
"list_execution_transitions",
"list_executions",
"lookup_temporal_data",
"prepare_execution_input",
]
54 changes: 39 additions & 15 deletions agents-api/agents_api/queries/executions/count_executions.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,40 @@
from typing import Any, Literal, TypeVar
from typing import Any, Literal
from uuid import UUID

from beartype import beartype

from fastapi import HTTPException
from sqlglot import parse_one
import asyncpg
from ..utils import (
pg_query,
wrap_in_class,
rewrap_exceptions,
partialclass,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")

sql_query = """SELECT COUNT(*) FROM latest_executions
# Query to count executions for a given task
execution_count_query = parse_one("""
SELECT COUNT(*) FROM latest_executions
WHERE
developer_id = $1
AND task_id = $2;
"""
""").sql(pretty=True)


# @rewrap_exceptions(
# {
# QueryException: partialclass(HTTPException, status_code=400),
# ValidationError: partialclass(HTTPException, status_code=400),
# TypeError: partialclass(HTTPException, status_code=400),
# }
# )
@rewrap_exceptions(
{
asyncpg.NoDataFoundError: partialclass(
HTTPException,
status_code=404,
detail="No executions found for the specified task"
),
asyncpg.ForeignKeyViolationError: partialclass(
HTTPException,
status_code=404,
detail="The specified developer or task does not exist"
),
}
)
@wrap_in_class(dict, one=True)
@pg_query
@beartype
Expand All @@ -33,4 +43,18 @@ async def count_executions(
developer_id: UUID,
task_id: UUID,
) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]:
return (sql_query, [developer_id, task_id], "fetchrow")
"""
Count the number of executions for a given task.
Parameters:
developer_id (UUID): The ID of the developer.
task_id (UUID): The ID of the task.
Returns:
tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for counting executions.
"""
return (
execution_count_query,
[developer_id, task_id],
"fetchrow",
)
46 changes: 34 additions & 12 deletions agents-api/agents_api/queries/executions/create_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@
from ...common.utils.datetime import utcnow
from ...common.utils.types import dict_like
from ...metrics.counters import increase_counter
from sqlglot import parse_one
import asyncpg
from fastapi import HTTPException
from ..utils import (
pg_query,
wrap_in_class,
rewrap_exceptions,
partialclass,
)
from .constants import OUTPUT_UNNEST_KEY

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")

sql_query = """
create_execution_query = parse_one("""
INSERT INTO executions
(
developer_id,
Expand All @@ -37,16 +40,23 @@
1
)
RETURNING *;
"""
""").sql(pretty=True)


# @rewrap_exceptions(
# {
# QueryException: partialclass(HTTPException, status_code=400),
# ValidationError: partialclass(HTTPException, status_code=400),
# TypeError: partialclass(HTTPException, status_code=400),
# }
# )
@rewrap_exceptions(
{
asyncpg.NoDataFoundError: partialclass(
HTTPException,
status_code=404,
detail="No executions found for the specified task"
),
asyncpg.ForeignKeyViolationError: partialclass(
HTTPException,
status_code=404,
detail="The specified developer or task does not exist"
),
}
)
@wrap_in_class(
Execution,
one=True,
Expand All @@ -67,6 +77,18 @@ async def create_execution(
execution_id: UUID | None = None,
data: Annotated[CreateExecutionRequest | dict, dict_like(CreateExecutionRequest)],
) -> tuple[str, list]:
"""
Create a new execution.
Parameters:
developer_id (UUID): The ID of the developer.
task_id (UUID): The ID of the task.
execution_id (UUID | None): The ID of the execution.
data (CreateExecutionRequest | dict): The data for the execution.
Returns:
tuple[str, list]: SQL query and parameters for creating the execution.
"""
execution_id = execution_id or uuid7()

developer_id = str(developer_id)
Expand All @@ -86,7 +108,7 @@ async def create_execution(
execution_data["output"] = {OUTPUT_UNNEST_KEY: execution_data["output"]}

return (
sql_query,
create_execution_query,
[
developer_id,
task_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,20 @@
CreateTransitionRequest,
Transition,
)
import asyncpg
from fastapi import HTTPException
from sqlglot import parse_one
from ...common.utils.datetime import utcnow
from ...metrics.counters import increase_counter
from ..utils import (
pg_query,
wrap_in_class,
rewrap_exceptions,
partialclass,
)

sql_query = """
# Query to create a transition
create_execution_transition_query = parse_one("""
INSERT INTO transitions
(
execution_id,
Expand Down Expand Up @@ -43,7 +49,7 @@
$10
)
RETURNING *;
"""
""").sql(pretty=True)


def validate_transition_targets(data: CreateTransitionRequest) -> None:
Expand Down Expand Up @@ -80,13 +86,20 @@ def validate_transition_targets(data: CreateTransitionRequest) -> None:
raise ValueError(f"Invalid transition type: {data.type}")


# rewrap_exceptions(
# {
# QueryException: partialclass(HTTPException, status_code=400),
# ValidationError: partialclass(HTTPException, status_code=400),
# TypeError: partialclass(HTTPException, status_code=400),
# }
# )
@rewrap_exceptions(
{
asyncpg.NoDataFoundError: partialclass(
HTTPException,
status_code=404,
detail="No executions found for the specified task"
),
asyncpg.ForeignKeyViolationError: partialclass(
HTTPException,
status_code=404,
detail="The specified developer or task does not exist"
),
}
)
@wrap_in_class(
Transition,
transform=lambda d: {
Expand All @@ -111,6 +124,19 @@ async def create_execution_transition(
transition_id: UUID | None = None,
task_token: str | None = None,
) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]:
"""
Create a new execution transition.
Parameters:
developer_id (UUID): The ID of the developer.
execution_id (UUID): The ID of the execution.
data (CreateTransitionRequest): The data for the transition.
transition_id (UUID | None): The ID of the transition.
task_token (str | None): The task token.
Returns:
tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for creating the transition.
"""
transition_id = transition_id or uuid7()
data.metadata = data.metadata or {}
data.execution_id = execution_id
Expand Down Expand Up @@ -140,7 +166,7 @@ async def create_execution_transition(
)

return (
sql_query,
create_execution_transition_query,
[
execution_id,
transition_id,
Expand Down
50 changes: 35 additions & 15 deletions agents-api/agents_api/queries/executions/create_temporal_lookup.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from typing import TypeVar
from uuid import UUID

from beartype import beartype
from temporalio.client import WorkflowHandle
from sqlglot import parse_one
import asyncpg
from fastapi import HTTPException
from uuid import UUID

from ...metrics.counters import increase_counter
from ..utils import (
pg_query,
rewrap_exceptions,
partialclass,
)

T = TypeVar("T")

sql_query = """
# Query to create a temporal lookup
create_temporal_lookup_query = parse_one("""
INSERT INTO temporal_executions_lookup
(
execution_id,
Expand All @@ -29,17 +32,23 @@
$5
)
RETURNING *;
"""
""").sql(pretty=True)


# @rewrap_exceptions(
# {
# AssertionError: partialclass(HTTPException, status_code=404),
# QueryException: partialclass(HTTPException, status_code=400),
# ValidationError: partialclass(HTTPException, status_code=400),
# TypeError: partialclass(HTTPException, status_code=400),
# }
# )
@rewrap_exceptions(
{
asyncpg.NoDataFoundError: partialclass(
HTTPException,
status_code=404,
detail="No executions found for the specified task"
),
asyncpg.ForeignKeyViolationError: partialclass(
HTTPException,
status_code=404,
detail="The specified developer or task does not exist"
),
}
)
@pg_query
@increase_counter("create_temporal_lookup")
@beartype
Expand All @@ -49,11 +58,22 @@ async def create_temporal_lookup(
execution_id: UUID,
workflow_handle: WorkflowHandle,
) -> tuple[str, list]:
"""
Create a temporal lookup for a given execution.
Parameters:
developer_id (UUID): The ID of the developer.
execution_id (UUID): The ID of the execution.
workflow_handle (WorkflowHandle): The workflow handle.
Returns:
tuple[str, list]: SQL query and parameters for creating the temporal lookup.
"""
developer_id = str(developer_id)
execution_id = str(execution_id)

return (
sql_query,
create_temporal_lookup_query,
[
execution_id,
workflow_handle.id,
Expand Down
Loading

0 comments on commit eaada88

Please sign in to comment.