Skip to content

Commit

Permalink
feat: Add list tools query
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Dec 20, 2024
1 parent e7d3079 commit 83f58ac
Showing 1 changed file with 37 additions and 55 deletions.
92 changes: 37 additions & 55 deletions agents-api/agents_api/queries/tools/list_tools.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,43 @@
from typing import Any, Literal, TypeVar
from uuid import UUID

import sqlvalidator
from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ...autogen.openapi_model import Tool
from ...exceptions import InvalidSQLQuery
from ..utils import (
cozo_query,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
verify_developer_owns_resource_query,
pg_query,
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")

sql_query = sqlvalidator.parse("""
SELECT * FROM tools
WHERE
developer_id = $1 AND
agent_id = $2
ORDER BY
CASE WHEN $5 = 'created_at' AND $6 = 'desc' THEN s.created_at END DESC,
CASE WHEN $5 = 'created_at' AND $6 = 'asc' THEN s.created_at END ASC,
CASE WHEN $5 = 'updated_at' AND $6 = 'desc' THEN s.updated_at END DESC,
CASE WHEN $5 = 'updated_at' AND $6 = 'asc' THEN s.updated_at END ASC
LIMIT $3 OFFSET $4;
""")

@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
)
if not sql_query.is_valid():
raise InvalidSQLQuery("get_tool")


# @rewrap_exceptions(
# {
# QueryException: partialclass(HTTPException, status_code=400),
# ValidationError: partialclass(HTTPException, status_code=400),
# TypeError: partialclass(HTTPException, status_code=400),
# }
# )
@wrap_in_class(
Tool,
transform=lambda d: {
Expand All @@ -38,7 +49,7 @@
**d,
},
)
@cozo_query
@pg_query
@beartype
def list_tools(
*,
Expand All @@ -49,46 +60,17 @@ def list_tools(
sort_by: Literal["created_at", "updated_at"] = "created_at",
direction: Literal["asc", "desc"] = "desc",
) -> tuple[list[str], dict]:
developer_id = str(developer_id)
agent_id = str(agent_id)

sort = f"{'-' if direction == 'desc' else ''}{sort_by}"

list_query = f"""
input[agent_id] <- [[to_uuid($agent_id)]]
?[
agent_id,
id,
name,
type,
spec,
description,
updated_at,
created_at,
] := input[agent_id],
*tools {{
agent_id,
tool_id: id,
name,
type,
spec,
description,
updated_at,
created_at,
}}
:limit $limit
:offset $offset
:sort {sort}
"""

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id),
list_query,
]

return (
queries,
{"agent_id": agent_id, "limit": limit, "offset": offset},
sql_query.format(),
[
developer_id,
agent_id,
limit,
offset,
sort_by,
direction,
],
)

0 comments on commit 83f58ac

Please sign in to comment.