diff --git a/agents-api/agents_api/exceptions.py b/agents-api/agents_api/exceptions.py index 615958a87..f6fcc4741 100644 --- a/agents-api/agents_api/exceptions.py +++ b/agents-api/agents_api/exceptions.py @@ -49,3 +49,12 @@ class FailedEncodingSentinel: """Sentinel object returned when failed to encode payload.""" payload_data: bytes + + +class QueriesBaseException(AgentsBaseException): + pass + + +class InvalidSQLQuery(QueriesBaseException): + def __init__(self, query_name: str): + super().__init__(f"invalid query: {query_name}") diff --git a/agents-api/agents_api/models/chat/get_cached_response.py b/agents-api/agents_api/models/chat/get_cached_response.py deleted file mode 100644 index 368c88567..000000000 --- a/agents-api/agents_api/models/chat/get_cached_response.py +++ /dev/null @@ -1,15 +0,0 @@ -from beartype import beartype - -from ..utils import cozo_query - - -@cozo_query -@beartype -def get_cached_response(key: str) -> tuple[str, dict]: - query = """ - input[key] <- [[$key]] - ?[key, value] := input[key], *session_cache{key, value} - :limit 1 - """ - - return (query, {"key": key}) diff --git a/agents-api/agents_api/models/chat/prepare_chat_context.py b/agents-api/agents_api/models/chat/prepare_chat_context.py deleted file mode 100644 index f77686d7a..000000000 --- a/agents-api/agents_api/models/chat/prepare_chat_context.py +++ /dev/null @@ -1,143 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...common.protocol.sessions import ChatContext, make_session -from ..session.prepare_session_data import prepare_session_data -from ..utils import ( - cozo_query, - fix_uuid_if_present, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ChatContext, - one=True, - transform=lambda d: { - **d, - "session": make_session( - agents=[a["id"] for a in d["agents"]], - users=[u["id"] for u in d["users"]], - **d["session"], - ), - "toolsets": [ - { - **ts, - "tools": [ - { - tool["type"]: tool.pop("spec"), - **tool, - } - for tool in map(fix_uuid_if_present, ts["tools"]) - ], - } - for ts in d["toolsets"] - ], - }, -) -@cozo_query -@beartype -def prepare_chat_context( - *, - developer_id: UUID, - session_id: UUID, -) -> tuple[list[str], dict]: - """ - Executes a complex query to retrieve memory context based on session ID. - """ - - [*_, session_data_query], sd_vars = prepare_session_data.__wrapped__( - developer_id=developer_id, session_id=session_id - ) - - session_data_fields = ("session", "agents", "users") - - session_data_query += """ - :create _session_data_json { - agents: [Json], - users: [Json], - session: Json, - } - """ - - toolsets_query = """ - input[session_id] <- [[to_uuid($session_id)]] - - tools_by_agent[agent_id, collect(tool)] := - input[session_id], - *session_lookup{ - session_id, - participant_id: agent_id, - participant_type: "agent", - }, - - *tools { agent_id, tool_id, name, type, spec, description, updated_at, created_at }, - tool = { - "id": tool_id, - "name": name, - "type": type, - "spec": spec, - "description": description, - "updated_at": updated_at, - "created_at": created_at, - } - - agent_toolsets[collect(toolset)] := - tools_by_agent[agent_id, tools], - toolset = { - "agent_id": agent_id, - "tools": tools, - } - - ?[toolsets] := - agent_toolsets[toolsets] - - :create _toolsets_json { - toolsets: [Json], - } - """ - - combine_query = f""" - ?[{', '.join(session_data_fields)}, toolsets] := - *_session_data_json {{ {', '.join(session_data_fields)} }}, - *_toolsets_json {{ toolsets }} - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - session_data_query, - toolsets_query, - combine_query, - ] - - return ( - queries, - { - "session_id": str(session_id), - **sd_vars, - }, - ) diff --git a/agents-api/agents_api/models/chat/set_cached_response.py b/agents-api/agents_api/models/chat/set_cached_response.py deleted file mode 100644 index 8625f3f1b..000000000 --- a/agents-api/agents_api/models/chat/set_cached_response.py +++ /dev/null @@ -1,19 +0,0 @@ -from beartype import beartype - -from ..utils import cozo_query - - -@cozo_query -@beartype -def set_cached_response(key: str, value: dict) -> tuple[str, dict]: - query = """ - ?[key, value] <- [[$key, $value]] - - :insert session_cache { - key => value - } - - :returning - """ - - return (query, {"key": key, "value": value}) diff --git a/agents-api/agents_api/models/tools/create_tools.py b/agents-api/agents_api/models/tools/create_tools.py deleted file mode 100644 index 578a1268d..000000000 --- a/agents-api/agents_api/models/tools/create_tools.py +++ /dev/null @@ -1,136 +0,0 @@ -"""This module contains functions for creating tools in the CozoDB database.""" - -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError -from uuid_extensions import uuid7 - -from ...autogen.openapi_model import CreateToolRequest, Tool -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - AssertionError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Tool, - transform=lambda d: { - "id": UUID(d.pop("tool_id")), - d["type"]: d.pop("spec"), - **d, - }, - _kind="inserted", -) -@cozo_query -@increase_counter("create_tools") -@beartype -def create_tools( - *, - developer_id: UUID, - agent_id: UUID, - data: list[CreateToolRequest], - ignore_existing: bool = False, -) -> tuple[list[str], dict]: - """ - Constructs a datalog query for inserting tool records into the 'agent_functions' relation in the CozoDB. - - Parameters: - agent_id (UUID): The unique identifier for the agent. - data (list[CreateToolRequest]): A list of function definitions to be inserted. - - Returns: - list[Tool] - """ - - assert all( - getattr(tool, tool.type) is not None - for tool in data - if hasattr(tool, tool.type) - ), "Tool spec must be passed" - - tools_data = [ - [ - str(agent_id), - str(uuid7()), - tool.type, - tool.name, - getattr(tool, tool.type) and getattr(tool, tool.type).model_dump(), - tool.description if hasattr(tool, "description") else None, - ] - for tool in data - ] - - ensure_tool_name_unique_query = """ - input[agent_id, tool_id, type, name, spec, description] <- $records - ?[tool_id] := - input[agent_id, _, type, name, _, _], - *tools{ - agent_id: to_uuid(agent_id), - tool_id, - type, - name, - spec, - description, - } - - :limit 1 - :assert none - """ - - # Datalog query for inserting new tool records into the 'tools' relation - create_query = """ - input[agent_id, tool_id, type, name, spec, description] <- $records - - # Do not add duplicate - ?[agent_id, tool_id, type, name, spec, description] := - input[agent_id, tool_id, type, name, spec, description], - not *tools{ - agent_id: to_uuid(agent_id), - type, - name, - } - - :insert tools { - agent_id, - tool_id, - type, - name, - spec, - description, - } - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - create_query, - ] - - if not ignore_existing: - queries.insert( - -1, - ensure_tool_name_unique_query, - ) - - return (queries, {"records": tools_data}) diff --git a/agents-api/agents_api/models/tools/delete_tool.py b/agents-api/agents_api/models/tools/delete_tool.py deleted file mode 100644 index c79cdfd29..000000000 --- a/agents-api/agents_api/models/tools/delete_tool.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ResourceDeletedResponse -from ...common.utils.datetime import utcnow -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceDeletedResponse, - one=True, - transform=lambda d: {"id": d["tool_id"], "deleted_at": utcnow(), "jobs": [], **d}, - _kind="deleted", -) -@cozo_query -@beartype -def delete_tool( - *, - developer_id: UUID, - agent_id: UUID, - tool_id: UUID, -) -> tuple[list[str], dict]: - agent_id = str(agent_id) - tool_id = str(tool_id) - - delete_query = """ - # Delete function - ?[tool_id, agent_id] <- [[ - to_uuid($tool_id), - to_uuid($agent_id), - ]] - - :delete tools { - tool_id, - agent_id, - } - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - delete_query, - ] - - return (queries, {"tool_id": tool_id, "agent_id": agent_id}) diff --git a/agents-api/agents_api/models/tools/get_tool.py b/agents-api/agents_api/models/tools/get_tool.py deleted file mode 100644 index 465fd2efe..000000000 --- a/agents-api/agents_api/models/tools/get_tool.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Tool -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Tool, - transform=lambda d: { - "id": UUID(d.pop("tool_id")), - d["type"]: d.pop("spec"), - **d, - }, - one=True, -) -@cozo_query -@beartype -def get_tool( - *, - developer_id: UUID, - agent_id: UUID, - tool_id: UUID, -) -> tuple[list[str], dict]: - agent_id = str(agent_id) - tool_id = str(tool_id) - - get_query = """ - input[agent_id, tool_id] <- [[to_uuid($agent_id), to_uuid($tool_id)]] - - ?[ - agent_id, - tool_id, - type, - name, - spec, - updated_at, - created_at, - ] := input[agent_id, tool_id], - *tools { - agent_id, - tool_id, - name, - type, - spec, - updated_at, - created_at, - } - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - get_query, - ] - - return (queries, {"agent_id": agent_id, "tool_id": tool_id}) diff --git a/agents-api/agents_api/models/tools/get_tool_args_from_metadata.py b/agents-api/agents_api/models/tools/get_tool_args_from_metadata.py deleted file mode 100644 index 2cdb92cb9..000000000 --- a/agents-api/agents_api/models/tools/get_tool_args_from_metadata.py +++ /dev/null @@ -1,148 +0,0 @@ -from typing import Literal -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - - -def tool_args_for_task( - *, - developer_id: UUID, - agent_id: UUID, - task_id: UUID, - tool_type: Literal["integration", "api_call"] = "integration", - arg_type: Literal["args", "setup"] = "args", -) -> tuple[list[str], dict]: - agent_id = str(agent_id) - task_id = str(task_id) - - get_query = f""" - input[agent_id, task_id] <- [[to_uuid($agent_id), to_uuid($task_id)]] - - ?[values] := - input[agent_id, task_id], - *tasks {{ - task_id, - metadata: task_metadata, - }}, - *agents {{ - agent_id, - metadata: agent_metadata, - }}, - task_{arg_type} = get(task_metadata, "x-{tool_type}-{arg_type}", {{}}), - agent_{arg_type} = get(agent_metadata, "x-{tool_type}-{arg_type}", {{}}), - - # Right values overwrite left values - # See: https://docs.cozodb.org/en/latest/functions.html#Func.Vector.concat - values = concat(agent_{arg_type}, task_{arg_type}), - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "tasks", task_id=task_id, parents=[("agents", "agent_id")] - ), - get_query, - ] - - return (queries, {"agent_id": agent_id, "task_id": task_id}) - - -def tool_args_for_session( - *, - developer_id: UUID, - session_id: UUID, - agent_id: UUID, - arg_type: Literal["args", "setup"] = "args", - tool_type: Literal["integration", "api_call"] = "integration", -) -> tuple[list[str], dict]: - session_id = str(session_id) - - get_query = f""" - input[session_id, agent_id] <- [[to_uuid($session_id), to_uuid($agent_id)]] - - ?[values] := - input[session_id, agent_id], - *sessions {{ - session_id, - metadata: session_metadata, - }}, - *agents {{ - agent_id, - metadata: agent_metadata, - }}, - session_{arg_type} = get(session_metadata, "x-{tool_type}-{arg_type}", {{}}), - agent_{arg_type} = get(agent_metadata, "x-{tool_type}-{arg_type}", {{}}), - - # Right values overwrite left values - # See: https://docs.cozodb.org/en/latest/functions.html#Func.Vector.concat - values = concat(agent_{arg_type}, session_{arg_type}), - - :limit 1 - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query( - developer_id, "sessions", session_id=session_id - ), - get_query, - ] - - return (queries, {"agent_id": agent_id, "session_id": session_id}) - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(dict, transform=lambda x: x["values"], one=True) -@cozo_query -@beartype -def get_tool_args_from_metadata( - *, - developer_id: UUID, - agent_id: UUID, - session_id: UUID | None = None, - task_id: UUID | None = None, - tool_type: Literal["integration", "api_call"] = "integration", - arg_type: Literal["args", "setup", "headers"] = "args", -) -> tuple[list[str], dict]: - common: dict = dict( - developer_id=developer_id, - agent_id=agent_id, - tool_type=tool_type, - arg_type=arg_type, - ) - - match session_id, task_id: - case (None, task_id) if task_id is not None: - return tool_args_for_task( - **common, - task_id=task_id, - ) - - case (session_id, None) if session_id is not None: - return tool_args_for_session( - **common, - session_id=session_id, - ) - - case (_, _): - raise ValueError("Either session_id or task_id must be provided") diff --git a/agents-api/agents_api/models/tools/list_tools.py b/agents-api/agents_api/models/tools/list_tools.py deleted file mode 100644 index 727bf8028..000000000 --- a/agents-api/agents_api/models/tools/list_tools.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Any, Literal, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import Tool -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - Tool, - transform=lambda d: { - d["type"]: { - **d.pop("spec"), - "name": d["name"], - "description": d["description"], - }, - **d, - }, -) -@cozo_query -@beartype -def list_tools( - *, - developer_id: UUID, - agent_id: UUID, - limit: int = 100, - offset: int = 0, - sort_by: Literal["created_at", "updated_at"] = "created_at", - direction: Literal["asc", "desc"] = "desc", -) -> tuple[list[str], dict]: - agent_id = str(agent_id) - - sort = f"{'-' if direction == 'desc' else ''}{sort_by}" - - list_query = f""" - input[agent_id] <- [[to_uuid($agent_id)]] - - ?[ - agent_id, - id, - name, - type, - spec, - description, - updated_at, - created_at, - ] := input[agent_id], - *tools {{ - agent_id, - tool_id: id, - name, - type, - spec, - description, - updated_at, - created_at, - }} - - :limit $limit - :offset $offset - :sort {sort} - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - list_query, - ] - - return ( - queries, - {"agent_id": agent_id, "limit": limit, "offset": offset}, - ) diff --git a/agents-api/agents_api/models/tools/update_tool.py b/agents-api/agents_api/models/tools/update_tool.py deleted file mode 100644 index ef700a5f6..000000000 --- a/agents-api/agents_api/models/tools/update_tool.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import Any, TypeVar -from uuid import UUID - -from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError - -from ...autogen.openapi_model import ( - ResourceUpdatedResponse, - UpdateToolRequest, -) -from ...common.utils.cozo import cozo_process_mutate_data -from ...metrics.counters import increase_counter -from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, - wrap_in_class, -) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - - -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class( - ResourceUpdatedResponse, - one=True, - transform=lambda d: {"id": d["tool_id"], "jobs": [], **d}, - _kind="inserted", -) -@cozo_query -@increase_counter("update_tool") -@beartype -def update_tool( - *, - developer_id: UUID, - agent_id: UUID, - tool_id: UUID, - data: UpdateToolRequest, - **kwargs, -) -> tuple[list[str], dict]: - agent_id = str(agent_id) - tool_id = str(tool_id) - - # Extract the tool data from the payload - update_data = data.model_dump(exclude_none=True) - - # Assert that only one of the tool type fields is present - tool_specs = [ - (tool_type, update_data.get(tool_type)) - for tool_type in ["function", "integration", "system", "api_call"] - if update_data.get(tool_type) is not None - ] - - assert len(tool_specs) <= 1, "Invalid tool update" - tool_type, tool_spec = tool_specs[0] if tool_specs else (None, None) - - if tool_type is not None: - update_data["type"] = update_data.get("type", tool_type) - assert update_data["type"] == tool_type, "Invalid tool update" - - update_data["spec"] = tool_spec - del update_data[tool_type] - - tool_cols, tool_vals = cozo_process_mutate_data( - { - **update_data, - "agent_id": agent_id, - "tool_id": tool_id, - } - ) - - # Construct the datalog query for updating the tool information - patch_query = f""" - input[{tool_cols}] <- $input - - ?[{tool_cols}, created_at, updated_at] := - *tools {{ - agent_id: to_uuid($agent_id), - tool_id: to_uuid($tool_id), - created_at - }}, - input[{tool_cols}], - updated_at = now() - - :put tools {{ {tool_cols}, created_at, updated_at }} - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - patch_query, - ] - - return ( - queries, - dict(input=tool_vals, spec=tool_spec, agent_id=agent_id, tool_id=tool_id), - ) diff --git a/agents-api/agents_api/models/chat/__init__.py b/agents-api/agents_api/queries/chat/__init__.py similarity index 92% rename from agents-api/agents_api/models/chat/__init__.py rename to agents-api/agents_api/queries/chat/__init__.py index 428b72572..2c05b4f8b 100644 --- a/agents-api/agents_api/models/chat/__init__.py +++ b/agents-api/agents_api/queries/chat/__init__.py @@ -17,6 +17,4 @@ # ruff: noqa: F401, F403, F405 from .gather_messages import gather_messages -from .get_cached_response import get_cached_response from .prepare_chat_context import prepare_chat_context -from .set_cached_response import set_cached_response diff --git a/agents-api/agents_api/models/chat/gather_messages.py b/agents-api/agents_api/queries/chat/gather_messages.py similarity index 95% rename from agents-api/agents_api/models/chat/gather_messages.py rename to agents-api/agents_api/queries/chat/gather_messages.py index 28dc6607f..cbf3bf209 100644 --- a/agents-api/agents_api/models/chat/gather_messages.py +++ b/agents-api/agents_api/queries/chat/gather_messages.py @@ -3,7 +3,6 @@ from beartype import beartype from fastapi import HTTPException -from pycozo.client import QueryException from pydantic import ValidationError from ...autogen.openapi_model import ChatInput, DocReference, History @@ -13,8 +12,8 @@ from ..docs.search_docs_by_embedding import search_docs_by_embedding from ..docs.search_docs_by_text import search_docs_by_text from ..docs.search_docs_hybrid import search_docs_hybrid -from ..entry.get_history import get_history -from ..session.get_session import get_session +from ..entries.get_history import get_history +from ..sessions.get_session import get_session from ..utils import ( partialclass, rewrap_exceptions, @@ -25,7 +24,6 @@ @rewrap_exceptions( { - QueryException: partialclass(HTTPException, status_code=400), ValidationError: partialclass(HTTPException, status_code=400), TypeError: partialclass(HTTPException, status_code=400), } diff --git a/agents-api/agents_api/queries/chat/prepare_chat_context.py b/agents-api/agents_api/queries/chat/prepare_chat_context.py new file mode 100644 index 000000000..1d9bd52fb --- /dev/null +++ b/agents-api/agents_api/queries/chat/prepare_chat_context.py @@ -0,0 +1,169 @@ +from typing import Any, TypeVar +from uuid import UUID + +import sqlvalidator +from beartype import beartype + +from ...common.protocol.sessions import ChatContext, make_session +from ...exceptions import InvalidSQLQuery +from ..utils import ( + pg_query, + wrap_in_class, +) + +ModelT = TypeVar("ModelT", bound=Any) +T = TypeVar("T") + + +sql_query = sqlvalidator.parse( + """SELECT * FROM +( + SELECT jsonb_agg(u) AS users FROM ( + SELECT + session_lookup.participant_id, + users.user_id AS id, + users.developer_id, + users.name, + users.about, + users.created_at, + users.updated_at, + users.metadata + FROM session_lookup + INNER JOIN users ON session_lookup.participant_id = users.user_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'user' + ) u +) AS users, +( + SELECT jsonb_agg(a) AS agents FROM ( + SELECT + session_lookup.participant_id, + agents.agent_id AS id, + agents.developer_id, + agents.canonical_name, + agents.name, + agents.about, + agents.instructions, + agents.model, + agents.created_at, + agents.updated_at, + agents.metadata, + agents.default_settings + FROM session_lookup + INNER JOIN agents ON session_lookup.participant_id = agents.agent_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'agent' + ) a +) AS agents, +( + SELECT to_jsonb(s) AS session FROM ( + SELECT + sessions.session_id AS id, + sessions.developer_id, + sessions.situation, + sessions.system_template, + sessions.created_at, + sessions.metadata, + sessions.render_templates, + sessions.token_budget, + sessions.context_overflow, + sessions.forward_tool_calls, + sessions.recall_options + FROM sessions + WHERE + developer_id = $1 AND + session_id = $2 + LIMIT 1 + ) s +) AS session, +( + SELECT jsonb_agg(r) AS toolsets FROM ( + SELECT + session_lookup.participant_id, + tools.tool_id as id, + tools.developer_id, + tools.agent_id, + tools.task_id, + tools.task_version, + tools.type, + tools.name, + tools.description, + tools.spec, + tools.updated_at, + tools.created_at + FROM session_lookup + INNER JOIN tools ON session_lookup.participant_id = tools.agent_id + WHERE + session_lookup.developer_id = $1 AND + session_id = $2 AND + session_lookup.participant_type = 'agent' + ) r +) AS toolsets""" +) +if not sql_query.is_valid(): + raise InvalidSQLQuery("prepare_chat_context") + + +def _transform(d): + toolsets = {} + for tool in d["toolsets"]: + agent_id = tool["agent_id"] + if agent_id in toolsets: + toolsets[agent_id].append(tool) + else: + toolsets[agent_id] = [tool] + + return { + **d, + "session": make_session( + agents=[a["id"] for a in d["agents"]], + users=[u["id"] for u in d["users"]], + **d["session"], + ), + "toolsets": [ + { + "agent_id": agent_id, + "tools": [ + { + tool["type"]: tool.pop("spec"), + **tool, + } + for tool in tools + ], + } + for agent_id, tools in toolsets.items() + ], + } + + +# TODO: implement this part +# @rewrap_exceptions( +# { +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) +@wrap_in_class( + ChatContext, + one=True, + transform=_transform, +) +@pg_query +@beartype +async def prepare_chat_context( + *, + developer_id: UUID, + session_id: UUID, +) -> tuple[list[str], list]: + """ + Executes a complex query to retrieve memory context based on session ID. + """ + + return ( + [sql_query.format()], + [developer_id, session_id], + ) diff --git a/agents-api/agents_api/models/tools/__init__.py b/agents-api/agents_api/queries/tools/__init__.py similarity index 100% rename from agents-api/agents_api/models/tools/__init__.py rename to agents-api/agents_api/queries/tools/__init__.py diff --git a/agents-api/agents_api/queries/tools/create_tools.py b/agents-api/agents_api/queries/tools/create_tools.py new file mode 100644 index 000000000..70b0525a8 --- /dev/null +++ b/agents-api/agents_api/queries/tools/create_tools.py @@ -0,0 +1,112 @@ +"""This module contains functions for creating tools in the CozoDB database.""" + +from typing import Any, TypeVar +from uuid import UUID + +import sqlvalidator +from beartype import beartype +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateToolRequest, Tool +from ...exceptions import InvalidSQLQuery +from ...metrics.counters import increase_counter +from ..utils import ( + pg_query, + # rewrap_exceptions, + wrap_in_class, +) + +ModelT = TypeVar("ModelT", bound=Any) +T = TypeVar("T") + + +sql_query = """INSERT INTO tools +( + developer_id, + agent_id, + tool_id, + type, + name, + spec, + description +) +SELECT + $1, + $2, + $3, + $4, + $5, + $6, + $7 +WHERE NOT EXISTS ( + SELECT null FROM tools + WHERE (agent_id, name) = ($2, $5) +) +RETURNING * +""" + + +# if not sql_query.is_valid(): +# raise InvalidSQLQuery("create_tools") + + +# @rewrap_exceptions( +# { +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# AssertionError: partialclass(HTTPException, status_code=400), +# } +# ) +@wrap_in_class( + Tool, + transform=lambda d: { + "id": d.pop("tool_id"), + d["type"]: d.pop("spec"), + **d, + }, +) +@pg_query +@increase_counter("create_tools") +@beartype +async def create_tools( + *, + developer_id: UUID, + agent_id: UUID, + data: list[CreateToolRequest], + ignore_existing: bool = False, # TODO: what to do with this flag? +) -> tuple[str, list, str]: + """ + Constructs a datalog query for inserting tool records into the 'agent_functions' relation in the CozoDB. + + Parameters: + agent_id (UUID): The unique identifier for the agent. + data (list[CreateToolRequest]): A list of function definitions to be inserted. + + Returns: + list[Tool] + """ + + assert all( + getattr(tool, tool.type) is not None + for tool in data + if hasattr(tool, tool.type) + ), "Tool spec must be passed" + + tools_data = [ + [ + developer_id, + str(agent_id), + str(uuid7()), + tool.type, + tool.name, + getattr(tool, tool.type) and getattr(tool, tool.type).model_dump(), + tool.description if hasattr(tool, "description") else None, + ] + for tool in data + ] + + return ( + sql_query, + tools_data, + "fetchmany", + ) diff --git a/agents-api/agents_api/queries/tools/delete_tool.py b/agents-api/agents_api/queries/tools/delete_tool.py new file mode 100644 index 000000000..cd666ee42 --- /dev/null +++ b/agents-api/agents_api/queries/tools/delete_tool.py @@ -0,0 +1,64 @@ +from typing import Any, TypeVar +from uuid import UUID + +import sqlvalidator +from beartype import beartype + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...exceptions import InvalidSQLQuery +from ..utils import ( + pg_query, + wrap_in_class, +) + +ModelT = TypeVar("ModelT", bound=Any) +T = TypeVar("T") + + +sql_query = """ +DELETE FROM + tools +WHERE + developer_id = $1 AND + agent_id = $2 AND + tool_id = $3 +RETURNING * +""" + +# if not sql_query.is_valid(): +# raise InvalidSQLQuery("delete_tool") + + +# @rewrap_exceptions( +# { +# QueryException: partialclass(HTTPException, status_code=400), +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: {"id": d["tool_id"], "deleted_at": utcnow(), "jobs": [], **d}, +) +@pg_query +@beartype +async def delete_tool( + *, + developer_id: UUID, + agent_id: UUID, + tool_id: UUID, +) -> tuple[str, list]: + developer_id = str(developer_id) + agent_id = str(agent_id) + tool_id = str(tool_id) + + return ( + sql_query, + [ + developer_id, + agent_id, + tool_id, + ], + ) diff --git a/agents-api/agents_api/queries/tools/get_tool.py b/agents-api/agents_api/queries/tools/get_tool.py new file mode 100644 index 000000000..29a7ae9b6 --- /dev/null +++ b/agents-api/agents_api/queries/tools/get_tool.py @@ -0,0 +1,65 @@ +from typing import Any, TypeVar +from uuid import UUID + +import sqlvalidator +from beartype import beartype + +from ...autogen.openapi_model import Tool +from ...exceptions import InvalidSQLQuery +from ..utils import ( + pg_query, + wrap_in_class, +) + +ModelT = TypeVar("ModelT", bound=Any) +T = TypeVar("T") + +sql_query = """ +SELECT * FROM tools +WHERE + developer_id = $1 AND + agent_id = $2 AND + tool_id = $3 +LIMIT 1 +""" + +# if not sql_query.is_valid(): +# raise InvalidSQLQuery("get_tool") + + +# @rewrap_exceptions( +# { +# QueryException: partialclass(HTTPException, status_code=400), +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) +@wrap_in_class( + Tool, + transform=lambda d: { + "id": d.pop("tool_id"), + d["type"]: d.pop("spec"), + **d, + }, + one=True, +) +@pg_query +@beartype +async def get_tool( + *, + developer_id: UUID, + agent_id: UUID, + tool_id: UUID, +) -> tuple[str, list]: + developer_id = str(developer_id) + agent_id = str(agent_id) + tool_id = str(tool_id) + + return ( + sql_query, + [ + developer_id, + agent_id, + tool_id, + ], + ) diff --git a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py new file mode 100644 index 000000000..8d53a4e1b --- /dev/null +++ b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py @@ -0,0 +1,98 @@ +from typing import Literal +from uuid import UUID + +import sqlvalidator +from beartype import beartype + +from ...exceptions import InvalidSQLQuery +from ..utils import ( + pg_query, + wrap_in_class, +) + +tools_args_for_task_query = """SELECT COALESCE(agents_md || tasks_md, agents_md, tasks_md, '{}') as values FROM ( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS agents_md + FROM agents + WHERE agent_id = $1 AND developer_id = $4 LIMIT 1 +) AS agents_md, +( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md + FROM tasks + WHERE task_id = $2 AND developer_id = $4 LIMIT 1 +) AS tasks_md""" + + +# if not tools_args_for_task_query.is_valid(): +# raise InvalidSQLQuery("tools_args_for_task_query") + +tool_args_for_session_query = """SELECT COALESCE(agents_md || sessions_md, agents_md, sessions_md, '{}') as values FROM ( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS agents_md + FROM agents + WHERE agent_id = $1 AND developer_id = $4 LIMIT 1 +) AS agents_md, +( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md + FROM sessions + WHERE session_id = $2 AND developer_id = $4 LIMIT 1 +) AS sessions_md""" + + +# if not tool_args_for_session_query.is_valid(): +# raise InvalidSQLQuery("tool_args_for_session") + + +# @rewrap_exceptions( +# { +# QueryException: partialclass(HTTPException, status_code=400), +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) +@wrap_in_class(dict, transform=lambda x: x["values"], one=True) +@pg_query +@beartype +async def get_tool_args_from_metadata( + *, + developer_id: UUID, + agent_id: UUID, + session_id: UUID | None = None, + task_id: UUID | None = None, + tool_type: Literal["integration", "api_call"] = "integration", + arg_type: Literal["args", "setup", "headers"] = "args", +) -> tuple[str, list]: + match session_id, task_id: + case (None, task_id) if task_id is not None: + return ( + tools_args_for_task_query, + [ + agent_id, + task_id, + f"x-{tool_type}-{arg_type}", + developer_id, + ], + ) + + case (session_id, None) if session_id is not None: + return ( + tool_args_for_session_query, + [agent_id, session_id, f"x-{tool_type}-{arg_type}", developer_id], + ) + + case (_, _): + raise ValueError("Either session_id or task_id must be provided") diff --git a/agents-api/agents_api/queries/tools/list_tools.py b/agents-api/agents_api/queries/tools/list_tools.py new file mode 100644 index 000000000..cdc82d9bd --- /dev/null +++ b/agents-api/agents_api/queries/tools/list_tools.py @@ -0,0 +1,77 @@ +from typing import Any, Literal, TypeVar +from uuid import UUID + +import sqlvalidator +from beartype import beartype + +from ...autogen.openapi_model import Tool +from ...exceptions import InvalidSQLQuery +from ..utils import ( + pg_query, + wrap_in_class, +) + +ModelT = TypeVar("ModelT", bound=Any) +T = TypeVar("T") + +sql_query = """ +SELECT * FROM tools +WHERE + developer_id = $1 AND + agent_id = $2 +ORDER BY + CASE WHEN $5 = 'created_at' AND $6 = 'desc' THEN tools.created_at END DESC NULLS LAST, + CASE WHEN $5 = 'created_at' AND $6 = 'asc' THEN tools.created_at END ASC NULLS LAST, + CASE WHEN $5 = 'updated_at' AND $6 = 'desc' THEN tools.updated_at END DESC NULLS LAST, + CASE WHEN $5 = 'updated_at' AND $6 = 'asc' THEN tools.updated_at END ASC NULLS LAST +LIMIT $3 OFFSET $4; +""" + +# if not sql_query.is_valid(): +# raise InvalidSQLQuery("list_tools") + + +# @rewrap_exceptions( +# { +# QueryException: partialclass(HTTPException, status_code=400), +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) +@wrap_in_class( + Tool, + transform=lambda d: { + d["type"]: { + **d.pop("spec"), + "name": d["name"], + "description": d["description"], + }, + "id": d.pop("tool_id"), + **d, + }, +) +@pg_query +@beartype +async def list_tools( + *, + developer_id: UUID, + agent_id: UUID, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", +) -> tuple[str, list]: + developer_id = str(developer_id) + agent_id = str(agent_id) + + return ( + sql_query, + [ + developer_id, + agent_id, + limit, + offset, + sort_by, + direction, + ], + ) diff --git a/agents-api/agents_api/models/tools/patch_tool.py b/agents-api/agents_api/queries/tools/patch_tool.py similarity index 53% rename from agents-api/agents_api/models/tools/patch_tool.py rename to agents-api/agents_api/queries/tools/patch_tool.py index bc49b8121..e0a20dc1d 100644 --- a/agents-api/agents_api/models/tools/patch_tool.py +++ b/agents-api/agents_api/queries/tools/patch_tool.py @@ -1,20 +1,14 @@ from typing import Any, TypeVar from uuid import UUID +import sqlvalidator from beartype import beartype -from fastapi import HTTPException -from pycozo.client import QueryException -from pydantic import ValidationError from ...autogen.openapi_model import PatchToolRequest, ResourceUpdatedResponse -from ...common.utils.cozo import cozo_process_mutate_data +from ...exceptions import InvalidSQLQuery from ...metrics.counters import increase_counter from ..utils import ( - cozo_query, - partialclass, - rewrap_exceptions, - verify_developer_id_query, - verify_developer_owns_resource_query, + pg_query, wrap_in_class, ) @@ -22,25 +16,45 @@ T = TypeVar("T") -@rewrap_exceptions( - { - QueryException: partialclass(HTTPException, status_code=400), - ValidationError: partialclass(HTTPException, status_code=400), - TypeError: partialclass(HTTPException, status_code=400), - } +sql_query = """ +WITH updated_tools AS ( + UPDATE tools + SET + type = COALESCE($4, type), + name = COALESCE($5, name), + description = COALESCE($6, description), + spec = COALESCE($7, spec) + WHERE + developer_id = $1 AND + agent_id = $2 AND + tool_id = $3 + RETURNING * ) +SELECT * FROM updated_tools; +""" + +# if not sql_query.is_valid(): +# raise InvalidSQLQuery("patch_tool") + + +# @rewrap_exceptions( +# { +# QueryException: partialclass(HTTPException, status_code=400), +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) @wrap_in_class( ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["tool_id"], "jobs": [], **d}, - _kind="inserted", ) -@cozo_query +@pg_query @increase_counter("patch_tool") @beartype -def patch_tool( +async def patch_tool( *, developer_id: UUID, agent_id: UUID, tool_id: UUID, data: PatchToolRequest -) -> tuple[list[str], dict]: +) -> tuple[str, list]: """ Execute the datalog query and return the results as a DataFrame Updates the tool information for a given agent and tool ID in the 'cozodb' database. @@ -54,6 +68,7 @@ def patch_tool( ResourceUpdatedResponse: The updated tool data. """ + developer_id = str(developer_id) agent_id = str(agent_id) tool_id = str(tool_id) @@ -78,39 +93,15 @@ def patch_tool( if tool_spec: del patch_data[tool_type] - tool_cols, tool_vals = cozo_process_mutate_data( - { - **patch_data, - "agent_id": agent_id, - "tool_id": tool_id, - } - ) - - # Construct the datalog query for updating the tool information - patch_query = f""" - input[{tool_cols}] <- $input - - ?[{tool_cols}, spec, updated_at] := - *tools {{ - agent_id: to_uuid($agent_id), - tool_id: to_uuid($tool_id), - spec: old_spec, - }}, - input[{tool_cols}], - spec = concat(old_spec, $spec), - updated_at = now() - - :update tools {{ {tool_cols}, spec, updated_at }} - :returning - """ - - queries = [ - verify_developer_id_query(developer_id), - verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id), - patch_query, - ] - return ( - queries, - dict(input=tool_vals, spec=tool_spec, agent_id=agent_id, tool_id=tool_id), + sql_query, + [ + developer_id, + agent_id, + tool_id, + tool_type, + data.name, + data.description, + tool_spec, + ], ) diff --git a/agents-api/agents_api/queries/tools/update_tool.py b/agents-api/agents_api/queries/tools/update_tool.py new file mode 100644 index 000000000..2b8beb155 --- /dev/null +++ b/agents-api/agents_api/queries/tools/update_tool.py @@ -0,0 +1,97 @@ +from typing import Any, TypeVar +from uuid import UUID + +import sqlvalidator +from beartype import beartype + +from ...autogen.openapi_model import ( + ResourceUpdatedResponse, + UpdateToolRequest, +) +from ...exceptions import InvalidSQLQuery +from ...metrics.counters import increase_counter +from ..utils import ( + pg_query, + wrap_in_class, +) + +ModelT = TypeVar("ModelT", bound=Any) +T = TypeVar("T") + +sql_query = """ +UPDATE tools +SET + type = $4, + name = $5, + description = $6, + spec = $7 +WHERE + developer_id = $1 AND + agent_id = $2 AND + tool_id = $3 +RETURNING *; +""" + +# if not sql_query.is_valid(): +# raise InvalidSQLQuery("update_tool") + + +# @rewrap_exceptions( +# { +# QueryException: partialclass(HTTPException, status_code=400), +# ValidationError: partialclass(HTTPException, status_code=400), +# TypeError: partialclass(HTTPException, status_code=400), +# } +# ) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: {"id": d["tool_id"], "jobs": [], **d}, +) +@pg_query +@increase_counter("update_tool") +@beartype +async def update_tool( + *, + developer_id: UUID, + agent_id: UUID, + tool_id: UUID, + data: UpdateToolRequest, + **kwargs, +) -> tuple[str, list]: + developer_id = str(developer_id) + agent_id = str(agent_id) + tool_id = str(tool_id) + + # Extract the tool data from the payload + update_data = data.model_dump(exclude_none=True) + + # Assert that only one of the tool type fields is present + tool_specs = [ + (tool_type, update_data.get(tool_type)) + for tool_type in ["function", "integration", "system", "api_call"] + if update_data.get(tool_type) is not None + ] + + assert len(tool_specs) <= 1, "Invalid tool update" + tool_type, tool_spec = tool_specs[0] if tool_specs else (None, None) + + if tool_type is not None: + update_data["type"] = update_data.get("type", tool_type) + assert update_data["type"] == tool_type, "Invalid tool update" + + update_data["spec"] = tool_spec + del update_data[tool_type] + + return ( + sql_query, + [ + developer_id, + agent_id, + tool_id, + tool_type, + data.name, + data.description, + tool_spec, + ], + ) diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index ea3866ff2..a98fef531 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -12,30 +12,28 @@ CreateFileRequest, CreateSessionRequest, CreateTaskRequest, + CreateToolRequest, CreateUserRequest, ) from agents_api.clients.pg import create_db_pool from agents_api.env import api_key, api_key_header_name, multi_tenant_mode from agents_api.queries.agents.create_agent import create_agent +from agents_api.queries.agents.delete_agent import delete_agent from agents_api.queries.developers.create_developer import create_developer - -# from agents_api.queries.agents.delete_agent import delete_agent from agents_api.queries.developers.get_developer import get_developer from agents_api.queries.docs.create_doc import create_doc +from agents_api.queries.docs.delete_doc import delete_doc -# from agents_api.queries.docs.delete_doc import delete_doc -# from agents_api.queries.execution.create_execution import create_execution -# from agents_api.queries.execution.create_execution_transition import create_execution_transition -# from agents_api.queries.execution.create_temporal_lookup import create_temporal_lookup +# from agents_api.queries.executions.create_execution import create_execution +# from agents_api.queries.executions.create_execution_transition import create_execution_transition +# from agents_api.queries.executions.create_temporal_lookup import create_temporal_lookup from agents_api.queries.files.create_file import create_file - -# from agents_api.queries.files.delete_file import delete_file +from agents_api.queries.files.delete_file import delete_file from agents_api.queries.sessions.create_session import create_session from agents_api.queries.tasks.create_task import create_task - -# from agents_api.queries.task.delete_task import delete_task -# from agents_api.queries.tools.create_tools import create_tools -# from agents_api.queries.tools.delete_tool import delete_tool +from agents_api.queries.tasks.delete_task import delete_task +from agents_api.queries.tools.create_tools import create_tools +from agents_api.queries.tools.delete_tool import delete_tool from agents_api.queries.users.create_user import create_user from agents_api.web import app @@ -347,40 +345,41 @@ async def test_session( # yield transition -# @fixture(scope="global") -# async def test_tool( -# dsn=pg_dsn, -# developer_id=test_developer_id, -# agent=test_agent, -# ): -# function = { -# "description": "A function that prints hello world", -# "parameters": {"type": "object", "properties": {}}, -# } - -# tool = { -# "function": function, -# "name": "hello_world1", -# "type": "function", -# } - -# [tool, *_] = await create_tools( -# developer_id=developer_id, -# agent_id=agent.id, -# data=[CreateToolRequest(**tool)], -# connection_pool=pool, -# ) -# yield tool - -# # Cleanup -# try: -# await delete_tool( -# developer_id=developer_id, -# tool_id=tool.id, -# connection_pool=pool, -# ) -# finally: -# await pool.close() +@fixture(scope="global") +async def test_tool( + dsn=pg_dsn, + developer_id=test_developer_id, + agent=test_agent, +): + pool = await create_db_pool(dsn=dsn) + function = { + "description": "A function that prints hello world", + "parameters": {"type": "object", "properties": {}}, + } + + tool = { + "function": function, + "name": "hello_world1", + "type": "function", + } + + [tool, *_] = await create_tools( + developer_id=developer_id, + agent_id=agent.id, + data=[CreateToolRequest(**tool)], + connection_pool=pool, + ) + yield tool + + # Cleanup + try: + await delete_tool( + developer_id=developer_id, + tool_id=tool.id, + connection_pool=pool, + ) + finally: + await pool.close() @fixture(scope="global") diff --git a/agents-api/tests/test_tool_queries.py b/agents-api/tests/test_tool_queries.py index f6f4bac47..5056f03ca 100644 --- a/agents-api/tests/test_tool_queries.py +++ b/agents-api/tests/test_tool_queries.py @@ -1,170 +1,178 @@ # # Tests for tool queries -# from ward import test - -# from agents_api.autogen.openapi_model import ( -# CreateToolRequest, -# PatchToolRequest, -# Tool, -# UpdateToolRequest, -# ) -# from agents_api.queries.tools.create_tools import create_tools -# from agents_api.queries.tools.delete_tool import delete_tool -# from agents_api.queries.tools.get_tool import get_tool -# from agents_api.queries.tools.list_tools import list_tools -# from agents_api.queries.tools.patch_tool import patch_tool -# from agents_api.queries.tools.update_tool import update_tool -# from tests.fixtures import cozo_client, test_agent, test_developer_id, test_tool - - -# @test("query: create tool") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# function = { -# "name": "hello_world", -# "description": "A function that prints hello world", -# "parameters": {"type": "object", "properties": {}}, -# } - -# tool = { -# "function": function, -# "name": "hello_world", -# "type": "function", -# } - -# result = create_tools( -# developer_id=developer_id, -# agent_id=agent.id, -# data=[CreateToolRequest(**tool)], -# client=client, -# ) - -# assert result is not None -# assert isinstance(result[0], Tool) - - -# @test("query: delete tool") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# function = { -# "name": "temp_temp", -# "description": "A function that prints hello world", -# "parameters": {"type": "object", "properties": {}}, -# } - -# tool = { -# "function": function, -# "name": "temp_temp", -# "type": "function", -# } - -# [tool, *_] = create_tools( -# developer_id=developer_id, -# agent_id=agent.id, -# data=[CreateToolRequest(**tool)], -# client=client, -# ) - -# result = delete_tool( -# developer_id=developer_id, -# agent_id=agent.id, -# tool_id=tool.id, -# client=client, -# ) - -# assert result is not None - - -# @test("query: get tool") -# def _( -# client=cozo_client, developer_id=test_developer_id, tool=test_tool, agent=test_agent -# ): -# result = get_tool( -# developer_id=developer_id, -# agent_id=agent.id, -# tool_id=tool.id, -# client=client, -# ) - -# assert result is not None - - -# @test("query: list tools") -# def _( -# client=cozo_client, developer_id=test_developer_id, agent=test_agent, tool=test_tool -# ): -# result = list_tools( -# developer_id=developer_id, -# agent_id=agent.id, -# client=client, -# ) - -# assert result is not None -# assert all(isinstance(tool, Tool) for tool in result) - - -# @test("query: patch tool") -# def _( -# client=cozo_client, developer_id=test_developer_id, agent=test_agent, tool=test_tool -# ): -# patch_data = PatchToolRequest( -# **{ -# "name": "patched_tool", -# "function": { -# "description": "A patched function that prints hello world", -# }, -# } -# ) - -# result = patch_tool( -# developer_id=developer_id, -# agent_id=agent.id, -# tool_id=tool.id, -# data=patch_data, -# client=client, -# ) - -# assert result is not None - -# tool = get_tool( -# developer_id=developer_id, -# agent_id=agent.id, -# tool_id=tool.id, -# client=client, -# ) - -# assert tool.name == "patched_tool" -# assert tool.function.description == "A patched function that prints hello world" -# assert tool.function.parameters - - -# @test("query: update tool") -# def _( -# client=cozo_client, developer_id=test_developer_id, agent=test_agent, tool=test_tool -# ): -# update_data = UpdateToolRequest( -# name="updated_tool", -# description="An updated description", -# type="function", -# function={ -# "description": "An updated function that prints hello world", -# }, -# ) - -# result = update_tool( -# developer_id=developer_id, -# agent_id=agent.id, -# tool_id=tool.id, -# data=update_data, -# client=client, -# ) - -# assert result is not None - -# tool = get_tool( -# developer_id=developer_id, -# agent_id=agent.id, -# tool_id=tool.id, -# client=client, -# ) - -# assert tool.name == "updated_tool" -# assert not tool.function.parameters +from ward import test + +from agents_api.autogen.openapi_model import ( + CreateToolRequest, + PatchToolRequest, + Tool, + UpdateToolRequest, +) +from agents_api.clients.pg import create_db_pool +from agents_api.queries.tools.create_tools import create_tools +from agents_api.queries.tools.delete_tool import delete_tool +from agents_api.queries.tools.get_tool import get_tool +from agents_api.queries.tools.list_tools import list_tools +from agents_api.queries.tools.patch_tool import patch_tool +from agents_api.queries.tools.update_tool import update_tool +from tests.fixtures import pg_dsn, test_agent, test_developer_id, test_tool + + +@test("query: create tool") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + function = { + "name": "hello_world", + "description": "A function that prints hello world", + "parameters": {"type": "object", "properties": {}}, + } + + tool = { + "function": function, + "name": "hello_world", + "type": "function", + } + + result = await create_tools( + developer_id=developer_id, + agent_id=agent.id, + data=[CreateToolRequest(**tool)], + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result[0], Tool) + + +@test("query: delete tool") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + function = { + "name": "temp_temp", + "description": "A function that prints hello world", + "parameters": {"type": "object", "properties": {}}, + } + + tool = { + "function": function, + "name": "temp_temp", + "type": "function", + } + + [tool, *_] = await create_tools( + developer_id=developer_id, + agent_id=agent.id, + data=[CreateToolRequest(**tool)], + connection_pool=pool, + ) + + result = await delete_tool( + developer_id=developer_id, + agent_id=agent.id, + tool_id=tool.id, + connection_pool=pool, + ) + + assert result is not None + + +@test("query: get tool") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, tool=test_tool, agent=test_agent +): + pool = await create_db_pool(dsn=dsn) + result = get_tool( + developer_id=developer_id, + agent_id=agent.id, + tool_id=tool.id, + connection_pool=pool, + ) + + assert result is not None + + +@test("query: list tools") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, tool=test_tool +): + pool = await create_db_pool(dsn=dsn) + result = await list_tools( + developer_id=developer_id, + agent_id=agent.id, + connection_pool=pool, + ) + + assert result is not None + assert all(isinstance(tool, Tool) for tool in result) + + +@test("query: patch tool") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, tool=test_tool +): + pool = await create_db_pool(dsn=dsn) + patch_data = PatchToolRequest( + **{ + "name": "patched_tool", + "function": { + "description": "A patched function that prints hello world", + "parameters": {"param1": "value1"}, + }, + } + ) + + result = await patch_tool( + developer_id=developer_id, + agent_id=agent.id, + tool_id=tool.id, + data=patch_data, + connection_pool=pool, + ) + + assert result is not None + + tool = await get_tool( + developer_id=developer_id, + agent_id=agent.id, + tool_id=tool.id, + connection_pool=pool, + ) + + assert tool.name == "patched_tool" + assert tool.function.description == "A patched function that prints hello world" + assert tool.function.parameters + + +@test("query: update tool") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, tool=test_tool +): + pool = await create_db_pool(dsn=dsn) + update_data = UpdateToolRequest( + name="updated_tool", + description="An updated description", + type="function", + function={ + "description": "An updated function that prints hello world", + }, + ) + + result = await update_tool( + developer_id=developer_id, + agent_id=agent.id, + tool_id=tool.id, + data=update_data, + connection_pool=pool, + ) + + assert result is not None + + tool = await get_tool( + developer_id=developer_id, + agent_id=agent.id, + tool_id=tool.id, + connection_pool=pool, + ) + + assert tool.name == "updated_tool" + assert not tool.function.parameters