From c038e97d33a70aee54dce440068ac68f4457d2fe Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Thu, 26 Dec 2024 18:40:32 +0530 Subject: [PATCH] fix(agents-api): Fix failing tests Signed-off-by: Diwank Singh Tomer --- .../agents_api/common/protocol/tasks.py | 4 +- .../executions/get_temporal_workflow_data.py | 1 - .../executions/prepare_execution_input.py | 6 +- .../agents_api/queries/tasks/get_task.py | 6 +- .../tools/get_tool_args_from_metadata.py | 10 +-- agents-api/tests/fixtures.py | 2 +- agents-api/tests/test_task_routes.py | 63 ++++++++++--------- 7 files changed, 48 insertions(+), 44 deletions(-) diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index 95a58ee69..07736e7d6 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -307,7 +307,9 @@ def spec_to_task_data(spec: dict) -> dict: workflows_dict = {workflow["name"]: workflow["steps"] for workflow in workflows} tools = spec.pop("tools", []) - tools = [{tool["type"]: tool.pop("spec"), **tool} for tool in tools] + tools = [ + {tool["type"]: tool.pop("spec"), **tool} for tool in tools if tool is not None + ] return { "id": task_id, diff --git a/agents-api/agents_api/queries/executions/get_temporal_workflow_data.py b/agents-api/agents_api/queries/executions/get_temporal_workflow_data.py index e30e1fdce..624ff5abf 100644 --- a/agents-api/agents_api/queries/executions/get_temporal_workflow_data.py +++ b/agents-api/agents_api/queries/executions/get_temporal_workflow_data.py @@ -4,7 +4,6 @@ import asyncpg from beartype import beartype from fastapi import HTTPException -from sqlglot import parse_one from ..utils import ( partialclass, diff --git a/agents-api/agents_api/queries/executions/prepare_execution_input.py b/agents-api/agents_api/queries/executions/prepare_execution_input.py index f75ced5ab..1ddca0622 100644 --- a/agents-api/agents_api/queries/executions/prepare_execution_input.py +++ b/agents-api/agents_api/queries/executions/prepare_execution_input.py @@ -25,7 +25,7 @@ ) a ) AS agent, ( - SELECT jsonb_agg(r) AS tools FROM ( + SELECT COALESCE(jsonb_agg(r), '[]'::jsonb) AS tools FROM ( SELECT * FROM tools WHERE developer_id = $1 AND @@ -72,7 +72,9 @@ **d["agent"], }, "agent_tools": [ - {tool["type"]: tool.pop("spec"), **tool} for tool in d["tools"] + {tool["type"]: tool.pop("spec"), **tool} + for tool in d["tools"] + if tool is not None ], "arguments": d["execution"]["input"], "execution": { diff --git a/agents-api/agents_api/queries/tasks/get_task.py b/agents-api/agents_api/queries/tasks/get_task.py index 9786751c7..bb83e8d36 100644 --- a/agents-api/agents_api/queries/tasks/get_task.py +++ b/agents-api/agents_api/queries/tasks/get_task.py @@ -23,12 +23,12 @@ ) FILTER (WHERE w.name IS NOT NULL), '[]'::jsonb ) as workflows, - jsonb_agg(tl) as tools + COALESCE(jsonb_agg(tl), '[]'::jsonb) as tools FROM tasks t -INNER JOIN +LEFT JOIN workflows w ON t.developer_id = w.developer_id AND t.task_id = w.task_id AND t.version = w.version -INNER JOIN +LEFT JOIN tools tl ON t.developer_id = tl.developer_id AND t.task_id = tl.task_id WHERE t.developer_id = $1 AND t.task_id = $2 diff --git a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py index ace75bac5..6f38e4269 100644 --- a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py +++ b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py @@ -2,7 +2,6 @@ from uuid import UUID from beartype import beartype -from sqlglot import parse_one from ..utils import ( pg_query, @@ -10,7 +9,7 @@ ) # Define the raw SQL query for getting tool args from metadata -tools_args_for_task_query = parse_one(""" +tools_args_for_task_query = """ SELECT COALESCE(agents_md || tasks_md, agents_md, tasks_md, '{}') as values FROM ( SELECT CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' @@ -28,10 +27,11 @@ WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md FROM tasks WHERE task_id = $2 AND developer_id = $4 LIMIT 1 -) AS tasks_md""").sql(pretty=True) +) AS tasks_md""" # Define the raw SQL query for getting tool args from metadata for a session -tool_args_for_session_query = parse_one("""SELECT COALESCE(agents_md || sessions_md, agents_md, sessions_md, '{}') as values FROM ( +tool_args_for_session_query = """ +SELECT COALESCE(agents_md || sessions_md, agents_md, sessions_md, '{}') as values FROM ( SELECT CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' @@ -48,7 +48,7 @@ WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md FROM sessions WHERE session_id = $2 AND developer_id = $4 LIMIT 1 -) AS sessions_md""").sql(pretty=True) +) AS sessions_md""" # @rewrap_exceptions( diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 03286251b..9d781804e 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -329,7 +329,7 @@ async def test_execution_started( async def test_transition( dsn=pg_dsn, developer_id=test_developer_id, - execution=test_execution, + execution=test_execution_started, ): pool = await create_db_pool(dsn=dsn) transition = await create_execution_transition( diff --git a/agents-api/tests/test_task_routes.py b/agents-api/tests/test_task_routes.py index ee8395f84..0664847ad 100644 --- a/agents-api/tests/test_task_routes.py +++ b/agents-api/tests/test_task_routes.py @@ -3,15 +3,16 @@ from uuid_extensions import uuid7 from ward import test -from tests.fixtures import ( +from .fixtures import ( client, make_request, test_agent, - # test_execution, + test_execution, + test_transition, test_task, ) -from .fixtures import test_execution, test_transition +from .utils import patch_testing_temporal @test("route: unauthorized should fail") @@ -60,43 +61,43 @@ def _(make_request=make_request, agent=test_agent): assert response.status_code == 201 -# @test("route: create task execution") -# async def _(make_request=make_request, task=test_task): -# data = dict( -# input={}, -# metadata={}, -# ) +@test("route: create task execution") +async def _(make_request=make_request, task=test_task): + data = dict( + input={}, + metadata={}, + ) -# async with patch_testing_temporal(): -# response = make_request( -# method="POST", -# url=f"/tasks/{str(task.id)}/executions", -# json=data, -# ) + async with patch_testing_temporal(): + response = make_request( + method="POST", + url=f"/tasks/{str(task.id)}/executions", + json=data, + ) -# assert response.status_code == 201 + assert response.status_code == 201 -# @test("route: get execution not exists") -# def _(make_request=make_request): -# execution_id = str(uuid7()) +@test("route: get execution not exists") +def _(make_request=make_request): + execution_id = str(uuid7()) -# response = make_request( -# method="GET", -# url=f"/executions/{execution_id}", -# ) + response = make_request( + method="GET", + url=f"/executions/{execution_id}", + ) -# assert response.status_code == 404 + assert response.status_code == 404 -# @test("route: get execution exists") -# def _(make_request=make_request, execution=test_execution): -# response = make_request( -# method="GET", -# url=f"/executions/{str(execution.id)}", -# ) +@test("route: get execution exists") +def _(make_request=make_request, execution=test_execution): + response = make_request( + method="GET", + url=f"/executions/{str(execution.id)}", + ) -# assert response.status_code == 200 + assert response.status_code == 200 @test("route: get task not exists")