From 9d0068eb75c2923caf7d1e5034dca8f042718f34 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Wed, 18 Dec 2024 15:39:35 +0300 Subject: [PATCH 01/21] chore: Move ti queries directory --- .../models/chat/get_cached_response.py | 15 --------------- .../models/chat/set_cached_response.py | 19 ------------------- .../{models => queries}/chat/__init__.py | 2 -- .../chat/gather_messages.py | 0 .../chat/prepare_chat_context.py | 15 +++++++-------- 5 files changed, 7 insertions(+), 44 deletions(-) delete mode 100644 agents-api/agents_api/models/chat/get_cached_response.py delete mode 100644 agents-api/agents_api/models/chat/set_cached_response.py rename agents-api/agents_api/{models => queries}/chat/__init__.py (92%) rename agents-api/agents_api/{models => queries}/chat/gather_messages.py (100%) rename agents-api/agents_api/{models => queries}/chat/prepare_chat_context.py (92%) diff --git a/agents-api/agents_api/models/chat/get_cached_response.py b/agents-api/agents_api/models/chat/get_cached_response.py deleted file mode 100644 index 368c88567..000000000 --- a/agents-api/agents_api/models/chat/get_cached_response.py +++ /dev/null @@ -1,15 +0,0 @@ -from beartype import beartype - -from ..utils import cozo_query - - -@cozo_query -@beartype -def get_cached_response(key: str) -> tuple[str, dict]: - query = """ - input[key] <- [[$key]] - ?[key, value] := input[key], *session_cache{key, value} - :limit 1 - """ - - return (query, {"key": key}) diff --git a/agents-api/agents_api/models/chat/set_cached_response.py b/agents-api/agents_api/models/chat/set_cached_response.py deleted file mode 100644 index 8625f3f1b..000000000 --- a/agents-api/agents_api/models/chat/set_cached_response.py +++ /dev/null @@ -1,19 +0,0 @@ -from beartype import beartype - -from ..utils import cozo_query - - -@cozo_query -@beartype -def set_cached_response(key: str, value: dict) -> tuple[str, dict]: - query = """ - ?[key, value] <- [[$key, $value]] - - :insert session_cache { - key => value - } - - :returning - """ - - return (query, {"key": key, "value": value}) diff --git a/agents-api/agents_api/models/chat/__init__.py b/agents-api/agents_api/queries/chat/__init__.py similarity index 92% rename from agents-api/agents_api/models/chat/__init__.py rename to agents-api/agents_api/queries/chat/__init__.py index 428b72572..2c05b4f8b 100644 --- a/agents-api/agents_api/models/chat/__init__.py +++ b/agents-api/agents_api/queries/chat/__init__.py @@ -17,6 +17,4 @@ # ruff: noqa: F401, F403, F405 from .gather_messages import gather_messages -from .get_cached_response import get_cached_response from .prepare_chat_context import prepare_chat_context -from .set_cached_response import set_cached_response diff --git a/agents-api/agents_api/models/chat/gather_messages.py b/agents-api/agents_api/queries/chat/gather_messages.py similarity index 100% rename from agents-api/agents_api/models/chat/gather_messages.py rename to agents-api/agents_api/queries/chat/gather_messages.py diff --git a/agents-api/agents_api/models/chat/prepare_chat_context.py b/agents-api/agents_api/queries/chat/prepare_chat_context.py similarity index 92% rename from agents-api/agents_api/models/chat/prepare_chat_context.py rename to agents-api/agents_api/queries/chat/prepare_chat_context.py index f77686d7a..4731618f8 100644 --- a/agents-api/agents_api/models/chat/prepare_chat_context.py +++ b/agents-api/agents_api/queries/chat/prepare_chat_context.py @@ -3,7 +3,6 @@ from beartype import beartype from fastapi import HTTPException -from pycozo.client import QueryException from pydantic import ValidationError from ...common.protocol.sessions import ChatContext, make_session @@ -22,13 +21,13 @@ T = TypeVar("T") -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) +# TODO: implement this part +# @rewrap_exceptions( +# { +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) @wrap_in_class( ChatContext, one=True, From 780100b1f2a6ce87a4918b6b45c1a03ee9d4f10b Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Thu, 19 Dec 2024 15:34:37 +0300 Subject: [PATCH 02/21] feat: Add prepare chat context query --- .../queries/chat/gather_messages.py | 12 +- .../queries/chat/prepare_chat_context.py | 225 ++++++++++-------- 2 files changed, 129 insertions(+), 108 deletions(-) diff --git a/agents-api/agents_api/queries/chat/gather_messages.py b/agents-api/agents_api/queries/chat/gather_messages.py index 28dc6607f..34a7c564f 100644 --- a/agents-api/agents_api/queries/chat/gather_messages.py +++ b/agents-api/agents_api/queries/chat/gather_messages.py @@ -3,18 +3,17 @@ from beartype import beartype from fastapi import HTTPException -from pycozo.client import QueryException from pydantic import ValidationError from ...autogen.openapi_model import ChatInput, DocReference, History from ...clients import litellm from ...common.protocol.developers import Developer from ...common.protocol.sessions import ChatContext -from ..docs.search_docs_by_embedding import search_docs_by_embedding -from ..docs.search_docs_by_text import search_docs_by_text -from ..docs.search_docs_hybrid import search_docs_hybrid -from ..entry.get_history import get_history -from ..session.get_session import get_session +# from ..docs.search_docs_by_embedding import search_docs_by_embedding +# from ..docs.search_docs_by_text import search_docs_by_text +# from ..docs.search_docs_hybrid import search_docs_hybrid +# from ..entry.get_history import get_history +from ..sessions.get_session import get_session from ..utils import ( partialclass, rewrap_exceptions, @@ -25,7 +24,6 @@ @rewrap_exceptions( { - QueryException: partialclass(HTTPException, status_code=400), ValidationError: partialclass(HTTPException, status_code=400), TypeError: partialclass(HTTPException, status_code=400), } diff --git a/agents-api/agents_api/queries/chat/prepare_chat_context.py b/agents-api/agents_api/queries/chat/prepare_chat_context.py index 4731618f8..23926ea4c 100644 --- a/agents-api/agents_api/queries/chat/prepare_chat_context.py +++ b/agents-api/agents_api/queries/chat/prepare_chat_context.py @@ -2,18 +2,10 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException -from pydantic import ValidationError from ...common.protocol.sessions import ChatContext, make_session -from ..session.prepare_session_data import prepare_session_data from ..utils import ( - cozo_query, - fix_uuid_if_present, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, + pg_query, wrap_in_class, ) @@ -21,17 +13,107 @@ T = TypeVar("T") -# TODO: implement this part -# @rewrap_exceptions( -# { -# ValidationError: partialclass(HTTPException, status_code=400), -# TypeError: partialclass(HTTPException, status_code=400), -# } -# ) -@wrap_in_class( - ChatContext, - one=True, - transform=lambda d: { +query = """ +SELECT * FROM +( + SELECT jsonb_agg(u) AS users FROM ( + SELECT + session_lookup.participant_id, + users.user_id AS id, + users.developer_id, + users.name, + users.about, + users.created_at, + users.updated_at, + users.metadata + FROM session_lookup + INNER JOIN users ON session_lookup.participant_id = users.user_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'user' + ) u +) AS users, +( + SELECT jsonb_agg(a) AS agents FROM ( + SELECT + session_lookup.participant_id, + agents.agent_id AS id, + agents.developer_id, + agents.canonical_name, + agents.name, + agents.about, + agents.instructions, + agents.model, + agents.created_at, + agents.updated_at, + agents.metadata, + agents.default_settings + FROM session_lookup + INNER JOIN agents ON session_lookup.participant_id = agents.agent_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'agent' + ) a +) AS agents, +( + SELECT to_jsonb(s) AS session FROM ( + SELECT + sessions.session_id AS id, + sessions.developer_id, + sessions.situation, + sessions.system_template, + sessions.created_at, + sessions.metadata, + sessions.render_templates, + sessions.token_budget, + sessions.context_overflow, + sessions.forward_tool_calls, + sessions.recall_options + FROM sessions + WHERE + developer_id = $1 AND + session_id = $2 + LIMIT 1 + ) s +) AS session, +( + SELECT jsonb_agg(r) AS toolsets FROM ( + SELECT + session_lookup.participant_id, + tools.tool_id as id, + tools.developer_id, + tools.agent_id, + tools.task_id, + tools.task_version, + tools.type, + tools.name, + tools.description, + tools.spec, + tools.updated_at, + tools.created_at + FROM session_lookup + INNER JOIN tools ON session_lookup.participant_id = tools.agent_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'agent' + ) r +) AS toolsets +""" + + +def _transform(d): + toolsets = {} + for tool in d["toolsets"]: + agent_id = tool["agent_id"] + if agent_id in toolsets: + toolsets[agent_id].append(tool) + else: + toolsets[agent_id] = [tool] + + return { **d, "session": make_session( agents=[a["id"] for a in d["agents"]], @@ -40,103 +122,44 @@ ), "toolsets": [ { - **ts, + "agent_id": agent_id, "tools": [ { tool["type"]: tool.pop("spec"), **tool, } - for tool in map(fix_uuid_if_present, ts["tools"]) + for tool in tools ], } - for ts in d["toolsets"] + for agent_id, tools in toolsets.items() ], - }, + } + + +# TODO: implement this part +# @rewrap_exceptions( +# { +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) +@wrap_in_class( + ChatContext, + one=True, + transform=_transform, ) -@cozo_query +@pg_query @beartype -def prepare_chat_context( +async def prepare_chat_context( *, developer_id: UUID, session_id: UUID, -) -> tuple[list[str], dict]: +) -> tuple[list[str], list]: """ Executes a complex query to retrieve memory context based on session ID. """ - [*_, session_data_query], sd_vars = prepare_session_data.__wrapped__( - developer_id=developer_id, session_id=session_id - ) - - session_data_fields = ("session", "agents", "users") - - session_data_query += """ - :create _session_data_json { - agents: [Json], - users: [Json], - session: Json, - } - """ - - toolsets_query = """ - input[session_id] <- [[to_uuid($session_id)]] - - tools_by_agent[agent_id, collect(tool)] := - input[session_id], - *session_lookup{ - session_id, - participant_id: agent_id, - participant_type: "agent", - }, - - *tools { agent_id, tool_id, name, type, spec, description, updated_at, created_at }, - tool = { - "id": tool_id, - "name": name, - "type": type, - "spec": spec, - "description": description, - "updated_at": updated_at, - "created_at": created_at, - } - - agent_toolsets[collect(toolset)] := - tools_by_agent[agent_id, tools], - toolset = { - "agent_id": agent_id, - "tools": tools, - } - - ?[toolsets] := - agent_toolsets[toolsets] - - :create _toolsets_json { - toolsets: [Json], - } - """ - - combine_query = f""" - ?[{', '.join(session_data_fields)}, toolsets] := - *_session_data_json {{ {', '.join(session_data_fields)} }}, - *_toolsets_json {{ toolsets }} - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - session_data_query, - toolsets_query, - combine_query, - ] - return ( - queries, - { - "session_id": str(session_id), - **sd_vars, - }, + [query], + [developer_id, session_id], ) From c9fc7579c08b65c1203deec0a81deb5b5e6060ec Mon Sep 17 00:00:00 2001 From: whiterabbit1983 Date: Thu, 19 Dec 2024 12:38:09 +0000 Subject: [PATCH 03/21] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/chat/gather_messages.py | 1 + 1 file changed, 1 insertion(+) diff --git a/agents-api/agents_api/queries/chat/gather_messages.py b/agents-api/agents_api/queries/chat/gather_messages.py index 34a7c564f..4fd574368 100644 --- a/agents-api/agents_api/queries/chat/gather_messages.py +++ b/agents-api/agents_api/queries/chat/gather_messages.py @@ -9,6 +9,7 @@ from ...clients import litellm from ...common.protocol.developers import Developer from ...common.protocol.sessions import ChatContext + # from ..docs.search_docs_by_embedding import search_docs_by_embedding # from ..docs.search_docs_by_text import search_docs_by_text # from ..docs.search_docs_hybrid import search_docs_hybrid From 1d2bd9a4342c0e7bb095dfcfa6087c182084afd2 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Thu, 19 Dec 2024 15:49:10 +0300 Subject: [PATCH 04/21] feat: Add SQL validation --- agents-api/agents_api/exceptions.py | 9 ++ .../queries/chat/prepare_chat_context.py | 90 ++++++++++--------- 2 files changed, 56 insertions(+), 43 deletions(-) diff --git a/agents-api/agents_api/exceptions.py b/agents-api/agents_api/exceptions.py index 615958a87..f6fcc4741 100644 --- a/agents-api/agents_api/exceptions.py +++ b/agents-api/agents_api/exceptions.py @@ -49,3 +49,12 @@ class FailedEncodingSentinel: """Sentinel object returned when failed to encode payload.""" payload_data: bytes + + +class QueriesBaseException(AgentsBaseException): + pass + + +class InvalidSQLQuery(QueriesBaseException): + def __init__(self, query_name: str): + super().__init__(f"invalid query: {query_name}") diff --git a/agents-api/agents_api/queries/chat/prepare_chat_context.py b/agents-api/agents_api/queries/chat/prepare_chat_context.py index 23926ea4c..1d9bd52fb 100644 --- a/agents-api/agents_api/queries/chat/prepare_chat_context.py +++ b/agents-api/agents_api/queries/chat/prepare_chat_context.py @@ -1,9 +1,11 @@ from typing import Any, TypeVar from uuid import UUID +import sqlvalidator from beartype import beartype from ...common.protocol.sessions import ChatContext, make_session +from ...exceptions import InvalidSQLQuery from ..utils import ( pg_query, wrap_in_class, @@ -13,19 +15,19 @@ T = TypeVar("T") -query = """ -SELECT * FROM +sql_query = sqlvalidator.parse( + """SELECT * FROM ( SELECT jsonb_agg(u) AS users FROM ( SELECT session_lookup.participant_id, users.user_id AS id, - users.developer_id, - users.name, - users.about, - users.created_at, - users.updated_at, - users.metadata + users.developer_id, + users.name, + users.about, + users.created_at, + users.updated_at, + users.metadata FROM session_lookup INNER JOIN users ON session_lookup.participant_id = users.user_id WHERE @@ -39,16 +41,16 @@ SELECT session_lookup.participant_id, agents.agent_id AS id, - agents.developer_id, - agents.canonical_name, - agents.name, - agents.about, - agents.instructions, - agents.model, - agents.created_at, - agents.updated_at, - agents.metadata, - agents.default_settings + agents.developer_id, + agents.canonical_name, + agents.name, + agents.about, + agents.instructions, + agents.model, + agents.created_at, + agents.updated_at, + agents.metadata, + agents.default_settings FROM session_lookup INNER JOIN agents ON session_lookup.participant_id = agents.agent_id WHERE @@ -58,24 +60,24 @@ ) a ) AS agents, ( - SELECT to_jsonb(s) AS session FROM ( + SELECT to_jsonb(s) AS session FROM ( SELECT sessions.session_id AS id, - sessions.developer_id, - sessions.situation, - sessions.system_template, - sessions.created_at, - sessions.metadata, - sessions.render_templates, - sessions.token_budget, - sessions.context_overflow, - sessions.forward_tool_calls, - sessions.recall_options + sessions.developer_id, + sessions.situation, + sessions.system_template, + sessions.created_at, + sessions.metadata, + sessions.render_templates, + sessions.token_budget, + sessions.context_overflow, + sessions.forward_tool_calls, + sessions.recall_options FROM sessions WHERE developer_id = $1 AND session_id = $2 - LIMIT 1 + LIMIT 1 ) s ) AS session, ( @@ -83,16 +85,16 @@ SELECT session_lookup.participant_id, tools.tool_id as id, - tools.developer_id, - tools.agent_id, - tools.task_id, - tools.task_version, - tools.type, - tools.name, - tools.description, - tools.spec, - tools.updated_at, - tools.created_at + tools.developer_id, + tools.agent_id, + tools.task_id, + tools.task_version, + tools.type, + tools.name, + tools.description, + tools.spec, + tools.updated_at, + tools.created_at FROM session_lookup INNER JOIN tools ON session_lookup.participant_id = tools.agent_id WHERE @@ -100,8 +102,10 @@ session_id = $2 AND session_lookup.participant_type = 'agent' ) r -) AS toolsets -""" +) AS toolsets""" +) +if not sql_query.is_valid(): + raise InvalidSQLQuery("prepare_chat_context") def _transform(d): @@ -160,6 +164,6 @@ async def prepare_chat_context( """ return ( - [query], + [sql_query.format()], [developer_id, session_id], ) From 1bc8fe3439f38bede9615aa784dbfcd50d21b89c Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Fri, 20 Dec 2024 11:49:21 +0300 Subject: [PATCH 05/21] chore: Import other required queries --- agents-api/agents_api/queries/chat/gather_messages.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/agents-api/agents_api/queries/chat/gather_messages.py b/agents-api/agents_api/queries/chat/gather_messages.py index 4fd574368..94d5fe71a 100644 --- a/agents-api/agents_api/queries/chat/gather_messages.py +++ b/agents-api/agents_api/queries/chat/gather_messages.py @@ -10,10 +10,10 @@ from ...common.protocol.developers import Developer from ...common.protocol.sessions import ChatContext -# from ..docs.search_docs_by_embedding import search_docs_by_embedding -# from ..docs.search_docs_by_text import search_docs_by_text -# from ..docs.search_docs_hybrid import search_docs_hybrid -# from ..entry.get_history import get_history +from ..docs.search_docs_by_embedding import search_docs_by_embedding +from ..docs.search_docs_by_text import search_docs_by_text +from ..docs.search_docs_hybrid import search_docs_hybrid +from ..entries.get_history import get_history from ..sessions.get_session import get_session from ..utils import ( partialclass, From 2975407ab98d19bfde1ecac7ce2f574310efa7d1 Mon Sep 17 00:00:00 2001 From: whiterabbit1983 Date: Fri, 20 Dec 2024 08:50:13 +0000 Subject: [PATCH 06/21] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/chat/gather_messages.py | 1 - 1 file changed, 1 deletion(-) diff --git a/agents-api/agents_api/queries/chat/gather_messages.py b/agents-api/agents_api/queries/chat/gather_messages.py index 94d5fe71a..cbf3bf209 100644 --- a/agents-api/agents_api/queries/chat/gather_messages.py +++ b/agents-api/agents_api/queries/chat/gather_messages.py @@ -9,7 +9,6 @@ from ...clients import litellm from ...common.protocol.developers import Developer from ...common.protocol.sessions import ChatContext - from ..docs.search_docs_by_embedding import search_docs_by_embedding from ..docs.search_docs_by_text import search_docs_by_text from ..docs.search_docs_hybrid import search_docs_hybrid From ba3027b0f94fa7e05d75959ae0ebe74711846168 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Fri, 20 Dec 2024 12:37:19 +0300 Subject: [PATCH 07/21] chore: Move queries to another folder --- .../{models => queries}/tools/__init__.py | 0 .../{models => queries}/tools/create_tools.py | 21 +++++++++---------- .../{models => queries}/tools/delete_tool.py | 0 .../{models => queries}/tools/get_tool.py | 0 .../tools/get_tool_args_from_metadata.py | 0 .../{models => queries}/tools/list_tools.py | 0 .../{models => queries}/tools/patch_tool.py | 0 .../{models => queries}/tools/update_tool.py | 0 8 files changed, 10 insertions(+), 11 deletions(-) rename agents-api/agents_api/{models => queries}/tools/__init__.py (100%) rename agents-api/agents_api/{models => queries}/tools/create_tools.py (89%) rename agents-api/agents_api/{models => queries}/tools/delete_tool.py (100%) rename agents-api/agents_api/{models => queries}/tools/get_tool.py (100%) rename agents-api/agents_api/{models => queries}/tools/get_tool_args_from_metadata.py (100%) rename agents-api/agents_api/{models => queries}/tools/list_tools.py (100%) rename agents-api/agents_api/{models => queries}/tools/patch_tool.py (100%) rename agents-api/agents_api/{models => queries}/tools/update_tool.py (100%) diff --git a/agents-api/agents_api/models/tools/__init__.py b/agents-api/agents_api/queries/tools/__init__.py similarity index 100% rename from agents-api/agents_api/models/tools/__init__.py rename to agents-api/agents_api/queries/tools/__init__.py diff --git a/agents-api/agents_api/models/tools/create_tools.py b/agents-api/agents_api/queries/tools/create_tools.py similarity index 89% rename from agents-api/agents_api/models/tools/create_tools.py rename to agents-api/agents_api/queries/tools/create_tools.py index 578a1268d..0d2e0984c 100644 --- a/agents-api/agents_api/models/tools/create_tools.py +++ b/agents-api/agents_api/queries/tools/create_tools.py @@ -1,18 +1,18 @@ """This module contains functions for creating tools in the CozoDB database.""" +import sqlvalidator from typing import Any, TypeVar from uuid import UUID from beartype import beartype from fastapi import HTTPException -from pycozo.client import QueryException from pydantic import ValidationError from uuid_extensions import uuid7 from ...autogen.openapi_model import CreateToolRequest, Tool from ...metrics.counters import increase_counter from ..utils import ( - cozo_query, + pg_query, partialclass, rewrap_exceptions, verify_developer_id_query, @@ -24,14 +24,13 @@ T = TypeVar("T") -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - AssertionError: partialclass(HTTPException, status_code=400), - } -) +# @rewrap_exceptions( +# { +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# AssertionError: partialclass(HTTPException, status_code=400), +# } +# ) @wrap_in_class( Tool, transform=lambda d: { @@ -41,7 +40,7 @@ }, _kind="inserted", ) -@cozo_query +@pg_query @increase_counter("create_tools") @beartype def create_tools( diff --git a/agents-api/agents_api/models/tools/delete_tool.py b/agents-api/agents_api/queries/tools/delete_tool.py similarity index 100% rename from agents-api/agents_api/models/tools/delete_tool.py rename to agents-api/agents_api/queries/tools/delete_tool.py diff --git a/agents-api/agents_api/models/tools/get_tool.py b/agents-api/agents_api/queries/tools/get_tool.py similarity index 100% rename from agents-api/agents_api/models/tools/get_tool.py rename to agents-api/agents_api/queries/tools/get_tool.py diff --git a/agents-api/agents_api/models/tools/get_tool_args_from_metadata.py b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py similarity index 100% rename from agents-api/agents_api/models/tools/get_tool_args_from_metadata.py rename to agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py diff --git a/agents-api/agents_api/models/tools/list_tools.py b/agents-api/agents_api/queries/tools/list_tools.py similarity index 100% rename from agents-api/agents_api/models/tools/list_tools.py rename to agents-api/agents_api/queries/tools/list_tools.py diff --git a/agents-api/agents_api/models/tools/patch_tool.py b/agents-api/agents_api/queries/tools/patch_tool.py similarity index 100% rename from agents-api/agents_api/models/tools/patch_tool.py rename to agents-api/agents_api/queries/tools/patch_tool.py diff --git a/agents-api/agents_api/models/tools/update_tool.py b/agents-api/agents_api/queries/tools/update_tool.py similarity index 100% rename from agents-api/agents_api/models/tools/update_tool.py rename to agents-api/agents_api/queries/tools/update_tool.py From 5c060acea43bb44765ee6b716072296cad4e0a86 Mon Sep 17 00:00:00 2001 From: whiterabbit1983 Date: Fri, 20 Dec 2024 09:38:56 +0000 Subject: [PATCH 08/21] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/tools/create_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/agents-api/agents_api/queries/tools/create_tools.py b/agents-api/agents_api/queries/tools/create_tools.py index 0d2e0984c..a54fa6973 100644 --- a/agents-api/agents_api/queries/tools/create_tools.py +++ b/agents-api/agents_api/queries/tools/create_tools.py @@ -1,9 +1,9 @@ """This module contains functions for creating tools in the CozoDB database.""" -import sqlvalidator from typing import Any, TypeVar from uuid import UUID +import sqlvalidator from beartype import beartype from fastapi import HTTPException from pydantic import ValidationError @@ -12,8 +12,8 @@ from ...autogen.openapi_model import CreateToolRequest, Tool from ...metrics.counters import increase_counter from ..utils import ( - pg_query, partialclass, + pg_query, rewrap_exceptions, verify_developer_id_query, verify_developer_owns_resource_query, From 47b7c7e7ee7794491941e8a5c479978b3cc79c5d Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Fri, 20 Dec 2024 14:10:53 +0300 Subject: [PATCH 09/21] feat: Add create tools query --- .../agents_api/queries/tools/create_tools.py | 103 +++++++----------- 1 file changed, 41 insertions(+), 62 deletions(-) diff --git a/agents-api/agents_api/queries/tools/create_tools.py b/agents-api/agents_api/queries/tools/create_tools.py index a54fa6973..d50e98e80 100644 --- a/agents-api/agents_api/queries/tools/create_tools.py +++ b/agents-api/agents_api/queries/tools/create_tools.py @@ -5,18 +5,14 @@ import sqlvalidator from beartype import beartype -from fastapi import HTTPException -from pydantic import ValidationError from uuid_extensions import uuid7 from ...autogen.openapi_model import CreateToolRequest, Tool +from ...exceptions import InvalidSQLQuery from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, + # rewrap_exceptions, wrap_in_class, ) @@ -24,6 +20,37 @@ T = TypeVar("T") +sql_query = sqlvalidator.parse( + """INSERT INTO tools +( + developer_id, + agent_id, + tool_id, + type, + name, + spec, + description +) +SELECT + $1, + $2, + $3, + $4, + $5, + $6, + $7 +WHERE NOT EXISTS ( + SELECT null FROM tools + WHERE (agent_id, name) = ($2, $5) +) +RETURNING * +""" +) + +if not sql_query.is_valid(): + raise InvalidSQLQuery("create_tools") + + # @rewrap_exceptions( # { # ValidationError: partialclass(HTTPException, status_code=400), @@ -48,8 +75,8 @@ def create_tools( developer_id: UUID, agent_id: UUID, data: list[CreateToolRequest], - ignore_existing: bool = False, -) -> tuple[list[str], dict]: + ignore_existing: bool = False, # TODO: what to do with this flag? +) -> tuple[list[str], list]: """ Constructs a datalog query for inserting tool records into the 'agent_functions' relation in the CozoDB. @@ -69,6 +96,7 @@ def create_tools( tools_data = [ [ + developer_id, str(agent_id), str(uuid7()), tool.type, @@ -79,57 +107,8 @@ def create_tools( for tool in data ] - ensure_tool_name_unique_query = """ - input[agent_id, tool_id, type, name, spec, description] <- $records - ?[tool_id] := - input[agent_id, _, type, name, _, _], - *tools{ - agent_id: to_uuid(agent_id), - tool_id, - type, - name, - spec, - description, - } - - :limit 1 - :assert none - """ - - # Datalog query for inserting new tool records into the 'tools' relation - create_query = """ - input[agent_id, tool_id, type, name, spec, description] <- $records - - # Do not add duplicate - ?[agent_id, tool_id, type, name, spec, description] := - input[agent_id, tool_id, type, name, spec, description], - not *tools{ - agent_id: to_uuid(agent_id), - type, - name, - } - - :insert tools { - agent_id, - tool_id, - type, - name, - spec, - description, - } - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - create_query, - ] - - if not ignore_existing: - queries.insert( - -1, - ensure_tool_name_unique_query, - ) - - return (queries, {"records": tools_data}) + return ( + sql_query.format(), + tools_data, + "fetchmany", + ) From b774589abfd3f7569fe31f2d36393492a8b20dad Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Fri, 20 Dec 2024 14:27:45 +0300 Subject: [PATCH 10/21] feat: Add delete tool query --- .../agents_api/queries/tools/delete_tool.py | 69 +++++++++---------- 1 file changed, 33 insertions(+), 36 deletions(-) diff --git a/agents-api/agents_api/queries/tools/delete_tool.py b/agents-api/agents_api/queries/tools/delete_tool.py index c79cdfd29..59f561cf1 100644 --- a/agents-api/agents_api/queries/tools/delete_tool.py +++ b/agents-api/agents_api/queries/tools/delete_tool.py @@ -1,19 +1,14 @@ from typing import Any, TypeVar from uuid import UUID +import sqlvalidator from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow +from ...exceptions import InvalidSQLQuery from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, + pg_query, wrap_in_class, ) @@ -21,20 +16,34 @@ T = TypeVar("T") -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) +sql_query = sqlvalidator.parse(""" +DELETE FROM + tools +WHERE + developer_id = $1 AND + agent_id = $2 AND + tool_id = $3 +RETURNING * +""") + +if not sql_query.is_valid(): + raise InvalidSQLQuery("delete_tool") + + +# @rewrap_exceptions( +# { +# QueryException: partialclass(HTTPException, status_code=400), +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) @wrap_in_class( ResourceDeletedResponse, one=True, transform=lambda d: {"id": d["tool_id"], "deleted_at": utcnow(), "jobs": [], **d}, _kind="deleted", ) -@cozo_query +@pg_query @beartype def delete_tool( *, @@ -42,27 +51,15 @@ def delete_tool( agent_id: UUID, tool_id: UUID, ) -> tuple[list[str], dict]: + developer_id = str(developer_id) agent_id = str(agent_id) tool_id = str(tool_id) - delete_query = """ - # Delete function - ?[tool_id, agent_id] <- [[ - to_uuid($tool_id), - to_uuid($agent_id), - ]] - - :delete tools { - tool_id, + return ( + sql_query.format(), + [ + developer_id, agent_id, - } - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - delete_query, - ] - - return (queries, {"tool_id": tool_id, "agent_id": agent_id}) + tool_id, + ], + ) From b2806ac80c2cfb99cef9022b8f957797150c38d5 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Fri, 20 Dec 2024 14:32:39 +0300 Subject: [PATCH 11/21] feat: Add get tool query --- .../agents_api/queries/tools/get_tool.py | 76 ++++++++----------- 1 file changed, 30 insertions(+), 46 deletions(-) diff --git a/agents-api/agents_api/queries/tools/get_tool.py b/agents-api/agents_api/queries/tools/get_tool.py index 465fd2efe..3662725b8 100644 --- a/agents-api/agents_api/queries/tools/get_tool.py +++ b/agents-api/agents_api/queries/tools/get_tool.py @@ -1,32 +1,39 @@ from typing import Any, TypeVar from uuid import UUID +import sqlvalidator from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError from ...autogen.openapi_model import Tool +from ...exceptions import InvalidSQLQuery from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, + pg_query, wrap_in_class, ) ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") +sql_query = sqlvalidator.parse(""" +SELECT * FROM tools +WHERE + developer_id = $1 AND + agent_id = $2 AND + tool_id = $3 +LIMIT 1 +""") -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) +if not sql_query.is_valid(): + raise InvalidSQLQuery("get_tool") + + +# @rewrap_exceptions( +# { +# QueryException: partialclass(HTTPException, status_code=400), +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) @wrap_in_class( Tool, transform=lambda d: { @@ -36,7 +43,7 @@ }, one=True, ) -@cozo_query +@pg_query @beartype def get_tool( *, @@ -44,38 +51,15 @@ def get_tool( agent_id: UUID, tool_id: UUID, ) -> tuple[list[str], dict]: + developer_id = str(developer_id) agent_id = str(agent_id) tool_id = str(tool_id) - get_query = """ - input[agent_id, tool_id] <- [[to_uuid($agent_id), to_uuid($tool_id)]] - - ?[ + return ( + sql_query.format(), + [ + developer_id, agent_id, tool_id, - type, - name, - spec, - updated_at, - created_at, - ] := input[agent_id, tool_id], - *tools { - agent_id, - tool_id, - name, - type, - spec, - updated_at, - created_at, - } - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - get_query, - ] - - return (queries, {"agent_id": agent_id, "tool_id": tool_id}) + ], + ) From cf184e0e5f2d1487f22e89561474ca976b85e80c Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Fri, 20 Dec 2024 14:45:28 +0300 Subject: [PATCH 12/21] feat: Add list tools query --- .../agents_api/queries/tools/list_tools.py | 92 ++++++++----------- 1 file changed, 37 insertions(+), 55 deletions(-) diff --git a/agents-api/agents_api/queries/tools/list_tools.py b/agents-api/agents_api/queries/tools/list_tools.py index 727bf8028..59fb1eff5 100644 --- a/agents-api/agents_api/queries/tools/list_tools.py +++ b/agents-api/agents_api/queries/tools/list_tools.py @@ -1,32 +1,43 @@ from typing import Any, Literal, TypeVar from uuid import UUID +import sqlvalidator from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError from ...autogen.openapi_model import Tool +from ...exceptions import InvalidSQLQuery from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, + pg_query, wrap_in_class, ) ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") +sql_query = sqlvalidator.parse(""" +SELECT * FROM tools +WHERE + developer_id = $1 AND + agent_id = $2 +ORDER BY + CASE WHEN $5 = 'created_at' AND $6 = 'desc' THEN s.created_at END DESC, + CASE WHEN $5 = 'created_at' AND $6 = 'asc' THEN s.created_at END ASC, + CASE WHEN $5 = 'updated_at' AND $6 = 'desc' THEN s.updated_at END DESC, + CASE WHEN $5 = 'updated_at' AND $6 = 'asc' THEN s.updated_at END ASC +LIMIT $3 OFFSET $4; +""") -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) +if not sql_query.is_valid(): + raise InvalidSQLQuery("get_tool") + + +# @rewrap_exceptions( +# { +# QueryException: partialclass(HTTPException, status_code=400), +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) @wrap_in_class( Tool, transform=lambda d: { @@ -38,7 +49,7 @@ **d, }, ) -@cozo_query +@pg_query @beartype def list_tools( *, @@ -49,46 +60,17 @@ def list_tools( sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", ) -> tuple[list[str], dict]: + developer_id = str(developer_id) agent_id = str(agent_id) - sort = f"{'-' if direction == 'desc' else ''}{sort_by}" - - list_query = f""" - input[agent_id] <- [[to_uuid($agent_id)]] - - ?[ - agent_id, - id, - name, - type, - spec, - description, - updated_at, - created_at, - ] := input[agent_id], - *tools {{ - agent_id, - tool_id: id, - name, - type, - spec, - description, - updated_at, - created_at, - }} - - :limit $limit - :offset $offset - :sort {sort} - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - list_query, - ] - return ( - queries, - {"agent_id": agent_id, "limit": limit, "offset": offset}, + sql_query.format(), + [ + developer_id, + agent_id, + limit, + offset, + sort_by, + direction, + ], ) From 53b65a19fea18a285f62d6758693ed44653de74c Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Fri, 20 Dec 2024 15:21:09 +0300 Subject: [PATCH 13/21] feat: Add patch tool query --- .../agents_api/queries/tools/patch_tool.py | 94 +++++++++---------- 1 file changed, 43 insertions(+), 51 deletions(-) diff --git a/agents-api/agents_api/queries/tools/patch_tool.py b/agents-api/agents_api/queries/tools/patch_tool.py index bc49b8121..aa663dec0 100644 --- a/agents-api/agents_api/queries/tools/patch_tool.py +++ b/agents-api/agents_api/queries/tools/patch_tool.py @@ -1,20 +1,14 @@ from typing import Any, TypeVar from uuid import UUID +import sqlvalidator from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError from ...autogen.openapi_model import PatchToolRequest, ResourceUpdatedResponse -from ...common.utils.cozo import cozo_process_mutate_data +from ...exceptions import InvalidSQLQuery from ...metrics.counters import increase_counter from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, + pg_query, wrap_in_class, ) @@ -22,25 +16,46 @@ T = TypeVar("T") -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } +sql_query = sqlvalidator.parse(""" +WITH updated_tools AS ( + UPDATE tools + SET + type = COALESCE($4, type), + name = COALESCE($5, name), + description = COALESCE($6, description), + spec = COALESCE($7, spec) + WHERE + developer_id = $1 AND + agent_id = $2 AND + tool_id = $3 + RETURNING * ) +SELECT * FROM updated_tools; +""") + +if not sql_query.is_valid(): + raise InvalidSQLQuery("patch_tool") + + +# @rewrap_exceptions( +# { +# QueryException: partialclass(HTTPException, status_code=400), +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) @wrap_in_class( ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["tool_id"], "jobs": [], **d}, _kind="inserted", ) -@cozo_query +@pg_query @increase_counter("patch_tool") @beartype def patch_tool( *, developer_id: UUID, agent_id: UUID, tool_id: UUID, data: PatchToolRequest -) -> tuple[list[str], dict]: +) -> tuple[list[str], list]: """ Execute the datalog query and return the results as a DataFrame Updates the tool information for a given agent and tool ID in the 'cozodb' database. @@ -54,6 +69,7 @@ def patch_tool( ResourceUpdatedResponse: The updated tool data. """ + developer_id = str(developer_id) agent_id = str(agent_id) tool_id = str(tool_id) @@ -78,39 +94,15 @@ def patch_tool( if tool_spec: del patch_data[tool_type] - tool_cols, tool_vals = cozo_process_mutate_data( - { - **patch_data, - "agent_id": agent_id, - "tool_id": tool_id, - } - ) - - # Construct the datalog query for updating the tool information - patch_query = f""" - input[{tool_cols}] <- $input - - ?[{tool_cols}, spec, updated_at] := - *tools {{ - agent_id: to_uuid($agent_id), - tool_id: to_uuid($tool_id), - spec: old_spec, - }}, - input[{tool_cols}], - spec = concat(old_spec, $spec), - updated_at = now() - - :update tools {{ {tool_cols}, spec, updated_at }} - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - patch_query, - ] - return ( - queries, - dict(input=tool_vals, spec=tool_spec, agent_id=agent_id, tool_id=tool_id), + sql_query.format(), + [ + developer_id, + agent_id, + tool_id, + tool_type, + data.name, + data.description, + tool_spec, + ], ) From 3299d54a16fb0485394f1a4d3658e6db4a768d73 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Fri, 20 Dec 2024 15:21:21 +0300 Subject: [PATCH 14/21] fix: Fix return types --- agents-api/agents_api/queries/tools/delete_tool.py | 2 +- agents-api/agents_api/queries/tools/get_tool.py | 2 +- agents-api/agents_api/queries/tools/list_tools.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/agents-api/agents_api/queries/tools/delete_tool.py b/agents-api/agents_api/queries/tools/delete_tool.py index 59f561cf1..17535e1e4 100644 --- a/agents-api/agents_api/queries/tools/delete_tool.py +++ b/agents-api/agents_api/queries/tools/delete_tool.py @@ -50,7 +50,7 @@ def delete_tool( developer_id: UUID, agent_id: UUID, tool_id: UUID, -) -> tuple[list[str], dict]: +) -> tuple[list[str], list]: developer_id = str(developer_id) agent_id = str(agent_id) tool_id = str(tool_id) diff --git a/agents-api/agents_api/queries/tools/get_tool.py b/agents-api/agents_api/queries/tools/get_tool.py index 3662725b8..af63be0c9 100644 --- a/agents-api/agents_api/queries/tools/get_tool.py +++ b/agents-api/agents_api/queries/tools/get_tool.py @@ -50,7 +50,7 @@ def get_tool( developer_id: UUID, agent_id: UUID, tool_id: UUID, -) -> tuple[list[str], dict]: +) -> tuple[list[str], list]: developer_id = str(developer_id) agent_id = str(agent_id) tool_id = str(tool_id) diff --git a/agents-api/agents_api/queries/tools/list_tools.py b/agents-api/agents_api/queries/tools/list_tools.py index 59fb1eff5..3dac84875 100644 --- a/agents-api/agents_api/queries/tools/list_tools.py +++ b/agents-api/agents_api/queries/tools/list_tools.py @@ -28,7 +28,7 @@ """) if not sql_query.is_valid(): - raise InvalidSQLQuery("get_tool") + raise InvalidSQLQuery("list_tools") # @rewrap_exceptions( @@ -59,7 +59,7 @@ def list_tools( offset: int = 0, sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", -) -> tuple[list[str], dict]: +) -> tuple[list[str], list]: developer_id = str(developer_id) agent_id = str(agent_id) From 5e94d332119c88dca9c7dbba5cf601818958e4d5 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Fri, 20 Dec 2024 15:31:48 +0300 Subject: [PATCH 15/21] feat: Add update tool query --- .../agents_api/queries/tools/update_tool.py | 93 ++++++++----------- 1 file changed, 41 insertions(+), 52 deletions(-) diff --git a/agents-api/agents_api/queries/tools/update_tool.py b/agents-api/agents_api/queries/tools/update_tool.py index ef700a5f6..356e28bbf 100644 --- a/agents-api/agents_api/queries/tools/update_tool.py +++ b/agents-api/agents_api/queries/tools/update_tool.py @@ -1,44 +1,55 @@ from typing import Any, TypeVar from uuid import UUID +import sqlvalidator from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError from ...autogen.openapi_model import ( ResourceUpdatedResponse, UpdateToolRequest, ) -from ...common.utils.cozo import cozo_process_mutate_data +from ...exceptions import InvalidSQLQuery from ...metrics.counters import increase_counter from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, + pg_query, wrap_in_class, ) ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) +sql_query = sqlvalidator.parse(""" +UPDATE tools +SET + type = $4, + name = $5, + description = $6, + spec = $7 +WHERE + developer_id = $1 AND + agent_id = $2 AND + tool_id = $3 +RETURNING *; +""") + +if not sql_query.is_valid(): + raise InvalidSQLQuery("update_tool") + + +# @rewrap_exceptions( +# { +# QueryException: partialclass(HTTPException, status_code=400), +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) @wrap_in_class( ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["tool_id"], "jobs": [], **d}, _kind="inserted", ) -@cozo_query +@pg_query @increase_counter("update_tool") @beartype def update_tool( @@ -48,7 +59,8 @@ def update_tool( tool_id: UUID, data: UpdateToolRequest, **kwargs, -) -> tuple[list[str], dict]: +) -> tuple[list[str], list]: + developer_id = str(developer_id) agent_id = str(agent_id) tool_id = str(tool_id) @@ -72,38 +84,15 @@ def update_tool( update_data["spec"] = tool_spec del update_data[tool_type] - tool_cols, tool_vals = cozo_process_mutate_data( - { - **update_data, - "agent_id": agent_id, - "tool_id": tool_id, - } - ) - - # Construct the datalog query for updating the tool information - patch_query = f""" - input[{tool_cols}] <- $input - - ?[{tool_cols}, created_at, updated_at] := - *tools {{ - agent_id: to_uuid($agent_id), - tool_id: to_uuid($tool_id), - created_at - }}, - input[{tool_cols}], - updated_at = now() - - :put tools {{ {tool_cols}, created_at, updated_at }} - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - patch_query, - ] - return ( - queries, - dict(input=tool_vals, spec=tool_spec, agent_id=agent_id, tool_id=tool_id), + sql_query.format(), + [ + developer_id, + agent_id, + tool_id, + tool_type, + data.name, + data.description, + tool_spec, + ], ) From 0252a8870aafdba70213e7271a18f083b4b948be Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Sat, 21 Dec 2024 21:05:17 +0300 Subject: [PATCH 16/21] WIP --- .../tools/get_tool_args_from_metadata.py | 33 +++++-------------- 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py index 2cdb92cb9..a8a9dba1a 100644 --- a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py +++ b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py @@ -2,16 +2,9 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, + pg_query, wrap_in_class, ) @@ -51,10 +44,6 @@ def tool_args_for_task( """ queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "tasks", task_id=task_id, parents=[("agents", "agent_id")] - ), get_query, ] @@ -95,25 +84,21 @@ def tool_args_for_session( """ queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), get_query, ] return (queries, {"agent_id": agent_id, "session_id": session_id}) -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) +# @rewrap_exceptions( +# { +# QueryException: partialclass(HTTPException, status_code=400), +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) @wrap_in_class(dict, transform=lambda x: x["values"], one=True) -@cozo_query +@pg_query @beartype def get_tool_args_from_metadata( *, From d209e773c6340d618333fecea4b78309774bac6e Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Mon, 23 Dec 2024 11:32:13 +0300 Subject: [PATCH 17/21] feat: Add tools args from metadata query --- .../tools/get_tool_args_from_metadata.py | 151 +++++++----------- 1 file changed, 59 insertions(+), 92 deletions(-) diff --git a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py index a8a9dba1a..57453cd34 100644 --- a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py +++ b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py @@ -1,93 +1,62 @@ from typing import Literal from uuid import UUID +import sqlvalidator from beartype import beartype +from ...exceptions import InvalidSQLQuery from ..utils import ( pg_query, wrap_in_class, ) +tools_args_for_task_query = sqlvalidator.parse( + """SELECT COALESCE(agents_md || tasks_md, agents_md, tasks_md, '{}') as values FROM ( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS agents_md + FROM agents + WHERE agent_id = $1 AND developer_id = $4 LIMIT 1 +) AS agents_md, +( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md + FROM tasks + WHERE task_id = $2 AND developer_id = $4 LIMIT 1 +) AS tasks_md""" +) -def tool_args_for_task( - *, - developer_id: UUID, - agent_id: UUID, - task_id: UUID, - tool_type: Literal["integration", "api_call"] = "integration", - arg_type: Literal["args", "setup"] = "args", -) -> tuple[list[str], dict]: - agent_id = str(agent_id) - task_id = str(task_id) - - get_query = f""" - input[agent_id, task_id] <- [[to_uuid($agent_id), to_uuid($task_id)]] - - ?[values] := - input[agent_id, task_id], - *tasks {{ - task_id, - metadata: task_metadata, - }}, - *agents {{ - agent_id, - metadata: agent_metadata, - }}, - task_{arg_type} = get(task_metadata, "x-{tool_type}-{arg_type}", {{}}), - agent_{arg_type} = get(agent_metadata, "x-{tool_type}-{arg_type}", {{}}), - - # Right values overwrite left values - # See: https://docs.cozodb.org/en/latest/functions.html#Func.Vector.concat - values = concat(agent_{arg_type}, task_{arg_type}), - - :limit 1 - """ - - queries = [ - get_query, - ] - - return (queries, {"agent_id": agent_id, "task_id": task_id}) - - -def tool_args_for_session( - *, - developer_id: UUID, - session_id: UUID, - agent_id: UUID, - arg_type: Literal["args", "setup"] = "args", - tool_type: Literal["integration", "api_call"] = "integration", -) -> tuple[list[str], dict]: - session_id = str(session_id) - - get_query = f""" - input[session_id, agent_id] <- [[to_uuid($session_id), to_uuid($agent_id)]] - - ?[values] := - input[session_id, agent_id], - *sessions {{ - session_id, - metadata: session_metadata, - }}, - *agents {{ - agent_id, - metadata: agent_metadata, - }}, - session_{arg_type} = get(session_metadata, "x-{tool_type}-{arg_type}", {{}}), - agent_{arg_type} = get(agent_metadata, "x-{tool_type}-{arg_type}", {{}}), - - # Right values overwrite left values - # See: https://docs.cozodb.org/en/latest/functions.html#Func.Vector.concat - values = concat(agent_{arg_type}, session_{arg_type}), - - :limit 1 - """ - - queries = [ - get_query, - ] +if not tools_args_for_task_query.is_valid(): + raise InvalidSQLQuery("tools_args_for_task_query") + +tool_args_for_session_query = sqlvalidator.parse( + """SELECT COALESCE(agents_md || sessions_md, agents_md, sessions_md, '{}') as values FROM ( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS agents_md + FROM agents + WHERE agent_id = $1 AND developer_id = $4 LIMIT 1 +) AS agents_md, +( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md + FROM sessions + WHERE session_id = $2 AND developer_id = $4 LIMIT 1 +) AS sessions_md""" +) - return (queries, {"agent_id": agent_id, "session_id": session_id}) +if not tool_args_for_session_query.is_valid(): + raise InvalidSQLQuery("tool_args_for_session") # @rewrap_exceptions( @@ -108,25 +77,23 @@ def get_tool_args_from_metadata( task_id: UUID | None = None, tool_type: Literal["integration", "api_call"] = "integration", arg_type: Literal["args", "setup", "headers"] = "args", -) -> tuple[list[str], dict]: - common: dict = dict( - developer_id=developer_id, - agent_id=agent_id, - tool_type=tool_type, - arg_type=arg_type, - ) - +) -> tuple[list[str], list]: match session_id, task_id: case (None, task_id) if task_id is not None: - return tool_args_for_task( - **common, - task_id=task_id, + return ( + tools_args_for_task_query.format(), + [ + agent_id, + task_id, + f"x-{tool_type}-{arg_type}", + developer_id, + ], ) case (session_id, None) if session_id is not None: - return tool_args_for_session( - **common, - session_id=session_id, + return ( + tool_args_for_session_query.format(), + [agent_id, session_id, f"x-{tool_type}-{arg_type}", developer_id], ) case (_, _): From 583bf66c89fc35063f6c9afc0d030a279f2595d0 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Mon, 23 Dec 2024 14:42:53 +0300 Subject: [PATCH 18/21] fix: Remove sql validation, fix tests --- .../agents_api/queries/tools/create_tools.py | 18 +- .../agents_api/queries/tools/delete_tool.py | 13 +- .../agents_api/queries/tools/get_tool.py | 16 +- .../tools/get_tool_args_from_metadata.py | 26 +- .../agents_api/queries/tools/list_tools.py | 23 +- .../agents_api/queries/tools/patch_tool.py | 15 +- .../agents_api/queries/tools/update_tool.py | 15 +- agents-api/tests/fixtures.py | 88 ++--- agents-api/tests/test_tool_queries.py | 344 +++++++++--------- 9 files changed, 281 insertions(+), 277 deletions(-) diff --git a/agents-api/agents_api/queries/tools/create_tools.py b/agents-api/agents_api/queries/tools/create_tools.py index d50e98e80..075497541 100644 --- a/agents-api/agents_api/queries/tools/create_tools.py +++ b/agents-api/agents_api/queries/tools/create_tools.py @@ -20,8 +20,7 @@ T = TypeVar("T") -sql_query = sqlvalidator.parse( - """INSERT INTO tools +sql_query = """INSERT INTO tools ( developer_id, agent_id, @@ -45,10 +44,10 @@ ) RETURNING * """ -) -if not sql_query.is_valid(): - raise InvalidSQLQuery("create_tools") + +# if not sql_query.is_valid(): +# raise InvalidSQLQuery("create_tools") # @rewrap_exceptions( @@ -61,22 +60,21 @@ @wrap_in_class( Tool, transform=lambda d: { - "id": UUID(d.pop("tool_id")), + "id": d.pop("tool_id"), d["type"]: d.pop("spec"), **d, }, - _kind="inserted", ) @pg_query @increase_counter("create_tools") @beartype -def create_tools( +async def create_tools( *, developer_id: UUID, agent_id: UUID, data: list[CreateToolRequest], ignore_existing: bool = False, # TODO: what to do with this flag? -) -> tuple[list[str], list]: +) -> tuple[str, list] | tuple[str, list, str]: """ Constructs a datalog query for inserting tool records into the 'agent_functions' relation in the CozoDB. @@ -108,7 +106,7 @@ def create_tools( ] return ( - sql_query.format(), + sql_query, tools_data, "fetchmany", ) diff --git a/agents-api/agents_api/queries/tools/delete_tool.py b/agents-api/agents_api/queries/tools/delete_tool.py index 17535e1e4..c67cdaba5 100644 --- a/agents-api/agents_api/queries/tools/delete_tool.py +++ b/agents-api/agents_api/queries/tools/delete_tool.py @@ -16,7 +16,7 @@ T = TypeVar("T") -sql_query = sqlvalidator.parse(""" +sql_query = """ DELETE FROM tools WHERE @@ -24,10 +24,10 @@ agent_id = $2 AND tool_id = $3 RETURNING * -""") +""" -if not sql_query.is_valid(): - raise InvalidSQLQuery("delete_tool") +# if not sql_query.is_valid(): +# raise InvalidSQLQuery("delete_tool") # @rewrap_exceptions( @@ -41,11 +41,10 @@ ResourceDeletedResponse, one=True, transform=lambda d: {"id": d["tool_id"], "deleted_at": utcnow(), "jobs": [], **d}, - _kind="deleted", ) @pg_query @beartype -def delete_tool( +async def delete_tool( *, developer_id: UUID, agent_id: UUID, @@ -56,7 +55,7 @@ def delete_tool( tool_id = str(tool_id) return ( - sql_query.format(), + sql_query, [ developer_id, agent_id, diff --git a/agents-api/agents_api/queries/tools/get_tool.py b/agents-api/agents_api/queries/tools/get_tool.py index af63be0c9..7581714e9 100644 --- a/agents-api/agents_api/queries/tools/get_tool.py +++ b/agents-api/agents_api/queries/tools/get_tool.py @@ -14,17 +14,17 @@ ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") -sql_query = sqlvalidator.parse(""" +sql_query = """ SELECT * FROM tools WHERE developer_id = $1 AND agent_id = $2 AND tool_id = $3 LIMIT 1 -""") +""" -if not sql_query.is_valid(): - raise InvalidSQLQuery("get_tool") +# if not sql_query.is_valid(): +# raise InvalidSQLQuery("get_tool") # @rewrap_exceptions( @@ -37,7 +37,7 @@ @wrap_in_class( Tool, transform=lambda d: { - "id": UUID(d.pop("tool_id")), + "id": d.pop("tool_id"), d["type"]: d.pop("spec"), **d, }, @@ -45,18 +45,18 @@ ) @pg_query @beartype -def get_tool( +async def get_tool( *, developer_id: UUID, agent_id: UUID, tool_id: UUID, -) -> tuple[list[str], list]: +) -> tuple[str, list] | tuple[str, list, str]: developer_id = str(developer_id) agent_id = str(agent_id) tool_id = str(tool_id) return ( - sql_query.format(), + sql_query, [ developer_id, agent_id, diff --git a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py index 57453cd34..f4caf5524 100644 --- a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py +++ b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py @@ -10,8 +10,7 @@ wrap_in_class, ) -tools_args_for_task_query = sqlvalidator.parse( - """SELECT COALESCE(agents_md || tasks_md, agents_md, tasks_md, '{}') as values FROM ( +tools_args_for_task_query = """SELECT COALESCE(agents_md || tasks_md, agents_md, tasks_md, '{}') as values FROM ( SELECT CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' @@ -29,13 +28,12 @@ FROM tasks WHERE task_id = $2 AND developer_id = $4 LIMIT 1 ) AS tasks_md""" -) -if not tools_args_for_task_query.is_valid(): - raise InvalidSQLQuery("tools_args_for_task_query") -tool_args_for_session_query = sqlvalidator.parse( - """SELECT COALESCE(agents_md || sessions_md, agents_md, sessions_md, '{}') as values FROM ( +# if not tools_args_for_task_query.is_valid(): +# raise InvalidSQLQuery("tools_args_for_task_query") + +tool_args_for_session_query = """SELECT COALESCE(agents_md || sessions_md, agents_md, sessions_md, '{}') as values FROM ( SELECT CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' @@ -53,10 +51,10 @@ FROM sessions WHERE session_id = $2 AND developer_id = $4 LIMIT 1 ) AS sessions_md""" -) -if not tool_args_for_session_query.is_valid(): - raise InvalidSQLQuery("tool_args_for_session") + +# if not tool_args_for_session_query.is_valid(): +# raise InvalidSQLQuery("tool_args_for_session") # @rewrap_exceptions( @@ -69,7 +67,7 @@ @wrap_in_class(dict, transform=lambda x: x["values"], one=True) @pg_query @beartype -def get_tool_args_from_metadata( +async def get_tool_args_from_metadata( *, developer_id: UUID, agent_id: UUID, @@ -77,11 +75,11 @@ def get_tool_args_from_metadata( task_id: UUID | None = None, tool_type: Literal["integration", "api_call"] = "integration", arg_type: Literal["args", "setup", "headers"] = "args", -) -> tuple[list[str], list]: +) -> tuple[str, list] | tuple[str, list, str]: match session_id, task_id: case (None, task_id) if task_id is not None: return ( - tools_args_for_task_query.format(), + tools_args_for_task_query, [ agent_id, task_id, @@ -92,7 +90,7 @@ def get_tool_args_from_metadata( case (session_id, None) if session_id is not None: return ( - tool_args_for_session_query.format(), + tool_args_for_session_query, [agent_id, session_id, f"x-{tool_type}-{arg_type}", developer_id], ) diff --git a/agents-api/agents_api/queries/tools/list_tools.py b/agents-api/agents_api/queries/tools/list_tools.py index 3dac84875..01460e16b 100644 --- a/agents-api/agents_api/queries/tools/list_tools.py +++ b/agents-api/agents_api/queries/tools/list_tools.py @@ -14,21 +14,21 @@ ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") -sql_query = sqlvalidator.parse(""" +sql_query = """ SELECT * FROM tools WHERE developer_id = $1 AND agent_id = $2 ORDER BY - CASE WHEN $5 = 'created_at' AND $6 = 'desc' THEN s.created_at END DESC, - CASE WHEN $5 = 'created_at' AND $6 = 'asc' THEN s.created_at END ASC, - CASE WHEN $5 = 'updated_at' AND $6 = 'desc' THEN s.updated_at END DESC, - CASE WHEN $5 = 'updated_at' AND $6 = 'asc' THEN s.updated_at END ASC + CASE WHEN $5 = 'created_at' AND $6 = 'desc' THEN tools.created_at END DESC NULLS LAST, + CASE WHEN $5 = 'created_at' AND $6 = 'asc' THEN tools.created_at END ASC NULLS LAST, + CASE WHEN $5 = 'updated_at' AND $6 = 'desc' THEN tools.updated_at END DESC NULLS LAST, + CASE WHEN $5 = 'updated_at' AND $6 = 'asc' THEN tools.updated_at END ASC NULLS LAST LIMIT $3 OFFSET $4; -""") +""" -if not sql_query.is_valid(): - raise InvalidSQLQuery("list_tools") +# if not sql_query.is_valid(): +# raise InvalidSQLQuery("list_tools") # @rewrap_exceptions( @@ -46,12 +46,13 @@ "name": d["name"], "description": d["description"], }, + "id": d.pop("tool_id"), **d, }, ) @pg_query @beartype -def list_tools( +async def list_tools( *, developer_id: UUID, agent_id: UUID, @@ -59,12 +60,12 @@ def list_tools( offset: int = 0, sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", -) -> tuple[list[str], list]: +) -> tuple[str, list] | tuple[str, list, str]: developer_id = str(developer_id) agent_id = str(agent_id) return ( - sql_query.format(), + sql_query, [ developer_id, agent_id, diff --git a/agents-api/agents_api/queries/tools/patch_tool.py b/agents-api/agents_api/queries/tools/patch_tool.py index aa663dec0..a8adf1fa6 100644 --- a/agents-api/agents_api/queries/tools/patch_tool.py +++ b/agents-api/agents_api/queries/tools/patch_tool.py @@ -16,7 +16,7 @@ T = TypeVar("T") -sql_query = sqlvalidator.parse(""" +sql_query = """ WITH updated_tools AS ( UPDATE tools SET @@ -31,10 +31,10 @@ RETURNING * ) SELECT * FROM updated_tools; -""") +""" -if not sql_query.is_valid(): - raise InvalidSQLQuery("patch_tool") +# if not sql_query.is_valid(): +# raise InvalidSQLQuery("patch_tool") # @rewrap_exceptions( @@ -48,14 +48,13 @@ ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["tool_id"], "jobs": [], **d}, - _kind="inserted", ) @pg_query @increase_counter("patch_tool") @beartype -def patch_tool( +async def patch_tool( *, developer_id: UUID, agent_id: UUID, tool_id: UUID, data: PatchToolRequest -) -> tuple[list[str], list]: +) -> tuple[str, list] | tuple[str, list, str]: """ Execute the datalog query and return the results as a DataFrame Updates the tool information for a given agent and tool ID in the 'cozodb' database. @@ -95,7 +94,7 @@ def patch_tool( del patch_data[tool_type] return ( - sql_query.format(), + sql_query, [ developer_id, agent_id, diff --git a/agents-api/agents_api/queries/tools/update_tool.py b/agents-api/agents_api/queries/tools/update_tool.py index 356e28bbf..bb1d8dc87 100644 --- a/agents-api/agents_api/queries/tools/update_tool.py +++ b/agents-api/agents_api/queries/tools/update_tool.py @@ -18,7 +18,7 @@ ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") -sql_query = sqlvalidator.parse(""" +sql_query = """ UPDATE tools SET type = $4, @@ -30,10 +30,10 @@ agent_id = $2 AND tool_id = $3 RETURNING *; -""") +""" -if not sql_query.is_valid(): - raise InvalidSQLQuery("update_tool") +# if not sql_query.is_valid(): +# raise InvalidSQLQuery("update_tool") # @rewrap_exceptions( @@ -47,19 +47,18 @@ ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["tool_id"], "jobs": [], **d}, - _kind="inserted", ) @pg_query @increase_counter("update_tool") @beartype -def update_tool( +async def update_tool( *, developer_id: UUID, agent_id: UUID, tool_id: UUID, data: UpdateToolRequest, **kwargs, -) -> tuple[list[str], list]: +) -> tuple[str, list] | tuple[str, list, str]: developer_id = str(developer_id) agent_id = str(agent_id) tool_id = str(tool_id) @@ -85,7 +84,7 @@ def update_tool( del update_data[tool_type] return ( - sql_query.format(), + sql_query, [ developer_id, agent_id, diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index ea3866ff2..b342cd0b7 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -13,29 +13,30 @@ CreateSessionRequest, CreateTaskRequest, CreateUserRequest, + CreateToolRequest, ) from agents_api.clients.pg import create_db_pool from agents_api.env import api_key, api_key_header_name, multi_tenant_mode from agents_api.queries.agents.create_agent import create_agent from agents_api.queries.developers.create_developer import create_developer -# from agents_api.queries.agents.delete_agent import delete_agent +from agents_api.queries.agents.delete_agent import delete_agent from agents_api.queries.developers.get_developer import get_developer from agents_api.queries.docs.create_doc import create_doc -# from agents_api.queries.docs.delete_doc import delete_doc -# from agents_api.queries.execution.create_execution import create_execution -# from agents_api.queries.execution.create_execution_transition import create_execution_transition -# from agents_api.queries.execution.create_temporal_lookup import create_temporal_lookup +from agents_api.queries.docs.delete_doc import delete_doc +# from agents_api.queries.executions.create_execution import create_execution +# from agents_api.queries.executions.create_execution_transition import create_execution_transition +# from agents_api.queries.executions.create_temporal_lookup import create_temporal_lookup from agents_api.queries.files.create_file import create_file -# from agents_api.queries.files.delete_file import delete_file +from agents_api.queries.files.delete_file import delete_file from agents_api.queries.sessions.create_session import create_session from agents_api.queries.tasks.create_task import create_task -# from agents_api.queries.task.delete_task import delete_task -# from agents_api.queries.tools.create_tools import create_tools -# from agents_api.queries.tools.delete_tool import delete_tool +from agents_api.queries.tasks.delete_task import delete_task +from agents_api.queries.tools.create_tools import create_tools +from agents_api.queries.tools.delete_tool import delete_tool from agents_api.queries.users.create_user import create_user from agents_api.web import app @@ -347,40 +348,41 @@ async def test_session( # yield transition -# @fixture(scope="global") -# async def test_tool( -# dsn=pg_dsn, -# developer_id=test_developer_id, -# agent=test_agent, -# ): -# function = { -# "description": "A function that prints hello world", -# "parameters": {"type": "object", "properties": {}}, -# } - -# tool = { -# "function": function, -# "name": "hello_world1", -# "type": "function", -# } - -# [tool, *_] = await create_tools( -# developer_id=developer_id, -# agent_id=agent.id, -# data=[CreateToolRequest(**tool)], -# connection_pool=pool, -# ) -# yield tool - -# # Cleanup -# try: -# await delete_tool( -# developer_id=developer_id, -# tool_id=tool.id, -# connection_pool=pool, -# ) -# finally: -# await pool.close() +@fixture(scope="global") +async def test_tool( + dsn=pg_dsn, + developer_id=test_developer_id, + agent=test_agent, +): + pool = await create_db_pool(dsn=dsn) + function = { + "description": "A function that prints hello world", + "parameters": {"type": "object", "properties": {}}, + } + + tool = { + "function": function, + "name": "hello_world1", + "type": "function", + } + + [tool, *_] = await create_tools( + developer_id=developer_id, + agent_id=agent.id, + data=[CreateToolRequest(**tool)], + connection_pool=pool, + ) + yield tool + + # Cleanup + try: + await delete_tool( + developer_id=developer_id, + tool_id=tool.id, + connection_pool=pool, + ) + finally: + await pool.close() @fixture(scope="global") diff --git a/agents-api/tests/test_tool_queries.py b/agents-api/tests/test_tool_queries.py index f6f4bac47..43bdf8159 100644 --- a/agents-api/tests/test_tool_queries.py +++ b/agents-api/tests/test_tool_queries.py @@ -1,170 +1,178 @@ # # Tests for tool queries -# from ward import test - -# from agents_api.autogen.openapi_model import ( -# CreateToolRequest, -# PatchToolRequest, -# Tool, -# UpdateToolRequest, -# ) -# from agents_api.queries.tools.create_tools import create_tools -# from agents_api.queries.tools.delete_tool import delete_tool -# from agents_api.queries.tools.get_tool import get_tool -# from agents_api.queries.tools.list_tools import list_tools -# from agents_api.queries.tools.patch_tool import patch_tool -# from agents_api.queries.tools.update_tool import update_tool -# from tests.fixtures import cozo_client, test_agent, test_developer_id, test_tool - - -# @test("query: create tool") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# function = { -# "name": "hello_world", -# "description": "A function that prints hello world", -# "parameters": {"type": "object", "properties": {}}, -# } - -# tool = { -# "function": function, -# "name": "hello_world", -# "type": "function", -# } - -# result = create_tools( -# developer_id=developer_id, -# agent_id=agent.id, -# data=[CreateToolRequest(**tool)], -# client=client, -# ) - -# assert result is not None -# assert isinstance(result[0], Tool) - - -# @test("query: delete tool") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# function = { -# "name": "temp_temp", -# "description": "A function that prints hello world", -# "parameters": {"type": "object", "properties": {}}, -# } - -# tool = { -# "function": function, -# "name": "temp_temp", -# "type": "function", -# } - -# [tool, *_] = create_tools( -# developer_id=developer_id, -# agent_id=agent.id, -# data=[CreateToolRequest(**tool)], -# client=client, -# ) - -# result = delete_tool( -# developer_id=developer_id, -# agent_id=agent.id, -# tool_id=tool.id, -# client=client, -# ) - -# assert result is not None - - -# @test("query: get tool") -# def _( -# client=cozo_client, developer_id=test_developer_id, tool=test_tool, agent=test_agent -# ): -# result = get_tool( -# developer_id=developer_id, -# agent_id=agent.id, -# tool_id=tool.id, -# client=client, -# ) - -# assert result is not None - - -# @test("query: list tools") -# def _( -# client=cozo_client, developer_id=test_developer_id, agent=test_agent, tool=test_tool -# ): -# result = list_tools( -# developer_id=developer_id, -# agent_id=agent.id, -# client=client, -# ) - -# assert result is not None -# assert all(isinstance(tool, Tool) for tool in result) - - -# @test("query: patch tool") -# def _( -# client=cozo_client, developer_id=test_developer_id, agent=test_agent, tool=test_tool -# ): -# patch_data = PatchToolRequest( -# **{ -# "name": "patched_tool", -# "function": { -# "description": "A patched function that prints hello world", -# }, -# } -# ) - -# result = patch_tool( -# developer_id=developer_id, -# agent_id=agent.id, -# tool_id=tool.id, -# data=patch_data, -# client=client, -# ) - -# assert result is not None - -# tool = get_tool( -# developer_id=developer_id, -# agent_id=agent.id, -# tool_id=tool.id, -# client=client, -# ) - -# assert tool.name == "patched_tool" -# assert tool.function.description == "A patched function that prints hello world" -# assert tool.function.parameters - - -# @test("query: update tool") -# def _( -# client=cozo_client, developer_id=test_developer_id, agent=test_agent, tool=test_tool -# ): -# update_data = UpdateToolRequest( -# name="updated_tool", -# description="An updated description", -# type="function", -# function={ -# "description": "An updated function that prints hello world", -# }, -# ) - -# result = update_tool( -# developer_id=developer_id, -# agent_id=agent.id, -# tool_id=tool.id, -# data=update_data, -# client=client, -# ) - -# assert result is not None - -# tool = get_tool( -# developer_id=developer_id, -# agent_id=agent.id, -# tool_id=tool.id, -# client=client, -# ) - -# assert tool.name == "updated_tool" -# assert not tool.function.parameters +from ward import test + +from agents_api.autogen.openapi_model import ( + CreateToolRequest, + PatchToolRequest, + Tool, + UpdateToolRequest, +) +from agents_api.queries.tools.create_tools import create_tools +from agents_api.queries.tools.delete_tool import delete_tool +from agents_api.queries.tools.get_tool import get_tool +from agents_api.queries.tools.list_tools import list_tools +from agents_api.queries.tools.patch_tool import patch_tool +from agents_api.queries.tools.update_tool import update_tool +from tests.fixtures import test_agent, test_developer_id, pg_dsn, test_tool +from agents_api.clients.pg import create_db_pool + + +@test("query: create tool") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + function = { + "name": "hello_world", + "description": "A function that prints hello world", + "parameters": {"type": "object", "properties": {}}, + } + + tool = { + "function": function, + "name": "hello_world", + "type": "function", + } + + result = await create_tools( + developer_id=developer_id, + agent_id=agent.id, + data=[CreateToolRequest(**tool)], + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result[0], Tool) + + +@test("query: delete tool") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + function = { + "name": "temp_temp", + "description": "A function that prints hello world", + "parameters": {"type": "object", "properties": {}}, + } + + tool = { + "function": function, + "name": "temp_temp", + "type": "function", + } + + [tool, *_] = await create_tools( + developer_id=developer_id, + agent_id=agent.id, + data=[CreateToolRequest(**tool)], + connection_pool=pool, + ) + + result = delete_tool( + developer_id=developer_id, + agent_id=agent.id, + tool_id=tool.id, + connection_pool=pool, + ) + + assert result is not None + + +@test("query: get tool") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, tool=test_tool, agent=test_agent +): + pool = await create_db_pool(dsn=dsn) + result = get_tool( + developer_id=developer_id, + agent_id=agent.id, + tool_id=tool.id, + connection_pool=pool, + ) + + assert result is not None + + +@test("query: list tools") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, tool=test_tool +): + pool = await create_db_pool(dsn=dsn) + result = await list_tools( + developer_id=developer_id, + agent_id=agent.id, + connection_pool=pool, + ) + + assert result is not None + assert all(isinstance(tool, Tool) for tool in result) + + +@test("query: patch tool") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, tool=test_tool +): + pool = await create_db_pool(dsn=dsn) + patch_data = PatchToolRequest( + **{ + "name": "patched_tool", + "function": { + "description": "A patched function that prints hello world", + "parameters": {"param1": "value1"}, + }, + } + ) + + result = await patch_tool( + developer_id=developer_id, + agent_id=agent.id, + tool_id=tool.id, + data=patch_data, + connection_pool=pool, + ) + + assert result is not None + + tool = await get_tool( + developer_id=developer_id, + agent_id=agent.id, + tool_id=tool.id, + connection_pool=pool, + ) + + assert tool.name == "patched_tool" + assert tool.function.description == "A patched function that prints hello world" + assert tool.function.parameters + + +@test("query: update tool") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, tool=test_tool +): + pool = await create_db_pool(dsn=dsn) + update_data = UpdateToolRequest( + name="updated_tool", + description="An updated description", + type="function", + function={ + "description": "An updated function that prints hello world", + }, + ) + + result = await update_tool( + developer_id=developer_id, + agent_id=agent.id, + tool_id=tool.id, + data=update_data, + connection_pool=pool, + ) + + assert result is not None + + tool = await get_tool( + developer_id=developer_id, + agent_id=agent.id, + tool_id=tool.id, + connection_pool=pool, + ) + + assert tool.name == "updated_tool" + assert not tool.function.parameters From 128cc2fa6031cddb1aa63e4972f95ff66d54ca07 Mon Sep 17 00:00:00 2001 From: whiterabbit1983 Date: Mon, 23 Dec 2024 11:44:18 +0000 Subject: [PATCH 19/21] refactor: Lint agents-api (CI) --- agents-api/tests/fixtures.py | 9 +++------ agents-api/tests/test_tool_queries.py | 4 ++-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index b342cd0b7..a98fef531 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -12,28 +12,25 @@ CreateFileRequest, CreateSessionRequest, CreateTaskRequest, - CreateUserRequest, CreateToolRequest, + CreateUserRequest, ) from agents_api.clients.pg import create_db_pool from agents_api.env import api_key, api_key_header_name, multi_tenant_mode from agents_api.queries.agents.create_agent import create_agent -from agents_api.queries.developers.create_developer import create_developer - from agents_api.queries.agents.delete_agent import delete_agent +from agents_api.queries.developers.create_developer import create_developer from agents_api.queries.developers.get_developer import get_developer from agents_api.queries.docs.create_doc import create_doc - from agents_api.queries.docs.delete_doc import delete_doc + # from agents_api.queries.executions.create_execution import create_execution # from agents_api.queries.executions.create_execution_transition import create_execution_transition # from agents_api.queries.executions.create_temporal_lookup import create_temporal_lookup from agents_api.queries.files.create_file import create_file - from agents_api.queries.files.delete_file import delete_file from agents_api.queries.sessions.create_session import create_session from agents_api.queries.tasks.create_task import create_task - from agents_api.queries.tasks.delete_task import delete_task from agents_api.queries.tools.create_tools import create_tools from agents_api.queries.tools.delete_tool import delete_tool diff --git a/agents-api/tests/test_tool_queries.py b/agents-api/tests/test_tool_queries.py index 43bdf8159..12698e1be 100644 --- a/agents-api/tests/test_tool_queries.py +++ b/agents-api/tests/test_tool_queries.py @@ -8,14 +8,14 @@ Tool, UpdateToolRequest, ) +from agents_api.clients.pg import create_db_pool from agents_api.queries.tools.create_tools import create_tools from agents_api.queries.tools.delete_tool import delete_tool from agents_api.queries.tools.get_tool import get_tool from agents_api.queries.tools.list_tools import list_tools from agents_api.queries.tools.patch_tool import patch_tool from agents_api.queries.tools.update_tool import update_tool -from tests.fixtures import test_agent, test_developer_id, pg_dsn, test_tool -from agents_api.clients.pg import create_db_pool +from tests.fixtures import pg_dsn, test_agent, test_developer_id, test_tool @test("query: create tool") From 3d6d02344b549c64dc236269f440d7b8f2a4ef7a Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Mon, 23 Dec 2024 14:48:55 +0300 Subject: [PATCH 20/21] fix: Fix awaitable and type hint --- agents-api/agents_api/queries/tools/delete_tool.py | 2 +- agents-api/tests/test_tool_queries.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/agents-api/agents_api/queries/tools/delete_tool.py b/agents-api/agents_api/queries/tools/delete_tool.py index c67cdaba5..91a57bd2f 100644 --- a/agents-api/agents_api/queries/tools/delete_tool.py +++ b/agents-api/agents_api/queries/tools/delete_tool.py @@ -49,7 +49,7 @@ async def delete_tool( developer_id: UUID, agent_id: UUID, tool_id: UUID, -) -> tuple[list[str], list]: +) -> tuple[str, list] | tuple[str, list, str]: developer_id = str(developer_id) agent_id = str(agent_id) tool_id = str(tool_id) diff --git a/agents-api/tests/test_tool_queries.py b/agents-api/tests/test_tool_queries.py index 12698e1be..5056f03ca 100644 --- a/agents-api/tests/test_tool_queries.py +++ b/agents-api/tests/test_tool_queries.py @@ -66,7 +66,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): connection_pool=pool, ) - result = delete_tool( + result = await delete_tool( developer_id=developer_id, agent_id=agent.id, tool_id=tool.id, From 1c97bc3a0ae41af6fed18d6571397ba43fe1e795 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Mon, 23 Dec 2024 15:00:21 +0300 Subject: [PATCH 21/21] chore: Update type annotations --- agents-api/agents_api/queries/tools/create_tools.py | 2 +- agents-api/agents_api/queries/tools/delete_tool.py | 2 +- agents-api/agents_api/queries/tools/get_tool.py | 2 +- .../agents_api/queries/tools/get_tool_args_from_metadata.py | 2 +- agents-api/agents_api/queries/tools/list_tools.py | 2 +- agents-api/agents_api/queries/tools/patch_tool.py | 2 +- agents-api/agents_api/queries/tools/update_tool.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/agents-api/agents_api/queries/tools/create_tools.py b/agents-api/agents_api/queries/tools/create_tools.py index 075497541..70b0525a8 100644 --- a/agents-api/agents_api/queries/tools/create_tools.py +++ b/agents-api/agents_api/queries/tools/create_tools.py @@ -74,7 +74,7 @@ async def create_tools( agent_id: UUID, data: list[CreateToolRequest], ignore_existing: bool = False, # TODO: what to do with this flag? -) -> tuple[str, list] | tuple[str, list, str]: +) -> tuple[str, list, str]: """ Constructs a datalog query for inserting tool records into the 'agent_functions' relation in the CozoDB. diff --git a/agents-api/agents_api/queries/tools/delete_tool.py b/agents-api/agents_api/queries/tools/delete_tool.py index 91a57bd2f..cd666ee42 100644 --- a/agents-api/agents_api/queries/tools/delete_tool.py +++ b/agents-api/agents_api/queries/tools/delete_tool.py @@ -49,7 +49,7 @@ async def delete_tool( developer_id: UUID, agent_id: UUID, tool_id: UUID, -) -> tuple[str, list] | tuple[str, list, str]: +) -> tuple[str, list]: developer_id = str(developer_id) agent_id = str(agent_id) tool_id = str(tool_id) diff --git a/agents-api/agents_api/queries/tools/get_tool.py b/agents-api/agents_api/queries/tools/get_tool.py index 7581714e9..29a7ae9b6 100644 --- a/agents-api/agents_api/queries/tools/get_tool.py +++ b/agents-api/agents_api/queries/tools/get_tool.py @@ -50,7 +50,7 @@ async def get_tool( developer_id: UUID, agent_id: UUID, tool_id: UUID, -) -> tuple[str, list] | tuple[str, list, str]: +) -> tuple[str, list]: developer_id = str(developer_id) agent_id = str(agent_id) tool_id = str(tool_id) diff --git a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py index f4caf5524..8d53a4e1b 100644 --- a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py +++ b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py @@ -75,7 +75,7 @@ async def get_tool_args_from_metadata( task_id: UUID | None = None, tool_type: Literal["integration", "api_call"] = "integration", arg_type: Literal["args", "setup", "headers"] = "args", -) -> tuple[str, list] | tuple[str, list, str]: +) -> tuple[str, list]: match session_id, task_id: case (None, task_id) if task_id is not None: return ( diff --git a/agents-api/agents_api/queries/tools/list_tools.py b/agents-api/agents_api/queries/tools/list_tools.py index 01460e16b..cdc82d9bd 100644 --- a/agents-api/agents_api/queries/tools/list_tools.py +++ b/agents-api/agents_api/queries/tools/list_tools.py @@ -60,7 +60,7 @@ async def list_tools( offset: int = 0, sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", -) -> tuple[str, list] | tuple[str, list, str]: +) -> tuple[str, list]: developer_id = str(developer_id) agent_id = str(agent_id) diff --git a/agents-api/agents_api/queries/tools/patch_tool.py b/agents-api/agents_api/queries/tools/patch_tool.py index a8adf1fa6..e0a20dc1d 100644 --- a/agents-api/agents_api/queries/tools/patch_tool.py +++ b/agents-api/agents_api/queries/tools/patch_tool.py @@ -54,7 +54,7 @@ @beartype async def patch_tool( *, developer_id: UUID, agent_id: UUID, tool_id: UUID, data: PatchToolRequest -) -> tuple[str, list] | tuple[str, list, str]: +) -> tuple[str, list]: """ Execute the datalog query and return the results as a DataFrame Updates the tool information for a given agent and tool ID in the 'cozodb' database. diff --git a/agents-api/agents_api/queries/tools/update_tool.py b/agents-api/agents_api/queries/tools/update_tool.py index bb1d8dc87..2b8beb155 100644 --- a/agents-api/agents_api/queries/tools/update_tool.py +++ b/agents-api/agents_api/queries/tools/update_tool.py @@ -58,7 +58,7 @@ async def update_tool( tool_id: UUID, data: UpdateToolRequest, **kwargs, -) -> tuple[str, list] | tuple[str, list, str]: +) -> tuple[str, list]: developer_id = str(developer_id) agent_id = str(agent_id) tool_id = str(tool_id)