Skip to content

Commit

Permalink
Merge pull request #993 from julep-ai/f/executions-queries
Browse files Browse the repository at this point in the history
Executions queries
  • Loading branch information
whiterabbit1983 authored Dec 26, 2024
2 parents f0deed9 + e962645 commit 4e1c975
Show file tree
Hide file tree
Showing 18 changed files with 64 additions and 84 deletions.
5 changes: 2 additions & 3 deletions agents-api/agents_api/queries/executions/count_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ..utils import (
partialclass,
Expand All @@ -14,12 +13,12 @@
)

# Query to count executions for a given task
execution_count_query = parse_one("""
execution_count_query = """
SELECT COUNT(*) FROM latest_executions
WHERE
developer_id = $1
AND task_id = $2;
""").sql(pretty=True)
"""


@rewrap_exceptions(
Expand Down
5 changes: 2 additions & 3 deletions agents-api/agents_api/queries/executions/create_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
from uuid_extensions import uuid7

from ...autogen.openapi_model import CreateExecutionRequest, Execution
Expand All @@ -19,7 +18,7 @@
)
from .constants import OUTPUT_UNNEST_KEY

create_execution_query = parse_one("""
create_execution_query = """
INSERT INTO executions
(
developer_id,
Expand All @@ -39,7 +38,7 @@
1
)
RETURNING *;
""").sql(pretty=True)
"""


@rewrap_exceptions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
from uuid_extensions import uuid7

from ...autogen.openapi_model import (
Expand All @@ -21,7 +20,7 @@
)

# Query to create a transition
create_execution_transition_query = parse_one("""
create_execution_transition_query = """
INSERT INTO transitions
(
execution_id,
Expand Down Expand Up @@ -49,7 +48,7 @@
$10
)
RETURNING *;
""").sql(pretty=True)
"""


# FIXME: Remove this function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
from temporalio.client import WorkflowHandle

from ...metrics.counters import increase_counter
Expand All @@ -14,7 +13,7 @@
)

# Query to create a temporal lookup
create_temporal_lookup_query = parse_one("""
create_temporal_lookup_query = """
INSERT INTO temporal_executions_lookup
(
execution_id,
Expand All @@ -32,7 +31,7 @@
$5
)
RETURNING *;
""").sql(pretty=True)
"""


@rewrap_exceptions(
Expand All @@ -54,22 +53,19 @@
@beartype
async def create_temporal_lookup(
*,
developer_id: UUID, # FIXME: Remove this parameter
execution_id: UUID,
workflow_handle: WorkflowHandle,
) -> tuple[str, list]:
"""
Create a temporal lookup for a given execution.
Parameters:
developer_id (UUID): The ID of the developer.
execution_id (UUID): The ID of the execution.
workflow_handle (WorkflowHandle): The workflow handle.
Returns:
tuple[str, list]: SQL query and parameters for creating the temporal lookup.
"""
developer_id = str(developer_id)
execution_id = str(execution_id)

return (
Expand Down
5 changes: 2 additions & 3 deletions agents-api/agents_api/queries/executions/get_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ...autogen.openapi_model import Execution
from ..utils import (
Expand All @@ -16,12 +15,12 @@
from .constants import OUTPUT_UNNEST_KEY

# Query to get an execution
get_execution_query = parse_one("""
get_execution_query = """
SELECT * FROM latest_executions
WHERE
execution_id = $1
LIMIT 1;
""").sql(pretty=True)
"""


@rewrap_exceptions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ...autogen.openapi_model import Transition
from ..utils import (
Expand All @@ -16,13 +15,13 @@

# FIXME: Use latest_transitions instead of transitions
# Query to get an execution transition
get_execution_transition_query = parse_one("""
get_execution_transition_query = """
SELECT * FROM transitions
WHERE
transition_id = $1
OR task_token = $2
LIMIT 1;
""").sql(pretty=True)
"""


def _transform(d):
Expand Down Expand Up @@ -60,15 +59,13 @@ def _transform(d):
@beartype
async def get_execution_transition(
*,
developer_id: UUID, # FIXME: Remove this parameter
transition_id: UUID | None = None,
task_token: str | None = None,
) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]:
"""
Get an execution transition by its ID or task token.
Parameters:
developer_id (UUID): The ID of the developer.
transition_id (UUID | None): The ID of the transition.
task_token (str | None): The task token.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@
@beartype
async def get_paused_execution_token(
*,
developer_id: UUID,
execution_id: UUID,
) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]:
"""
Get a paused execution token for a given execution.
Parameters:
developer_id (UUID): The ID of the developer.
execution_id (UUID): The ID of the execution.
Returns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
)

# Query to get temporal workflow data
get_temporal_workflow_data_query = parse_one("""
get_temporal_workflow_data_query = """
SELECT id, run_id, result_run_id, first_execution_run_id FROM temporal_executions_lookup
WHERE
execution_id = $1
LIMIT 1;
""").sql(pretty=True)
"""


@rewrap_exceptions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,30 @@
import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ...autogen.openapi_model import Transition
from ...common.utils.datetime import utcnow
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class

# Query to list execution transitions
list_execution_transitions_query = parse_one("""
list_execution_transitions_query = """
SELECT * FROM transitions
WHERE
execution_id = $1
ORDER BY
CASE WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at END ASC NULLS LAST,
CASE WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at END DESC NULLS LAST,
CASE WHEN $4 = 'updated_at' AND $5 = 'asc' THEN updated_at END ASC NULLS LAST,
CASE WHEN $4 = 'updated_at' AND $5 = 'desc' THEN updated_at END DESC NULLS LAST
CASE WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at END DESC NULLS LAST
LIMIT $2 OFFSET $3;
""").sql(pretty=True)
"""


def _transform(d):
current_step = d.pop("current_step")
next_step = d.pop("next_step", None)

return {
"id": d["transition_id"],
"updated_at": utcnow(),
"current": {
"workflow": current_step[0],
"step": current_step[1],
Expand Down Expand Up @@ -60,7 +60,7 @@ async def list_execution_transitions(
execution_id: UUID,
limit: int = 100,
offset: int = 0,
sort_by: Literal["created_at", "updated_at"] = "created_at",
sort_by: Literal["created_at"] = "created_at",
direction: Literal["asc", "desc"] = "desc",
) -> tuple[str, list]:
"""
Expand Down
5 changes: 2 additions & 3 deletions agents-api/agents_api/queries/executions/list_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ...autogen.openapi_model import Execution
from ..utils import (
Expand All @@ -16,7 +15,7 @@
from .constants import OUTPUT_UNNEST_KEY

# Query to list executions
list_executions_query = parse_one("""
list_executions_query = """
SELECT * FROM latest_executions
WHERE
developer_id = $1 AND
Expand All @@ -27,7 +26,7 @@
CASE WHEN $3 = 'updated_at' AND $4 = 'asc' THEN updated_at END ASC NULLS LAST,
CASE WHEN $3 = 'updated_at' AND $4 = 'desc' THEN updated_at END DESC NULLS LAST
LIMIT $5 OFFSET $6;
""").sql(pretty=True)
"""


@rewrap_exceptions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,24 @@
import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class

# FIXME: Check if this query is correct


# Query to lookup temporal data
lookup_temporal_data_query = parse_one("""
lookup_temporal_data_query = """
SELECT t.*
FROM
temporal_executions_lookup t,
executions e
WHERE
t.execution_id = e.execution_id
AND t.developer_id = e.developer_id
AND e.execution_id = $1
AND e.developer_id = $2
LIMIT 1;
""").sql(pretty=True)
"""


@rewrap_exceptions(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from uuid import UUID

from beartype import beartype
from sqlglot import parse_one

from ...common.protocol.tasks import ExecutionInput
from ..utils import (
Expand All @@ -10,7 +9,7 @@
)

# Query to prepare execution input
prepare_execution_input_query = parse_one("""
prepare_execution_input_query = """
SELECT * FROM
(
SELECT to_jsonb(a) AS agent FROM (
Expand Down Expand Up @@ -41,7 +40,7 @@
LIMIT 1
) t
) AS task;
""").sql(pretty=True)
"""
# (
# SELECT to_jsonb(e) AS execution FROM (
# SELECT * FROM latest_executions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ async def create_task_execution(
background_tasks.add_task(
create_temporal_lookup,
#
developer_id=x_developer_id,
execution_id=execution.id,
workflow_handle=handle,
)
Expand Down
4 changes: 1 addition & 3 deletions agents-api/agents_api/routers/tasks/update_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ async def update_execution(
raise HTTPException(status_code=500, detail="Failed to stop execution")

case ResumeExecutionRequest():
token_data = await get_paused_execution_token(
developer_id=x_developer_id, execution_id=execution_id
)
token_data = await get_paused_execution_token(execution_id=execution_id)
activity_id = token_data["metadata"].get("x-activity-id", None)
run_id = token_data["metadata"].get("x-run-id", None)
workflow_id = token_data["metadata"].get("x-workflow-id", None)
Expand Down
Loading

0 comments on commit 4e1c975

Please sign in to comment.