Skip to content

Commit

Permalink
fix(agents-api): Fix failing tests
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Dec 26, 2024
1 parent 3b188f3 commit c038e97
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 44 deletions.
4 changes: 3 additions & 1 deletion agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,9 @@ def spec_to_task_data(spec: dict) -> dict:
workflows_dict = {workflow["name"]: workflow["steps"] for workflow in workflows}

tools = spec.pop("tools", [])
tools = [{tool["type"]: tool.pop("spec"), **tool} for tool in tools]
tools = [
{tool["type"]: tool.pop("spec"), **tool} for tool in tools if tool is not None
]

return {
"id": task_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ..utils import (
partialclass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
) a
) AS agent,
(
SELECT jsonb_agg(r) AS tools FROM (
SELECT COALESCE(jsonb_agg(r), '[]'::jsonb) AS tools FROM (
SELECT * FROM tools
WHERE
developer_id = $1 AND
Expand Down Expand Up @@ -72,7 +72,9 @@
**d["agent"],
},
"agent_tools": [
{tool["type"]: tool.pop("spec"), **tool} for tool in d["tools"]
{tool["type"]: tool.pop("spec"), **tool}
for tool in d["tools"]
if tool is not None
],
"arguments": d["execution"]["input"],
"execution": {
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/queries/tasks/get_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
) FILTER (WHERE w.name IS NOT NULL),
'[]'::jsonb
) as workflows,
jsonb_agg(tl) as tools
COALESCE(jsonb_agg(tl), '[]'::jsonb) as tools
FROM
tasks t
INNER JOIN
LEFT JOIN
workflows w ON t.developer_id = w.developer_id AND t.task_id = w.task_id AND t.version = w.version
INNER JOIN
LEFT JOIN
tools tl ON t.developer_id = tl.developer_id AND t.task_id = tl.task_id
WHERE
t.developer_id = $1 AND t.task_id = $2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
from uuid import UUID

from beartype import beartype
from sqlglot import parse_one

from ..utils import (
pg_query,
wrap_in_class,
)

# Define the raw SQL query for getting tool args from metadata
tools_args_for_task_query = parse_one("""
tools_args_for_task_query = """
SELECT COALESCE(agents_md || tasks_md, agents_md, tasks_md, '{}') as values FROM (
SELECT
CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args'
Expand All @@ -28,10 +27,11 @@
WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md
FROM tasks
WHERE task_id = $2 AND developer_id = $4 LIMIT 1
) AS tasks_md""").sql(pretty=True)
) AS tasks_md"""

# Define the raw SQL query for getting tool args from metadata for a session
tool_args_for_session_query = parse_one("""SELECT COALESCE(agents_md || sessions_md, agents_md, sessions_md, '{}') as values FROM (
tool_args_for_session_query = """
SELECT COALESCE(agents_md || sessions_md, agents_md, sessions_md, '{}') as values FROM (
SELECT
CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args'
WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args'
Expand All @@ -48,7 +48,7 @@
WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md
FROM sessions
WHERE session_id = $2 AND developer_id = $4 LIMIT 1
) AS sessions_md""").sql(pretty=True)
) AS sessions_md"""


# @rewrap_exceptions(
Expand Down
2 changes: 1 addition & 1 deletion agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ async def test_execution_started(
async def test_transition(
dsn=pg_dsn,
developer_id=test_developer_id,
execution=test_execution,
execution=test_execution_started,
):
pool = await create_db_pool(dsn=dsn)
transition = await create_execution_transition(
Expand Down
63 changes: 32 additions & 31 deletions agents-api/tests/test_task_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from uuid_extensions import uuid7
from ward import test

from tests.fixtures import (
from .fixtures import (
client,
make_request,
test_agent,
# test_execution,
test_execution,
test_transition,
test_task,
)

from .fixtures import test_execution, test_transition
from .utils import patch_testing_temporal


@test("route: unauthorized should fail")
Expand Down Expand Up @@ -60,43 +61,43 @@ def _(make_request=make_request, agent=test_agent):
assert response.status_code == 201


# @test("route: create task execution")
# async def _(make_request=make_request, task=test_task):
# data = dict(
# input={},
# metadata={},
# )
@test("route: create task execution")
async def _(make_request=make_request, task=test_task):
data = dict(
input={},
metadata={},
)

# async with patch_testing_temporal():
# response = make_request(
# method="POST",
# url=f"/tasks/{str(task.id)}/executions",
# json=data,
# )
async with patch_testing_temporal():
response = make_request(
method="POST",
url=f"/tasks/{str(task.id)}/executions",
json=data,
)

# assert response.status_code == 201
assert response.status_code == 201


# @test("route: get execution not exists")
# def _(make_request=make_request):
# execution_id = str(uuid7())
@test("route: get execution not exists")
def _(make_request=make_request):
execution_id = str(uuid7())

# response = make_request(
# method="GET",
# url=f"/executions/{execution_id}",
# )
response = make_request(
method="GET",
url=f"/executions/{execution_id}",
)

# assert response.status_code == 404
assert response.status_code == 404


# @test("route: get execution exists")
# def _(make_request=make_request, execution=test_execution):
# response = make_request(
# method="GET",
# url=f"/executions/{str(execution.id)}",
# )
@test("route: get execution exists")
def _(make_request=make_request, execution=test_execution):
response = make_request(
method="GET",
url=f"/executions/{str(execution.id)}",
)

# assert response.status_code == 200
assert response.status_code == 200


@test("route: get task not exists")
Expand Down

0 comments on commit c038e97

Please sign in to comment.