From 83f58aca92fc715cfbafc5f9f2f19f95cbf2da1e Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Fri, 20 Dec 2024 14:45:28 +0300 Subject: [PATCH] feat: Add list tools query --- .../agents_api/queries/tools/list_tools.py | 92 ++++++++----------- 1 file changed, 37 insertions(+), 55 deletions(-) diff --git a/agents-api/agents_api/queries/tools/list_tools.py b/agents-api/agents_api/queries/tools/list_tools.py index 727bf8028..59fb1eff5 100644 --- a/agents-api/agents_api/queries/tools/list_tools.py +++ b/agents-api/agents_api/queries/tools/list_tools.py @@ -1,32 +1,43 @@ from typing import Any, Literal, TypeVar from uuid import UUID +import sqlvalidator from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError from ...autogen.openapi_model import Tool +from ...exceptions import InvalidSQLQuery from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, + pg_query, wrap_in_class, ) ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") +sql_query = sqlvalidator.parse(""" +SELECT * FROM tools +WHERE + developer_id = $1 AND + agent_id = $2 +ORDER BY + CASE WHEN $5 = 'created_at' AND $6 = 'desc' THEN s.created_at END DESC, + CASE WHEN $5 = 'created_at' AND $6 = 'asc' THEN s.created_at END ASC, + CASE WHEN $5 = 'updated_at' AND $6 = 'desc' THEN s.updated_at END DESC, + CASE WHEN $5 = 'updated_at' AND $6 = 'asc' THEN s.updated_at END ASC +LIMIT $3 OFFSET $4; +""") -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) +if not sql_query.is_valid(): + raise InvalidSQLQuery("get_tool") + + +# @rewrap_exceptions( +# { +# QueryException: partialclass(HTTPException, status_code=400), +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) @wrap_in_class( Tool, transform=lambda d: { @@ -38,7 +49,7 @@ **d, }, ) -@cozo_query +@pg_query @beartype def list_tools( *, @@ -49,46 +60,17 @@ def list_tools( sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", ) -> tuple[list[str], dict]: + developer_id = str(developer_id) agent_id = str(agent_id) - sort = f"{'-' if direction == 'desc' else ''}{sort_by}" - - list_query = f""" - input[agent_id] <- [[to_uuid($agent_id)]] - - ?[ - agent_id, - id, - name, - type, - spec, - description, - updated_at, - created_at, - ] := input[agent_id], - *tools {{ - agent_id, - tool_id: id, - name, - type, - spec, - description, - updated_at, - created_at, - }} - - :limit $limit - :offset $offset - :sort {sort} - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - list_query, - ] - return ( - queries, - {"agent_id": agent_id, "limit": limit, "offset": offset}, + sql_query.format(), + [ + developer_id, + agent_id, + limit, + offset, + sort_by, + direction, + ], )