Skip to content

Commit

Permalink
Merge pull request #1001 from julep-ai/x/misc-test-fixes
Browse files Browse the repository at this point in the history
chore: misc test and queries fixes
  • Loading branch information
creatorrr authored Dec 31, 2024
2 parents 33ac790 + 213807a commit 81841e4
Show file tree
Hide file tree
Showing 14 changed files with 164 additions and 70 deletions.
5 changes: 5 additions & 0 deletions agents-api/agents_api/common/utils/db_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def get_operation_message(base_msg: str) -> str:
status_code=404,
detail=get_operation_message(f"Required key not found for {resource_name}"),
),
AssertionError: partialclass(
HTTPException,
status_code=404,
detail=get_operation_message(f"No {resource_name} found"),
),
# Pydantic validation errors
pydantic.ValidationError: lambda e: partialclass(
HTTPException,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
CASE WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at END DESC NULLS LAST
LIMIT $2 OFFSET $3;
"""
# Query to get a single transition
get_execution_transition_query = """
SELECT * FROM transitions
WHERE
execution_id = $1
AND transition_id = $2;
"""


def _transform(d):
Expand Down Expand Up @@ -53,11 +60,12 @@ def _transform(d):
Transition,
transform=_transform,
)
@pg_query
@pg_query(debug=True)
@beartype
async def list_execution_transitions(
*,
execution_id: UUID,
transition_id: UUID | None = None,
limit: int = 100,
offset: int = 0,
sort_by: Literal["created_at"] = "created_at",
Expand All @@ -76,6 +84,14 @@ async def list_execution_transitions(
Returns:
tuple[str, list]: SQL query and parameters for listing execution transitions.
"""
if transition_id is not None:
return (
get_execution_transition_query,
[
str(execution_id),
str(transition_id),
],
)
return (
list_execution_transitions_query,
[
Expand Down
25 changes: 22 additions & 3 deletions agents-api/agents_api/queries/tasks/create_or_update_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,26 @@
RETURNING *;
"""

# Define the raw SQL query for creating or updating a task
task_query = """
WITH current_version AS (
SELECT COALESCE(
(SELECT MAX("version")
FROM tasks
WHERE developer_id = $1
AND task_id = $4),
0
) + 1 as next_version,
COALESCE(
(SELECT canonical_name
FROM tasks
WHERE developer_id = $1 AND task_id = $4
ORDER BY version DESC
LIMIT 1),
$2
) as effective_canonical_name
FROM (SELECT 1) as dummy
)
INSERT INTO tasks (
"version",
developer_id,
Expand All @@ -53,9 +72,9 @@
metadata
)
SELECT
next_version, -- version
next_version, -- version
$1, -- developer_id
effective_canonical_name, -- canonical_name
effective_canonical_name, -- canonical_name
$3, -- agent_id
$4, -- task_id
$5, -- name
Expand Down Expand Up @@ -99,7 +118,7 @@
$4, -- step_idx
$5, -- step_type
$6 -- step_definition
FROM version
FROM version;
"""


Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/queries/tasks/update_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
name, -- $6
description, -- $7
inherit_tools, -- $8
input_schema, -- $9
input_schema -- $9
)
SELECT
current_version + 1, -- version
Expand Down Expand Up @@ -72,7 +72,7 @@
$4, -- step_idx
$5, -- step_type
$6 -- step_definition
FROM version
FROM version;
"""


Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/routers/healthz/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .check_health import check_health as check_health
from .router import router as router
3 changes: 3 additions & 0 deletions agents-api/agents_api/routers/healthz/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from fastapi import APIRouter

router: APIRouter = APIRouter()
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/jobs/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .routers import router # noqa: F401
from .routers import router as router
3 changes: 1 addition & 2 deletions agents-api/agents_api/routers/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .list_execution_transitions import list_execution_transitions
from .list_task_executions import list_task_executions
from .list_tasks import list_tasks

# from .patch_execution import patch_execution
from .router import router
from .stream_transitions_events import stream_transitions_events
from .update_execution import update_execution
9 changes: 0 additions & 9 deletions agents-api/agents_api/routers/tasks/create_task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,6 @@ async def create_task_execution(
detail="Invalid request arguments schema",
)

# except QueryException as e:
# if e.code == "transact::assertion_failure":
# raise HTTPException(
# status_code=status.HTTP_404_NOT_FOUND, detail="Task not found"
# )

# raise

# get developer data
developer: Developer = await get_developer(developer_id=x_developer_id)

Expand All @@ -159,7 +151,6 @@ async def create_task_execution(

background_tasks.add_task(
create_temporal_lookup,
#
execution_id=execution.id,
workflow_handle=handle,
)
Expand Down
39 changes: 20 additions & 19 deletions agents-api/agents_api/routers/tasks/list_execution_transitions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Literal
from uuid import UUID

from fastapi import HTTPException, status

from ...autogen.openapi_model import (
ListResponse,
Transition,
Expand Down Expand Up @@ -30,22 +32,21 @@ async def list_execution_transitions(
return ListResponse[Transition](items=transitions)


# TODO: Do we need this?
# @router.get("/executions/{execution_id}/transitions/{transition_id}", tags=["tasks"])
# async def get_execution_transition(
# execution_id: UUID,
# transition_id: UUID,
# ) -> Transition:
# try:
# res = [
# row.to_dict()
# for _, row in get_execution_transition_query(
# execution_id, transition_id
# ).iterrows()
# ][0]
# return Transition(**res)
# except (IndexError, KeyError):
# raise HTTPException(
# status_code=status.HTTP_404_NOT_FOUND,
# detail="Transition not found",
# )
@router.get("/executions/{execution_id}/transitions/{transition_id}", tags=["tasks"])
async def get_execution_transition(
execution_id: UUID,
transition_id: UUID,
) -> Transition:
try:
transitions = await list_execution_transitions_query(
execution_id=execution_id,
transition_id=transition_id,
)
if not transitions:
raise IndexError
return transitions[0]
except (IndexError, KeyError):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Transition not found",
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
agents,
docs,
files,
healthz,
internal,
jobs,
sessions,
Expand Down Expand Up @@ -151,6 +152,7 @@ def register_exceptions(app: FastAPI) -> None:
app.include_router(docs.router, dependencies=[Depends(get_api_key)])
app.include_router(tasks.router, dependencies=[Depends(get_api_key)])
app.include_router(internal.router)
app.include_router(healthz.router)

# TODO: CORS should be enabled only for JWT auth
#
Expand Down
2 changes: 0 additions & 2 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,13 @@ async def test_execution_started(
# Start the execution
await create_execution_transition(
developer_id=developer_id,
# task_id=task.id,
execution_id=execution.id,
data=CreateTransitionRequest(
type="init",
output={},
current={"workflow": "main", "step": 0},
next={"workflow": "main", "step": 0},
),
# update_execution_status=True,
connection_pool=pool,
)
yield execution
Expand Down
20 changes: 14 additions & 6 deletions agents-api/tests/test_docs_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def _(make_request=make_request, agent=test_agent):
async with patch_testing_temporal():
data = {
"title": "Test Agent Doc",
"content": ["This is a test agent document."],
"content": "This is a test agent document.",
}

response = make_request(
Expand All @@ -63,6 +63,16 @@ async def _(make_request=make_request, agent=test_agent):
)
doc_id = response.json()["id"]

response = make_request(
method="GET",
url=f"/docs/{doc_id}",
)

assert response.status_code == 200
assert response.json()["id"] == doc_id
assert response.json()["title"] == "Test Agent Doc"
assert response.json()["content"] == "This is a test agent document."

response = make_request(
method="DELETE",
url=f"/agents/{agent.id}/docs/{doc_id}",
Expand Down Expand Up @@ -163,9 +173,7 @@ def _(make_request=make_request, agent=test_agent):
assert isinstance(docs, list)


# TODO: Fix this test. It fails sometimes and sometimes not.


@skip("Fails due to FTS not working in Test Container")
@test("route: search agent docs")
async def _(make_request=make_request, agent=test_agent, doc=test_doc):
await asyncio.sleep(0.5)
Expand All @@ -188,8 +196,7 @@ async def _(make_request=make_request, agent=test_agent, doc=test_doc):
assert len(docs) >= 1


# FIXME: This test is failing because the search is not returning the expected results
@skip("Fails randomly on CI")
@skip("Fails due to FTS not working in Test Container")
@test("route: search user docs")
async def _(make_request=make_request, user=test_user, doc=test_user_doc):
await asyncio.sleep(0.5)
Expand All @@ -213,6 +220,7 @@ async def _(make_request=make_request, user=test_user, doc=test_user_doc):
assert len(docs) >= 1


@skip("Fails due to Vectorizer and FTS not working in Test Container")
@test("route: search agent docs hybrid with mmr")
async def _(make_request=make_request, agent=test_agent, doc=test_doc):
await asyncio.sleep(0.5)
Expand Down
Loading

0 comments on commit 81841e4

Please sign in to comment.