From 3a627b185d7ed30cf81cf33af1a3f76f7e67d2c1 Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Mon, 16 Dec 2024 18:31:44 -0500 Subject: [PATCH 01/29] feat(agents-api): Add entry queries --- .../queries/{ => entry}/__init__.py | 0 .../queries/entry/create_entries.py | 101 ++++++++++++++++++ .../queries/entry/delete_entries.py | 43 ++++++++ .../agents_api/queries/entry/get_history.py | 71 ++++++++++++ .../agents_api/queries/entry/list_entries.py | 74 +++++++++++++ 5 files changed, 289 insertions(+) rename agents-api/agents_api/queries/{ => entry}/__init__.py (100%) create mode 100644 agents-api/agents_api/queries/entry/create_entries.py create mode 100644 agents-api/agents_api/queries/entry/delete_entries.py create mode 100644 agents-api/agents_api/queries/entry/get_history.py create mode 100644 agents-api/agents_api/queries/entry/list_entries.py diff --git a/agents-api/agents_api/queries/__init__.py b/agents-api/agents_api/queries/entry/__init__.py similarity index 100% rename from agents-api/agents_api/queries/__init__.py rename to agents-api/agents_api/queries/entry/__init__.py diff --git a/agents-api/agents_api/queries/entry/create_entries.py b/agents-api/agents_api/queries/entry/create_entries.py new file mode 100644 index 000000000..feeebde89 --- /dev/null +++ b/agents-api/agents_api/queries/entry/create_entries.py @@ -0,0 +1,101 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one +from sqlglot.optimizer import optimize + +from ...autogen.openapi_model import CreateEntryRequest, Entry +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class +from ...common.utils.datetime import utcnow +from ...common.utils.messages import content_to_json +from uuid_extensions import uuid7 + +# Define the raw SQL query for creating entries +raw_query = """ +INSERT INTO entries ( + session_id, + entry_id, + source, + role, + event_type, + name, + content, + tool_call_id, + tool_calls, + model, + token_count, + created_at, + timestamp +) +VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13 +) +RETURNING *; +""" + +# Parse and optimize the query +query = optimize( + parse_one(raw_query), + schema={ + "entries": { + "session_id": "UUID", + "entry_id": "UUID", + "source": "TEXT", + "role": "chat_role", + "event_type": "TEXT", + "name": "TEXT", + "content": "JSONB[]", + "tool_call_id": "TEXT", + "tool_calls": "JSONB[]", + "model": "TEXT", + "token_count": "INTEGER", + "created_at": "TIMESTAMP", + "timestamp": "TIMESTAMP", + } + }, +).sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass(HTTPException, status_code=400), + asyncpg.UniqueViolationError: partialclass(HTTPException, status_code=409), + } +) +@wrap_in_class(Entry) +@increase_counter("create_entries") +@pg_query +@beartype +def create_entries( + *, + developer_id: UUID, + session_id: UUID, + data: list[CreateEntryRequest], + mark_session_as_updated: bool = True, +) -> tuple[str, list]: + + data_dicts = [item.model_dump(mode="json") for item in data] + + params = [ + ( + session_id, + item.pop("id", None) or str(uuid7()), + item.get("source"), + item.get("role"), + item.get("event_type") or 'message.create', + item.get("name"), + content_to_json(item.get("content") or []), + item.get("tool_call_id"), + item.get("tool_calls") or [], + item.get("model"), + item.get("token_count"), + (item.get("created_at") or utcnow()).timestamp(), + utcnow().timestamp(), + ) + for item in data_dicts + ] + + return query, params diff --git a/agents-api/agents_api/queries/entry/delete_entries.py b/agents-api/agents_api/queries/entry/delete_entries.py new file mode 100644 index 000000000..0150be3ee --- /dev/null +++ b/agents-api/agents_api/queries/entry/delete_entries.py @@ -0,0 +1,43 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one +from sqlglot.optimizer import optimize + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for deleting entries +raw_query = """ +DELETE FROM entries +WHERE session_id = $1 +RETURNING session_id as id; +""" + +# Parse and optimize the query +query = optimize( + parse_one(raw_query), + schema={ + "entries": { + "session_id": "UUID", + } + }, +).sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass(HTTPException, status_code=400), + } +) +@wrap_in_class(ResourceDeletedResponse, one=True) +@increase_counter("delete_entries_for_session") +@pg_query +@beartype +def delete_entries_for_session( + *, developer_id: UUID, session_id: UUID, mark_session_as_updated: bool = True +) -> tuple[str, dict]: + return query, [session_id] diff --git a/agents-api/agents_api/queries/entry/get_history.py b/agents-api/agents_api/queries/entry/get_history.py new file mode 100644 index 000000000..eae4f4e6c --- /dev/null +++ b/agents-api/agents_api/queries/entry/get_history.py @@ -0,0 +1,71 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one +from sqlglot.optimizer import optimize + +from ...autogen.openapi_model import History +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for getting history +raw_query = """ +SELECT + e.entry_id as id, + e.session_id, + e.role, + e.name, + e.content, + e.source, + e.token_count, + e.tokenizer, + e.created_at, + e.timestamp, + e.tool_calls, + e.tool_call_id +FROM entries e +WHERE e.session_id = $1 +AND e.source = ANY($2) +ORDER BY e.created_at; +""" + +# Parse and optimize the query +query = optimize( + parse_one(raw_query), + schema={ + "entries": { + "entry_id": "UUID", + "session_id": "UUID", + "role": "STRING", + "name": "STRING", + "content": "JSONB", + "source": "STRING", + "token_count": "INTEGER", + "tokenizer": "STRING", + "created_at": "TIMESTAMP", + "timestamp": "TIMESTAMP", + "tool_calls": "JSONB", + "tool_call_id": "UUID", + } + }, +).sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass(HTTPException, status_code=400), + } +) +@wrap_in_class(History, one=True) +@increase_counter("get_history") +@pg_query +@beartype +def get_history( + *, + developer_id: UUID, + session_id: UUID, + allowed_sources: list[str] = ["api_request", "api_response"], +) -> tuple[str, list]: + return query, [session_id, allowed_sources] diff --git a/agents-api/agents_api/queries/entry/list_entries.py b/agents-api/agents_api/queries/entry/list_entries.py new file mode 100644 index 000000000..e5884b1b3 --- /dev/null +++ b/agents-api/agents_api/queries/entry/list_entries.py @@ -0,0 +1,74 @@ +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one +from sqlglot.optimizer import optimize + +from ...autogen.openapi_model import Entry +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for listing entries +raw_query = """ +SELECT + e.entry_id as id, + e.session_id, + e.role, + e.name, + e.content, + e.source, + e.token_count, + e.tokenizer, + e.created_at, + e.timestamp +FROM entries e +WHERE e.session_id = $1 +AND e.source = ANY($2) +ORDER BY e.$3 $4 +LIMIT $5 OFFSET $6; +""" + +# Parse and optimize the query +query = optimize( + parse_one(raw_query), + schema={ + "entries": { + "entry_id": "UUID", + "session_id": "UUID", + "role": "STRING", + "name": "STRING", + "content": "JSONB", + "source": "STRING", + "token_count": "INTEGER", + "tokenizer": "STRING", + "created_at": "TIMESTAMP", + "timestamp": "TIMESTAMP", + } + }, +).sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass(HTTPException, status_code=400), + } +) +@wrap_in_class(Entry) +@increase_counter("list_entries") +@pg_query +@beartype +def list_entries( + *, + developer_id: UUID, + session_id: UUID, + allowed_sources: list[str] = ["api_request", "api_response"], + limit: int = -1, + offset: int = 0, + sort_by: Literal["created_at", "timestamp"] = "timestamp", + direction: Literal["asc", "desc"] = "asc", + exclude_relations: list[str] = [], +) -> tuple[str, dict]: + return query, [session_id, allowed_sources, sort_by, direction, limit, offset] From 6aa48071eaa6dc7847915e9d1b0b8e3ba08f7ec2 Mon Sep 17 00:00:00 2001 From: Vedantsahai18 Date: Mon, 16 Dec 2024 23:32:45 +0000 Subject: [PATCH 02/29] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/entry/create_entries.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/agents-api/agents_api/queries/entry/create_entries.py b/agents-api/agents_api/queries/entry/create_entries.py index feeebde89..98bac13c6 100644 --- a/agents-api/agents_api/queries/entry/create_entries.py +++ b/agents-api/agents_api/queries/entry/create_entries.py @@ -5,13 +5,13 @@ from fastapi import HTTPException from sqlglot import parse_one from sqlglot.optimizer import optimize +from uuid_extensions import uuid7 from ...autogen.openapi_model import CreateEntryRequest, Entry -from ...metrics.counters import increase_counter -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class from ...common.utils.datetime import utcnow from ...common.utils.messages import content_to_json -from uuid_extensions import uuid7 +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for creating entries raw_query = """ @@ -76,16 +76,15 @@ def create_entries( data: list[CreateEntryRequest], mark_session_as_updated: bool = True, ) -> tuple[str, list]: - data_dicts = [item.model_dump(mode="json") for item in data] - + params = [ ( session_id, item.pop("id", None) or str(uuid7()), item.get("source"), item.get("role"), - item.get("event_type") or 'message.create', + item.get("event_type") or "message.create", item.get("name"), content_to_json(item.get("content") or []), item.get("tool_call_id"), From a8d20686d83be37ac52e8718e7d175499a8f8e39 Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Mon, 16 Dec 2024 23:08:02 -0500 Subject: [PATCH 03/29] chore: update the entyr queries --- .../agents_api/queries/entry/__init__.py | 21 +++++++++++++++++++ .../queries/entry/create_entries.py | 5 ++++- .../queries/entry/delete_entries.py | 5 ++++- .../agents_api/queries/entry/get_history.py | 9 ++++---- .../agents_api/queries/entry/list_entries.py | 8 ++++--- 5 files changed, 39 insertions(+), 9 deletions(-) diff --git a/agents-api/agents_api/queries/entry/__init__.py b/agents-api/agents_api/queries/entry/__init__.py index e69de29bb..2ad83f115 100644 --- a/agents-api/agents_api/queries/entry/__init__.py +++ b/agents-api/agents_api/queries/entry/__init__.py @@ -0,0 +1,21 @@ +""" +The `entry` module provides SQL query functions for managing entries +in the TimescaleDB database. This includes operations for: + +- Creating new entries +- Deleting entries +- Retrieving entry history +- Listing entries with filtering and pagination +""" + +from .create_entries import create_entries +from .delete_entries import delete_entries_for_session +from .get_history import get_history +from .list_entries import list_entries + +__all__ = [ + "create_entries", + "delete_entries_for_session", + "get_history", + "list_entries", +] diff --git a/agents-api/agents_api/queries/entry/create_entries.py b/agents-api/agents_api/queries/entry/create_entries.py index 98bac13c6..3edad7b42 100644 --- a/agents-api/agents_api/queries/entry/create_entries.py +++ b/agents-api/agents_api/queries/entry/create_entries.py @@ -97,4 +97,7 @@ def create_entries( for item in data_dicts ] - return query, params + return ( + query, + params, + ) diff --git a/agents-api/agents_api/queries/entry/delete_entries.py b/agents-api/agents_api/queries/entry/delete_entries.py index 0150be3ee..d19dfa632 100644 --- a/agents-api/agents_api/queries/entry/delete_entries.py +++ b/agents-api/agents_api/queries/entry/delete_entries.py @@ -40,4 +40,7 @@ def delete_entries_for_session( *, developer_id: UUID, session_id: UUID, mark_session_as_updated: bool = True ) -> tuple[str, dict]: - return query, [session_id] + return ( + query, + [session_id], + ) diff --git a/agents-api/agents_api/queries/entry/get_history.py b/agents-api/agents_api/queries/entry/get_history.py index eae4f4e6c..8b98ed25c 100644 --- a/agents-api/agents_api/queries/entry/get_history.py +++ b/agents-api/agents_api/queries/entry/get_history.py @@ -20,7 +20,6 @@ e.content, e.source, e.token_count, - e.tokenizer, e.created_at, e.timestamp, e.tool_calls, @@ -43,7 +42,6 @@ "content": "JSONB", "source": "STRING", "token_count": "INTEGER", - "tokenizer": "STRING", "created_at": "TIMESTAMP", "timestamp": "TIMESTAMP", "tool_calls": "JSONB", @@ -67,5 +65,8 @@ def get_history( developer_id: UUID, session_id: UUID, allowed_sources: list[str] = ["api_request", "api_response"], -) -> tuple[str, list]: - return query, [session_id, allowed_sources] +) -> tuple[str, dict]: + return ( + query, + [session_id, allowed_sources], + ) diff --git a/agents-api/agents_api/queries/entry/list_entries.py b/agents-api/agents_api/queries/entry/list_entries.py index e5884b1b3..d2b664866 100644 --- a/agents-api/agents_api/queries/entry/list_entries.py +++ b/agents-api/agents_api/queries/entry/list_entries.py @@ -21,7 +21,6 @@ e.content, e.source, e.token_count, - e.tokenizer, e.created_at, e.timestamp FROM entries e @@ -43,7 +42,6 @@ "content": "JSONB", "source": "STRING", "token_count": "INTEGER", - "tokenizer": "STRING", "created_at": "TIMESTAMP", "timestamp": "TIMESTAMP", } @@ -71,4 +69,8 @@ def list_entries( direction: Literal["asc", "desc"] = "asc", exclude_relations: list[str] = [], ) -> tuple[str, dict]: - return query, [session_id, allowed_sources, sort_by, direction, limit, offset] + + return ( + query, + [session_id, allowed_sources, sort_by, direction, limit, offset], + ) From dc2002f199564153aa4688a0aca43ead110115c0 Mon Sep 17 00:00:00 2001 From: Vedantsahai18 Date: Tue, 17 Dec 2024 04:09:08 +0000 Subject: [PATCH 04/29] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/entry/list_entries.py | 1 - 1 file changed, 1 deletion(-) diff --git a/agents-api/agents_api/queries/entry/list_entries.py b/agents-api/agents_api/queries/entry/list_entries.py index d2b664866..6d8d88de5 100644 --- a/agents-api/agents_api/queries/entry/list_entries.py +++ b/agents-api/agents_api/queries/entry/list_entries.py @@ -69,7 +69,6 @@ def list_entries( direction: Literal["asc", "desc"] = "asc", exclude_relations: list[str] = [], ) -> tuple[str, dict]: - return ( query, [session_id, allowed_sources, sort_by, direction, limit, offset], From 70b759848b48b6f27ff99a7dbf696e33be073eeb Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Mon, 16 Dec 2024 23:29:49 -0500 Subject: [PATCH 05/29] chore: inner join developer table with entry queries --- .../agents_api/queries/entry/create_entries.py | 10 +++++++--- .../agents_api/queries/entry/delete_entries.py | 12 +++++++----- agents-api/agents_api/queries/entry/get_history.py | 7 ++++--- agents-api/agents_api/queries/entry/list_entries.py | 7 ++++--- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/agents-api/agents_api/queries/entry/create_entries.py b/agents-api/agents_api/queries/entry/create_entries.py index 3edad7b42..c131b0362 100644 --- a/agents-api/agents_api/queries/entry/create_entries.py +++ b/agents-api/agents_api/queries/entry/create_entries.py @@ -13,7 +13,7 @@ from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -# Define the raw SQL query for creating entries +# Define the raw SQL query for creating entries with a developer check raw_query = """ INSERT INTO entries ( session_id, @@ -30,9 +30,12 @@ created_at, timestamp ) -VALUES ( +SELECT $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13 -) +FROM + developers +WHERE + developer_id = $14 RETURNING *; """ @@ -93,6 +96,7 @@ def create_entries( item.get("token_count"), (item.get("created_at") or utcnow()).timestamp(), utcnow().timestamp(), + developer_id ) for item in data_dicts ] diff --git a/agents-api/agents_api/queries/entry/delete_entries.py b/agents-api/agents_api/queries/entry/delete_entries.py index d19dfa632..1fa34176f 100644 --- a/agents-api/agents_api/queries/entry/delete_entries.py +++ b/agents-api/agents_api/queries/entry/delete_entries.py @@ -10,11 +10,13 @@ from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -# Define the raw SQL query for deleting entries +# Define the raw SQL query for deleting entries with a developer check raw_query = """ DELETE FROM entries -WHERE session_id = $1 -RETURNING session_id as id; +USING developers +WHERE entries.session_id = $1 +AND developers.developer_id = $2 +RETURNING entries.session_id as id; """ # Parse and optimize the query @@ -39,8 +41,8 @@ @beartype def delete_entries_for_session( *, developer_id: UUID, session_id: UUID, mark_session_as_updated: bool = True -) -> tuple[str, dict]: +) -> tuple[str, list]: return ( query, - [session_id], + [session_id, developer_id], ) diff --git a/agents-api/agents_api/queries/entry/get_history.py b/agents-api/agents_api/queries/entry/get_history.py index 8b98ed25c..dd06734b0 100644 --- a/agents-api/agents_api/queries/entry/get_history.py +++ b/agents-api/agents_api/queries/entry/get_history.py @@ -10,7 +10,7 @@ from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -# Define the raw SQL query for getting history +# Define the raw SQL query for getting history with a developer check raw_query = """ SELECT e.entry_id as id, @@ -25,6 +25,7 @@ e.tool_calls, e.tool_call_id FROM entries e +JOIN developers d ON d.developer_id = $3 WHERE e.session_id = $1 AND e.source = ANY($2) ORDER BY e.created_at; @@ -65,8 +66,8 @@ def get_history( developer_id: UUID, session_id: UUID, allowed_sources: list[str] = ["api_request", "api_response"], -) -> tuple[str, dict]: +) -> tuple[str, list]: return ( query, - [session_id, allowed_sources], + [session_id, allowed_sources, developer_id], ) diff --git a/agents-api/agents_api/queries/entry/list_entries.py b/agents-api/agents_api/queries/entry/list_entries.py index 6d8d88de5..42add6899 100644 --- a/agents-api/agents_api/queries/entry/list_entries.py +++ b/agents-api/agents_api/queries/entry/list_entries.py @@ -11,7 +11,7 @@ from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -# Define the raw SQL query for listing entries +# Define the raw SQL query for listing entries with a developer check raw_query = """ SELECT e.entry_id as id, @@ -24,6 +24,7 @@ e.created_at, e.timestamp FROM entries e +JOIN developers d ON d.developer_id = $7 WHERE e.session_id = $1 AND e.source = ANY($2) ORDER BY e.$3 $4 @@ -68,8 +69,8 @@ def list_entries( sort_by: Literal["created_at", "timestamp"] = "timestamp", direction: Literal["asc", "desc"] = "asc", exclude_relations: list[str] = [], -) -> tuple[str, dict]: +) -> tuple[str, list]: return ( query, - [session_id, allowed_sources, sort_by, direction, limit, offset], + [session_id, allowed_sources, sort_by, direction, limit, offset, developer_id], ) From 5cf876757d3a8b583775aec2482c6928b647d314 Mon Sep 17 00:00:00 2001 From: Vedantsahai18 Date: Tue, 17 Dec 2024 04:30:39 +0000 Subject: [PATCH 06/29] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/entry/create_entries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agents-api/agents_api/queries/entry/create_entries.py b/agents-api/agents_api/queries/entry/create_entries.py index c131b0362..d3b3b4982 100644 --- a/agents-api/agents_api/queries/entry/create_entries.py +++ b/agents-api/agents_api/queries/entry/create_entries.py @@ -96,7 +96,7 @@ def create_entries( item.get("token_count"), (item.get("created_at") or utcnow()).timestamp(), utcnow().timestamp(), - developer_id + developer_id, ) for item in data_dicts ] From 8b6b0d90062fc1dc7471c4cd6239ca4cfded5275 Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Tue, 17 Dec 2024 23:59:05 +0530 Subject: [PATCH 07/29] wip(agents-api): Add session sql queries Signed-off-by: Diwank Singh Tomer --- .../agents_api/queries/sessions/__init__.py | 31 ++ .../queries/sessions/count_sessions.py | 55 ++++ .../sessions/create_or_update_session.py | 151 ++++++++++ .../queries/sessions/create_session.py | 138 +++++++++ .../queries/sessions/delete_session.py | 69 +++++ .../queries/sessions/get_session.py | 85 ++++++ .../queries/sessions/list_sessions.py | 109 +++++++ .../queries/sessions/patch_session.py | 131 +++++++++ .../queries/sessions/update_session.py | 131 +++++++++ .../agents_api/queries/users/list_users.py | 3 - agents-api/tests/test_session_queries.py | 265 ++++++++++-------- .../migrations/000009_sessions.up.sql | 3 +- memory-store/migrations/000015_entries.up.sql | 16 ++ 13 files changed, 1065 insertions(+), 122 deletions(-) create mode 100644 agents-api/agents_api/queries/sessions/__init__.py create mode 100644 agents-api/agents_api/queries/sessions/count_sessions.py create mode 100644 agents-api/agents_api/queries/sessions/create_or_update_session.py create mode 100644 agents-api/agents_api/queries/sessions/create_session.py create mode 100644 agents-api/agents_api/queries/sessions/delete_session.py create mode 100644 agents-api/agents_api/queries/sessions/get_session.py create mode 100644 agents-api/agents_api/queries/sessions/list_sessions.py create mode 100644 agents-api/agents_api/queries/sessions/patch_session.py create mode 100644 agents-api/agents_api/queries/sessions/update_session.py diff --git a/agents-api/agents_api/queries/sessions/__init__.py b/agents-api/agents_api/queries/sessions/__init__.py new file mode 100644 index 000000000..bf192210b --- /dev/null +++ b/agents-api/agents_api/queries/sessions/__init__.py @@ -0,0 +1,31 @@ +""" +The `sessions` module within the `queries` package provides SQL query functions for managing sessions +in the PostgreSQL database. This includes operations for: + +- Creating new sessions +- Updating existing sessions +- Retrieving session details +- Listing sessions with filtering and pagination +- Deleting sessions +""" + +from .count_sessions import count_sessions +from .create_or_update_session import create_or_update_session +from .create_session import create_session +from .delete_session import delete_session +from .get_session import get_session +from .list_sessions import list_sessions +from .patch_session import patch_session +from .update_session import update_session + +__all__ = [ + "count_sessions", + "create_or_update_session", + "create_session", + "delete_session", + "get_session", + "list_sessions", + "patch_session", + "update_session", +] + diff --git a/agents-api/agents_api/queries/sessions/count_sessions.py b/agents-api/agents_api/queries/sessions/count_sessions.py new file mode 100644 index 000000000..71c1ec0dc --- /dev/null +++ b/agents-api/agents_api/queries/sessions/count_sessions.py @@ -0,0 +1,55 @@ +"""This module contains functions for querying session data from the PostgreSQL database.""" + +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query outside the function +raw_query = """ +SELECT COUNT(session_id) as count +FROM sessions +WHERE developer_id = $1; +""" + +# Parse and optimize the query +query = parse_one(raw_query).sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ) + } +) +@wrap_in_class(dict, one=True) +@increase_counter("count_sessions") +@pg_query +@beartype +async def count_sessions( + *, + developer_id: UUID, +) -> tuple[str, list]: + """ + Counts sessions from the PostgreSQL database. + Uses the index on developer_id for efficient counting. + + Args: + developer_id (UUID): The developer's ID to filter sessions by. + + Returns: + tuple[str, list]: SQL query and parameters. + """ + + return ( + query, + [developer_id], + ) \ No newline at end of file diff --git a/agents-api/agents_api/queries/sessions/create_or_update_session.py b/agents-api/agents_api/queries/sessions/create_or_update_session.py new file mode 100644 index 000000000..4bbbef091 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/create_or_update_session.py @@ -0,0 +1,151 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import CreateOrUpdateSessionRequest, ResourceUpdatedResponse +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +session_query = parse_one(""" +INSERT INTO sessions ( + developer_id, + session_id, + situation, + system_template, + metadata, + render_templates, + token_budget, + context_overflow, + forward_tool_calls, + recall_options +) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8, + $9, + $10 +) +ON CONFLICT (developer_id, session_id) DO UPDATE SET + situation = EXCLUDED.situation, + system_template = EXCLUDED.system_template, + metadata = EXCLUDED.metadata, + render_templates = EXCLUDED.render_templates, + token_budget = EXCLUDED.token_budget, + context_overflow = EXCLUDED.context_overflow, + forward_tool_calls = EXCLUDED.forward_tool_calls, + recall_options = EXCLUDED.recall_options +RETURNING *; +""").sql(pretty=True) + +lookup_query = parse_one(""" +WITH deleted_lookups AS ( + DELETE FROM session_lookup + WHERE developer_id = $1 AND session_id = $2 +) +INSERT INTO session_lookup ( + developer_id, + session_id, + participant_type, + participant_id +) +SELECT + $1 as developer_id, + $2 as session_id, + unnest($3::participant_type[]) as participant_type, + unnest($4::uuid[]) as participant_id; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or participant does not exist.", + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="A session with this ID already exists.", + ), + } +) +@wrap_in_class(ResourceUpdatedResponse, one=True) +@increase_counter("create_or_update_session") +@pg_query +@beartype +async def create_or_update_session( + *, + developer_id: UUID, + session_id: UUID, + data: CreateOrUpdateSessionRequest, +) -> list[tuple[str, list]]: + """ + Constructs SQL queries to create or update a session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + data (CreateOrUpdateSessionRequest): Session data to insert or update + + Returns: + list[tuple[str, list]]: List of SQL queries and their parameters + """ + # Handle participants + users = data.users or ([data.user] if data.user else []) + agents = data.agents or ([data.agent] if data.agent else []) + + if not agents: + raise HTTPException( + status_code=400, + detail="At least one agent must be provided", + ) + + if data.agent and data.agents: + raise HTTPException( + status_code=400, + detail="Only one of 'agent' or 'agents' should be provided", + ) + + # Prepare participant arrays for lookup query + participant_types = ( + ["user"] * len(users) + ["agent"] * len(agents) + ) + participant_ids = [str(u) for u in users] + [str(a) for a in agents] + + # Prepare session parameters + session_params = [ + developer_id, # $1 + session_id, # $2 + data.situation, # $3 + data.system_template, # $4 + data.metadata or {}, # $5 + data.render_templates, # $6 + data.token_budget, # $7 + data.context_overflow, # $8 + data.forward_tool_calls, # $9 + data.recall_options or {}, # $10 + ] + + # Prepare lookup parameters + lookup_params = [ + developer_id, # $1 + session_id, # $2 + participant_types, # $3 + participant_ids, # $4 + ] + + return [ + (session_query, session_params), + (lookup_query, lookup_params), + ] diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py new file mode 100644 index 000000000..9f756f25c --- /dev/null +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -0,0 +1,138 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import CreateSessionRequest, Session +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +session_query = parse_one(""" +INSERT INTO sessions ( + developer_id, + session_id, + situation, + system_template, + metadata, + render_templates, + token_budget, + context_overflow, + forward_tool_calls, + recall_options +) +VALUES ( + $1, -- developer_id + $2, -- session_id + $3, -- situation + $4, -- system_template + $5, -- metadata + $6, -- render_templates + $7, -- token_budget + $8, -- context_overflow + $9, -- forward_tool_calls + $10 -- recall_options +) +RETURNING *; +""").sql(pretty=True) + +lookup_query = parse_one(""" +INSERT INTO session_lookup ( + developer_id, + session_id, + participant_type, + participant_id +) +SELECT + $1 as developer_id, + $2 as session_id, + unnest($3::participant_type[]) as participant_type, + unnest($4::uuid[]) as participant_id; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or participant does not exist.", + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="A session with this ID already exists.", + ), + } +) +@wrap_in_class(Session, one=True, transform=lambda d: {**d, "id": d["session_id"]}) +@increase_counter("create_session") +@pg_query +@beartype +async def create_session( + *, + developer_id: UUID, + session_id: UUID, + data: CreateSessionRequest, +) -> list[tuple[str, list]]: + """ + Constructs SQL queries to create a new session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + data (CreateSessionRequest): Session creation data + + Returns: + list[tuple[str, list]]: SQL queries and their parameters + """ + # Handle participants + users = data.users or ([data.user] if data.user else []) + agents = data.agents or ([data.agent] if data.agent else []) + + if not agents: + raise HTTPException( + status_code=400, + detail="At least one agent must be provided", + ) + + if data.agent and data.agents: + raise HTTPException( + status_code=400, + detail="Only one of 'agent' or 'agents' should be provided", + ) + + # Prepare participant arrays for lookup query + participant_types = ( + ["user"] * len(users) + ["agent"] * len(agents) + ) + participant_ids = [str(u) for u in users] + [str(a) for a in agents] + + # Prepare session parameters + session_params = [ + developer_id, # $1 + session_id, # $2 + data.situation, # $3 + data.system_template, # $4 + data.metadata or {}, # $5 + data.render_templates, # $6 + data.token_budget, # $7 + data.context_overflow, # $8 + data.forward_tool_calls, # $9 + data.recall_options or {}, # $10 + ] + + # Prepare lookup parameters + lookup_params = [ + developer_id, # $1 + session_id, # $2 + participant_types, # $3 + participant_ids, # $4 + ] + + return [ + (session_query, session_params), + (lookup_query, lookup_params), + ] diff --git a/agents-api/agents_api/queries/sessions/delete_session.py b/agents-api/agents_api/queries/sessions/delete_session.py new file mode 100644 index 000000000..2e3234fe2 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/delete_session.py @@ -0,0 +1,69 @@ +"""This module contains the implementation for deleting sessions from the PostgreSQL database.""" + +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +lookup_query = parse_one(""" +DELETE FROM session_lookup +WHERE developer_id = $1 AND session_id = $2; +""").sql(pretty=True) + +session_query = parse_one(""" +DELETE FROM sessions +WHERE developer_id = $1 AND session_id = $2 +RETURNING session_id; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + } +) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + "id": d["session_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@increase_counter("delete_session") +@pg_query +@beartype +async def delete_session( + *, + developer_id: UUID, + session_id: UUID, +) -> list[tuple[str, list]]: + """ + Constructs SQL queries to delete a session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID to delete + + Returns: + list[tuple[str, list]]: List of SQL queries and their parameters + """ + params = [developer_id, session_id] + + return [ + (lookup_query, params), # Delete from lookup table first due to FK constraint + (session_query, params), # Then delete from sessions table + ] diff --git a/agents-api/agents_api/queries/sessions/get_session.py b/agents-api/agents_api/queries/sessions/get_session.py new file mode 100644 index 000000000..441a1c5c3 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/get_session.py @@ -0,0 +1,85 @@ +"""This module contains functions for retrieving session data from the PostgreSQL database.""" + +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import Session +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +raw_query = """ +WITH session_participants AS ( + SELECT + sl.session_id, + array_agg(sl.participant_id) FILTER (WHERE sl.participant_type = 'agent') as agents, + array_agg(sl.participant_id) FILTER (WHERE sl.participant_type = 'user') as users + FROM session_lookup sl + WHERE sl.developer_id = $1 AND sl.session_id = $2 + GROUP BY sl.session_id +) +SELECT + s.session_id as id, + s.developer_id, + s.situation, + s.system_template, + s.metadata, + s.render_templates, + s.token_budget, + s.context_overflow, + s.forward_tool_calls, + s.recall_options, + s.created_at, + s.updated_at, + sp.agents, + sp.users +FROM sessions s +LEFT JOIN session_participants sp ON s.session_id = sp.session_id +WHERE s.developer_id = $1 AND s.session_id = $2; +""" + +# Parse and optimize the query +query = parse_one(raw_query).sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="Session not found" + ), + } +) +@wrap_in_class(Session, one=True) +@increase_counter("get_session") +@pg_query +@beartype +async def get_session( + *, + developer_id: UUID, + session_id: UUID, +) -> tuple[str, list]: + """ + Constructs SQL query to retrieve a session and its participants. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + + Returns: + tuple[str, list]: SQL query and parameters + """ + return ( + query, + [developer_id, session_id], + ) diff --git a/agents-api/agents_api/queries/sessions/list_sessions.py b/agents-api/agents_api/queries/sessions/list_sessions.py new file mode 100644 index 000000000..80986a867 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/list_sessions.py @@ -0,0 +1,109 @@ +"""This module contains functions for querying session data from the PostgreSQL database.""" + +from typing import Any, Literal, TypeVar +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import Session +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +raw_query = """ +WITH session_participants AS ( + SELECT + sl.session_id, + array_agg(sl.participant_id) FILTER (WHERE sl.participant_type = 'agent') as agents, + array_agg(sl.participant_id) FILTER (WHERE sl.participant_type = 'user') as users + FROM session_lookup sl + WHERE sl.developer_id = $1 + GROUP BY sl.session_id +) +SELECT + s.session_id as id, + s.developer_id, + s.situation, + s.system_template, + s.metadata, + s.render_templates, + s.token_budget, + s.context_overflow, + s.forward_tool_calls, + s.recall_options, + s.created_at, + s.updated_at, + sp.agents, + sp.users +FROM sessions s +LEFT JOIN session_participants sp ON s.session_id = sp.session_id +WHERE s.developer_id = $1 + AND ($5::jsonb IS NULL OR s.metadata @> $5::jsonb) +ORDER BY + CASE WHEN $3 = 'created_at' AND $4 = 'desc' THEN s.created_at END DESC, + CASE WHEN $3 = 'created_at' AND $4 = 'asc' THEN s.created_at END ASC, + CASE WHEN $3 = 'updated_at' AND $4 = 'desc' THEN s.updated_at END DESC, + CASE WHEN $3 = 'updated_at' AND $4 = 'asc' THEN s.updated_at END ASC +LIMIT $2 OFFSET $6; +""" + +# Parse and optimize the query +# query = parse_one(raw_query).sql(pretty=True) +query = raw_query + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="No sessions found" + ), + } +) +@wrap_in_class(Session) +@increase_counter("list_sessions") +@pg_query +@beartype +async def list_sessions( + *, + developer_id: UUID, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", + metadata_filter: dict[str, Any] = {}, +) -> tuple[str, list]: + """ + Lists sessions from the PostgreSQL database based on the provided filters. + + Args: + developer_id (UUID): The developer's UUID + limit (int): Maximum number of sessions to return + offset (int): Number of sessions to skip + sort_by (str): Field to sort by ('created_at' or 'updated_at') + direction (str): Sort direction ('asc' or 'desc') + metadata_filter (dict): Dictionary of metadata fields to filter by + + Returns: + tuple[str, list]: SQL query and parameters + """ + return ( + query, + [ + developer_id, # $1 + limit, # $2 + sort_by, # $3 + direction, # $4 + metadata_filter or None, # $5 + offset, # $6 + ], + ) diff --git a/agents-api/agents_api/queries/sessions/patch_session.py b/agents-api/agents_api/queries/sessions/patch_session.py new file mode 100644 index 000000000..b14b94a8a --- /dev/null +++ b/agents-api/agents_api/queries/sessions/patch_session.py @@ -0,0 +1,131 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import PatchSessionRequest, ResourceUpdatedResponse +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +# Build dynamic SET clause based on provided fields +session_query = parse_one(""" +WITH updated_session AS ( + UPDATE sessions + SET + situation = COALESCE($3, situation), + system_template = COALESCE($4, system_template), + metadata = sessions.metadata || $5, + render_templates = COALESCE($6, render_templates), + token_budget = COALESCE($7, token_budget), + context_overflow = COALESCE($8, context_overflow), + forward_tool_calls = COALESCE($9, forward_tool_calls), + recall_options = sessions.recall_options || $10 + WHERE + developer_id = $1 + AND session_id = $2 + RETURNING * +) +SELECT * FROM updated_session; +""").sql(pretty=True) + +lookup_query = parse_one(""" +WITH deleted_lookups AS ( + DELETE FROM session_lookup + WHERE developer_id = $1 AND session_id = $2 +) +INSERT INTO session_lookup ( + developer_id, + session_id, + participant_type, + participant_id +) +SELECT + $1 as developer_id, + $2 as session_id, + unnest($3::participant_type[]) as participant_type, + unnest($4::uuid[]) as participant_id; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or participant does not exist.", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + } +) +@wrap_in_class(ResourceUpdatedResponse, one=True) +@increase_counter("patch_session") +@pg_query +@beartype +async def patch_session( + *, + developer_id: UUID, + session_id: UUID, + data: PatchSessionRequest, +) -> list[tuple[str, list]]: + """ + Constructs SQL queries to patch a session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + data (PatchSessionRequest): Session patch data + + Returns: + list[tuple[str, list]]: List of SQL queries and their parameters + """ + # Handle participants + users = data.users or ([data.user] if data.user else []) + agents = data.agents or ([data.agent] if data.agent else []) + + if data.agent and data.agents: + raise HTTPException( + status_code=400, + detail="Only one of 'agent' or 'agents' should be provided", + ) + + # Prepare participant arrays for lookup query if participants are provided + participant_types = [] + participant_ids = [] + if users or agents: + participant_types = ["user"] * len(users) + ["agent"] * len(agents) + participant_ids = [str(u) for u in users] + [str(a) for a in agents] + + # Extract fields from data, using None for unset fields + session_params = [ + developer_id, # $1 + session_id, # $2 + data.situation, # $3 + data.system_template, # $4 + data.metadata or {}, # $5 + data.render_templates, # $6 + data.token_budget, # $7 + data.context_overflow, # $8 + data.forward_tool_calls, # $9 + data.recall_options or {}, # $10 + ] + + queries = [(session_query, session_params)] + + # Only add lookup query if participants are provided + if participant_types: + lookup_params = [ + developer_id, # $1 + session_id, # $2 + participant_types, # $3 + participant_ids, # $4 + ] + queries.append((lookup_query, lookup_params)) + + return queries diff --git a/agents-api/agents_api/queries/sessions/update_session.py b/agents-api/agents_api/queries/sessions/update_session.py new file mode 100644 index 000000000..2999e21f6 --- /dev/null +++ b/agents-api/agents_api/queries/sessions/update_session.py @@ -0,0 +1,131 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateSessionRequest +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL queries +session_query = parse_one(""" +UPDATE sessions +SET + situation = $3, + system_template = $4, + metadata = $5, + render_templates = $6, + token_budget = $7, + context_overflow = $8, + forward_tool_calls = $9, + recall_options = $10 +WHERE + developer_id = $1 + AND session_id = $2 +RETURNING *; +""").sql(pretty=True) + +lookup_query = parse_one(""" +WITH deleted_lookups AS ( + DELETE FROM session_lookup + WHERE developer_id = $1 AND session_id = $2 +) +INSERT INTO session_lookup ( + developer_id, + session_id, + participant_type, + participant_id +) +SELECT + $1 as developer_id, + $2 as session_id, + unnest($3::participant_type[]) as participant_type, + unnest($4::uuid[]) as participant_id; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or participant does not exist.", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + } +) +@wrap_in_class(ResourceUpdatedResponse, one=True) +@increase_counter("update_session") +@pg_query +@beartype +async def update_session( + *, + developer_id: UUID, + session_id: UUID, + data: UpdateSessionRequest, +) -> list[tuple[str, list]]: + """ + Constructs SQL queries to update a session and its participant lookups. + + Args: + developer_id (UUID): The developer's UUID + session_id (UUID): The session's UUID + data (UpdateSessionRequest): Session update data + + Returns: + list[tuple[str, list]]: List of SQL queries and their parameters + """ + # Handle participants + users = data.users or ([data.user] if data.user else []) + agents = data.agents or ([data.agent] if data.agent else []) + + if not agents: + raise HTTPException( + status_code=400, + detail="At least one agent must be provided", + ) + + if data.agent and data.agents: + raise HTTPException( + status_code=400, + detail="Only one of 'agent' or 'agents' should be provided", + ) + + # Prepare participant arrays for lookup query + participant_types = ( + ["user"] * len(users) + ["agent"] * len(agents) + ) + participant_ids = [str(u) for u in users] + [str(a) for a in agents] + + # Prepare session parameters + session_params = [ + developer_id, # $1 + session_id, # $2 + data.situation, # $3 + data.system_template, # $4 + data.metadata or {}, # $5 + data.render_templates, # $6 + data.token_budget, # $7 + data.context_overflow, # $8 + data.forward_tool_calls, # $9 + data.recall_options or {}, # $10 + ] + + # Prepare lookup parameters + lookup_params = [ + developer_id, # $1 + session_id, # $2 + participant_types, # $3 + participant_ids, # $4 + ] + + return [ + (session_query, session_params), + (lookup_query, lookup_params), + ] diff --git a/agents-api/agents_api/queries/users/list_users.py b/agents-api/agents_api/queries/users/list_users.py index 7f3677eab..74b40eb7b 100644 --- a/agents-api/agents_api/queries/users/list_users.py +++ b/agents-api/agents_api/queries/users/list_users.py @@ -37,9 +37,6 @@ OFFSET $3; """ -# Parse and optimize the query -# query = parse_one(raw_query).sql(pretty=True) - @rewrap_exceptions( { diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index e8ec40367..262b5aef8 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -1,160 +1,191 @@ -# # Tests for session queries - -# from uuid_extensions import uuid7 -# from ward import test - -# from agents_api.autogen.openapi_model import ( -# CreateOrUpdateSessionRequest, -# CreateSessionRequest, -# Session, -# ) -# from agents_api.queries.session.count_sessions import count_sessions -# from agents_api.queries.session.create_or_update_session import create_or_update_session -# from agents_api.queries.session.create_session import create_session -# from agents_api.queries.session.delete_session import delete_session -# from agents_api.queries.session.get_session import get_session -# from agents_api.queries.session.list_sessions import list_sessions -# from tests.fixtures import ( -# cozo_client, -# test_agent, -# test_developer_id, -# test_session, -# test_user, -# ) - -# MODEL = "gpt-4o-mini" - - -# @test("query: create session") -# def _( -# client=cozo_client, developer_id=test_developer_id, agent=test_agent, user=test_user -# ): -# create_session( +""" +This module contains tests for SQL query generation functions in the sessions module. +Tests verify the SQL queries without actually executing them against a database. +""" + + +from uuid import UUID + +import asyncpg +from uuid_extensions import uuid7 +from ward import raises, test + +from agents_api.autogen.openapi_model import ( + CreateOrUpdateSessionRequest, + CreateSessionRequest, + PatchSessionRequest, + ResourceDeletedResponse, + ResourceUpdatedResponse, + Session, + UpdateSessionRequest, +) +from agents_api.clients.pg import create_db_pool +from agents_api.queries.sessions import ( + count_sessions, + create_or_update_session, + create_session, + delete_session, + get_session, + list_sessions, + patch_session, + update_session, +) +from tests.fixtures import pg_dsn, test_developer_id # , test_session, test_agent, test_user + + +# @test("query: create session sql") +# async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user): +# """Test that a session can be successfully created.""" + +# pool = await create_db_pool(dsn=dsn) +# await create_session( # developer_id=developer_id, +# session_id=uuid7(), # data=CreateSessionRequest( # users=[user.id], # agents=[agent.id], -# situation="test session about", +# situation="test session", # ), -# client=client, +# connection_pool=pool, # ) -# @test("query: create session no user") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# create_session( +# @test("query: create or update session sql") +# async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user): +# """Test that a session can be successfully created or updated.""" + +# pool = await create_db_pool(dsn=dsn) +# await create_or_update_session( # developer_id=developer_id, -# data=CreateSessionRequest( +# session_id=uuid7(), +# data=CreateOrUpdateSessionRequest( +# users=[user.id], # agents=[agent.id], -# situation="test session about", +# situation="test session", # ), -# client=client, +# connection_pool=pool, # ) -# @test("query: get session not exists") -# def _(client=cozo_client, developer_id=test_developer_id): -# session_id = uuid7() - -# try: -# get_session( -# session_id=session_id, -# developer_id=developer_id, -# client=client, -# ) -# except Exception: -# pass -# else: -# assert False, "Session should not exist" - +# @test("query: update session sql") +# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent): +# """Test that an existing session's information can be successfully updated.""" -# @test("query: get session exists") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# result = get_session( +# pool = await create_db_pool(dsn=dsn) +# update_result = await update_session( # session_id=session.id, # developer_id=developer_id, -# client=client, +# data=UpdateSessionRequest( +# agents=[agent.id], +# situation="updated session", +# ), +# connection_pool=pool, # ) -# assert result is not None -# assert isinstance(result, Session) +# assert update_result is not None +# assert isinstance(update_result, ResourceUpdatedResponse) +# assert update_result.updated_at > session.created_at -# @test("query: delete session") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# session = create_session( -# developer_id=developer_id, -# data=CreateSessionRequest( -# agent=agent.id, -# situation="test session about", -# ), -# client=client, -# ) +@test("query: get session not exists sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test that retrieving a non-existent session returns an empty result.""" -# delete_session( -# session_id=session.id, -# developer_id=developer_id, -# client=client, -# ) + session_id = uuid7() + pool = await create_db_pool(dsn=dsn) -# try: -# get_session( -# session_id=session.id, -# developer_id=developer_id, -# client=client, -# ) -# except Exception: -# pass + with raises(Exception): + await get_session( + session_id=session_id, + developer_id=developer_id, + connection_pool=pool, + ) -# else: -# assert False, "Session should not exist" +# @test("query: get session exists sql") +# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): +# """Test that retrieving an existing session returns the correct session information.""" -# @test("query: list sessions") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# result = list_sessions( +# pool = await create_db_pool(dsn=dsn) +# result = await get_session( +# session_id=session.id, # developer_id=developer_id, -# client=client, +# connection_pool=pool, # ) -# assert isinstance(result, list) -# assert len(result) > 0 +# assert result is not None +# assert isinstance(result, Session) -# @test("query: count sessions") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# result = count_sessions( -# developer_id=developer_id, -# client=client, -# ) +@test("query: list sessions sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test that listing sessions returns a collection of session information.""" + + pool = await create_db_pool(dsn=dsn) + result = await list_sessions( + developer_id=developer_id, + connection_pool=pool, + ) -# assert isinstance(result, dict) -# assert result["count"] > 0 + assert isinstance(result, list) + assert len(result) >= 1 + assert all(isinstance(session, Session) for session in result) -# @test("query: create or update session") -# def _( -# client=cozo_client, developer_id=test_developer_id, agent=test_agent, user=test_user -# ): -# session_id = uuid7() +# @test("query: patch session sql") +# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent): +# """Test that a session can be successfully patched.""" -# create_or_update_session( -# session_id=session_id, +# pool = await create_db_pool(dsn=dsn) +# patch_result = await patch_session( # developer_id=developer_id, -# data=CreateOrUpdateSessionRequest( -# users=[user.id], +# session_id=session.id, +# data=PatchSessionRequest( # agents=[agent.id], -# situation="test session about", +# situation="patched session", +# metadata={"test": "metadata"}, # ), -# client=client, +# connection_pool=pool, # ) -# result = get_session( -# session_id=session_id, +# assert patch_result is not None +# assert isinstance(patch_result, ResourceUpdatedResponse) +# assert patch_result.updated_at > session.created_at + + +# @test("query: delete session sql") +# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): +# """Test that a session can be successfully deleted.""" + +# pool = await create_db_pool(dsn=dsn) +# delete_result = await delete_session( # developer_id=developer_id, -# client=client, +# session_id=session.id, +# connection_pool=pool, # ) -# assert result is not None -# assert isinstance(result, Session) -# assert result.id == session_id +# assert delete_result is not None +# assert isinstance(delete_result, ResourceDeletedResponse) + +# # Verify the session no longer exists +# with raises(Exception): +# await get_session( +# developer_id=developer_id, +# session_id=session.id, +# connection_pool=pool, +# ) + + +@test("query: count sessions sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test that sessions can be counted.""" + + pool = await create_db_pool(dsn=dsn) + result = await count_sessions( + developer_id=developer_id, + connection_pool=pool, + ) + + assert isinstance(result, dict) + assert "count" in result + assert isinstance(result["count"], int) diff --git a/memory-store/migrations/000009_sessions.up.sql b/memory-store/migrations/000009_sessions.up.sql index 082f3823c..75b5fde9a 100644 --- a/memory-store/migrations/000009_sessions.up.sql +++ b/memory-store/migrations/000009_sessions.up.sql @@ -7,8 +7,7 @@ CREATE TABLE IF NOT EXISTS sessions ( situation TEXT, system_template TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - -- NOTE: Derived from entries - -- updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, metadata JSONB NOT NULL DEFAULT '{}'::JSONB, render_templates BOOLEAN NOT NULL DEFAULT TRUE, token_budget INTEGER, diff --git a/memory-store/migrations/000015_entries.up.sql b/memory-store/migrations/000015_entries.up.sql index 9985e4c41..e9d5c6a4f 100644 --- a/memory-store/migrations/000015_entries.up.sql +++ b/memory-store/migrations/000015_entries.up.sql @@ -85,4 +85,20 @@ OR UPDATE ON entries FOR EACH ROW EXECUTE FUNCTION optimized_update_token_count_after (); +-- Add trigger to update parent session's updated_at +CREATE OR REPLACE FUNCTION update_session_updated_at() +RETURNS TRIGGER AS $$ +BEGIN + UPDATE sessions + SET updated_at = CURRENT_TIMESTAMP + WHERE session_id = NEW.session_id; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER trg_update_session_updated_at +AFTER INSERT OR UPDATE ON entries +FOR EACH ROW +EXECUTE FUNCTION update_session_updated_at(); + COMMIT; \ No newline at end of file From 065c7d2ef68a762eb455a559f48e9108cc0d0d11 Mon Sep 17 00:00:00 2001 From: creatorrr Date: Tue, 17 Dec 2024 18:30:17 +0000 Subject: [PATCH 08/29] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/sessions/__init__.py | 1 - agents-api/agents_api/queries/sessions/count_sessions.py | 2 +- .../queries/sessions/create_or_update_session.py | 9 +++++---- agents-api/agents_api/queries/sessions/create_session.py | 4 +--- agents-api/agents_api/queries/sessions/get_session.py | 4 +--- agents-api/agents_api/queries/sessions/list_sessions.py | 4 +--- agents-api/agents_api/queries/sessions/update_session.py | 4 +--- agents-api/tests/test_session_queries.py | 9 +++++---- 8 files changed, 15 insertions(+), 22 deletions(-) diff --git a/agents-api/agents_api/queries/sessions/__init__.py b/agents-api/agents_api/queries/sessions/__init__.py index bf192210b..d0f64ea5e 100644 --- a/agents-api/agents_api/queries/sessions/__init__.py +++ b/agents-api/agents_api/queries/sessions/__init__.py @@ -28,4 +28,3 @@ "patch_session", "update_session", ] - diff --git a/agents-api/agents_api/queries/sessions/count_sessions.py b/agents-api/agents_api/queries/sessions/count_sessions.py index 71c1ec0dc..2abdf22e5 100644 --- a/agents-api/agents_api/queries/sessions/count_sessions.py +++ b/agents-api/agents_api/queries/sessions/count_sessions.py @@ -52,4 +52,4 @@ async def count_sessions( return ( query, [developer_id], - ) \ No newline at end of file + ) diff --git a/agents-api/agents_api/queries/sessions/create_or_update_session.py b/agents-api/agents_api/queries/sessions/create_or_update_session.py index 4bbbef091..bc54bf31b 100644 --- a/agents-api/agents_api/queries/sessions/create_or_update_session.py +++ b/agents-api/agents_api/queries/sessions/create_or_update_session.py @@ -5,7 +5,10 @@ from fastapi import HTTPException from sqlglot import parse_one -from ...autogen.openapi_model import CreateOrUpdateSessionRequest, ResourceUpdatedResponse +from ...autogen.openapi_model import ( + CreateOrUpdateSessionRequest, + ResourceUpdatedResponse, +) from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class @@ -118,9 +121,7 @@ async def create_or_update_session( ) # Prepare participant arrays for lookup query - participant_types = ( - ["user"] * len(users) + ["agent"] * len(agents) - ) + participant_types = ["user"] * len(users) + ["agent"] * len(agents) participant_ids = [str(u) for u in users] + [str(a) for a in agents] # Prepare session parameters diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py index 9f756f25c..3074f087b 100644 --- a/agents-api/agents_api/queries/sessions/create_session.py +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -105,9 +105,7 @@ async def create_session( ) # Prepare participant arrays for lookup query - participant_types = ( - ["user"] * len(users) + ["agent"] * len(agents) - ) + participant_types = ["user"] * len(users) + ["agent"] * len(agents) participant_ids = [str(u) for u in users] + [str(a) for a in agents] # Prepare session parameters diff --git a/agents-api/agents_api/queries/sessions/get_session.py b/agents-api/agents_api/queries/sessions/get_session.py index 441a1c5c3..1f704539e 100644 --- a/agents-api/agents_api/queries/sessions/get_session.py +++ b/agents-api/agents_api/queries/sessions/get_session.py @@ -54,9 +54,7 @@ detail="The specified developer does not exist.", ), asyncpg.NoDataFoundError: partialclass( - HTTPException, - status_code=404, - detail="Session not found" + HTTPException, status_code=404, detail="Session not found" ), } ) diff --git a/agents-api/agents_api/queries/sessions/list_sessions.py b/agents-api/agents_api/queries/sessions/list_sessions.py index 80986a867..5ce31803b 100644 --- a/agents-api/agents_api/queries/sessions/list_sessions.py +++ b/agents-api/agents_api/queries/sessions/list_sessions.py @@ -63,9 +63,7 @@ detail="The specified developer does not exist.", ), asyncpg.NoDataFoundError: partialclass( - HTTPException, - status_code=404, - detail="No sessions found" + HTTPException, status_code=404, detail="No sessions found" ), } ) diff --git a/agents-api/agents_api/queries/sessions/update_session.py b/agents-api/agents_api/queries/sessions/update_session.py index 2999e21f6..01e21e732 100644 --- a/agents-api/agents_api/queries/sessions/update_session.py +++ b/agents-api/agents_api/queries/sessions/update_session.py @@ -98,9 +98,7 @@ async def update_session( ) # Prepare participant arrays for lookup query - participant_types = ( - ["user"] * len(users) + ["agent"] * len(agents) - ) + participant_types = ["user"] * len(users) + ["agent"] * len(agents) participant_ids = [str(u) for u in users] + [str(a) for a in agents] # Prepare session parameters diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 262b5aef8..90b40a0d8 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -3,7 +3,6 @@ Tests verify the SQL queries without actually executing them against a database. """ - from uuid import UUID import asyncpg @@ -30,13 +29,15 @@ patch_session, update_session, ) -from tests.fixtures import pg_dsn, test_developer_id # , test_session, test_agent, test_user - +from tests.fixtures import ( + pg_dsn, + test_developer_id, +) # , test_session, test_agent, test_user # @test("query: create session sql") # async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user): # """Test that a session can be successfully created.""" - + # pool = await create_db_pool(dsn=dsn) # await create_session( # developer_id=developer_id, From 2eb10d3110872e3c8a302a1b7a48c0f1e13580b6 Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Tue, 17 Dec 2024 23:33:09 -0500 Subject: [PATCH 09/29] chore: developers and user refactor + add test for entry queries + bug fixes --- agents-api/agents_api/autogen/Entries.py | 1 + .../agents_api/autogen/openapi_model.py | 1 + .../agents_api/queries/developers/__init__.py | 7 + .../queries/developers/create_developer.py | 34 +-- .../queries/developers/get_developer.py | 25 ++- .../queries/developers/patch_developer.py | 28 ++- .../queries/developers/update_developer.py | 25 ++- .../queries/{entry => entries}/__init__.py | 8 +- .../queries/entries/create_entry.py | 196 +++++++++++++++++ .../queries/entries/delete_entry.py | 96 +++++++++ .../agents_api/queries/entries/get_history.py | 72 +++++++ .../agents_api/queries/entries/list_entry.py | 80 +++++++ .../queries/entry/create_entries.py | 107 ---------- .../queries/entry/delete_entries.py | 48 ----- .../agents_api/queries/entry/get_history.py | 73 ------- .../agents_api/queries/entry/list_entries.py | 76 ------- .../queries/users/create_or_update_user.py | 43 ++-- .../agents_api/queries/users/create_user.py | 41 ++-- .../agents_api/queries/users/delete_user.py | 31 +-- .../agents_api/queries/users/get_user.py | 33 ++- .../agents_api/queries/users/list_users.py | 42 ++-- .../agents_api/queries/users/patch_user.py | 50 +++-- .../agents_api/queries/users/update_user.py | 27 +-- agents-api/tests/test_developer_queries.py | 1 - agents-api/tests/test_entry_queries.py | 200 ++++++++---------- agents-api/tests/test_user_queries.py | 1 - agents-api/tests/utils.py | 2 - .../integrations/autogen/Entries.py | 1 + typespec/entries/models.tsp | 1 + .../@typespec/openapi3/openapi-1.0.0.yaml | 4 + 30 files changed, 758 insertions(+), 596 deletions(-) rename agents-api/agents_api/queries/{entry => entries}/__init__.py (68%) create mode 100644 agents-api/agents_api/queries/entries/create_entry.py create mode 100644 agents-api/agents_api/queries/entries/delete_entry.py create mode 100644 agents-api/agents_api/queries/entries/get_history.py create mode 100644 agents-api/agents_api/queries/entries/list_entry.py delete mode 100644 agents-api/agents_api/queries/entry/create_entries.py delete mode 100644 agents-api/agents_api/queries/entry/delete_entries.py delete mode 100644 agents-api/agents_api/queries/entry/get_history.py delete mode 100644 agents-api/agents_api/queries/entry/list_entries.py diff --git a/agents-api/agents_api/autogen/Entries.py b/agents-api/agents_api/autogen/Entries.py index de37e77d8..d195b518f 100644 --- a/agents-api/agents_api/autogen/Entries.py +++ b/agents-api/agents_api/autogen/Entries.py @@ -52,6 +52,7 @@ class BaseEntry(BaseModel): ] tokenizer: str token_count: int + modelname: str = "gpt-40-mini" tool_calls: ( list[ ChosenFunctionCall diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index d19684cee..01042c58c 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -400,6 +400,7 @@ def from_model_input( source=source, tokenizer=tokenizer["type"], token_count=token_count, + modelname=model, **kwargs, ) diff --git a/agents-api/agents_api/queries/developers/__init__.py b/agents-api/agents_api/queries/developers/__init__.py index b3964aba4..c3d1d4bbb 100644 --- a/agents-api/agents_api/queries/developers/__init__.py +++ b/agents-api/agents_api/queries/developers/__init__.py @@ -20,3 +20,10 @@ from .get_developer import get_developer from .patch_developer import patch_developer from .update_developer import update_developer + +__all__ = [ + "create_developer", + "get_developer", + "patch_developer", + "update_developer", +] diff --git a/agents-api/agents_api/queries/developers/create_developer.py b/agents-api/agents_api/queries/developers/create_developer.py index 7ee845fbf..793d2f184 100644 --- a/agents-api/agents_api/queries/developers/create_developer.py +++ b/agents-api/agents_api/queries/developers/create_developer.py @@ -3,14 +3,19 @@ from beartype import beartype from sqlglot import parse_one from uuid_extensions import uuid7 +import asyncpg +from fastapi import HTTPException from ...common.protocol.developers import Developer from ..utils import ( pg_query, wrap_in_class, + rewrap_exceptions, + partialclass, ) -query = parse_one(""" +# Define the raw SQL query +developer_query = parse_one(""" INSERT INTO developers ( developer_id, email, @@ -19,22 +24,25 @@ settings ) VALUES ( - $1, - $2, - $3, - $4, - $5::jsonb + $1, -- developer_id + $2, -- email + $3, -- active + $4, -- tags + $5::jsonb -- settings ) RETURNING *; """).sql(pretty=True) -# @rewrap_exceptions( -# { -# QueryException: partialclass(HTTPException, status_code=403), -# ValidationError: partialclass(HTTPException, status_code=500), -# } -# ) +@rewrap_exceptions( + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ) + } +) @wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) @pg_query @beartype @@ -49,6 +57,6 @@ async def create_developer( developer_id = str(developer_id or uuid7()) return ( - query, + developer_query, [developer_id, email, active, tags or [], settings or {}], ) diff --git a/agents-api/agents_api/queries/developers/get_developer.py b/agents-api/agents_api/queries/developers/get_developer.py index 38302ab3b..54d4cf9d9 100644 --- a/agents-api/agents_api/queries/developers/get_developer.py +++ b/agents-api/agents_api/queries/developers/get_developer.py @@ -5,11 +5,12 @@ from beartype import beartype from fastapi import HTTPException -from pydantic import ValidationError from sqlglot import parse_one +import asyncpg from ...common.protocol.developers import Developer from ..utils import ( + partialclass, pg_query, rewrap_exceptions, wrap_in_class, @@ -18,18 +19,24 @@ # TODO: Add verify_developer verify_developer = None -query = parse_one("SELECT * FROM developers WHERE developer_id = $1").sql(pretty=True) +# Define the raw SQL query +developer_query = parse_one(""" +SELECT * FROM developers WHERE developer_id = $1 -- developer_id +""").sql(pretty=True) ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") -# @rewrap_exceptions( -# { -# QueryException: partialclass(HTTPException, status_code=403), -# ValidationError: partialclass(HTTPException, status_code=500), -# } -# ) +@rewrap_exceptions( + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ) + } +) @wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) @pg_query @beartype @@ -40,6 +47,6 @@ async def get_developer( developer_id = str(developer_id) return ( - query, + developer_query, [developer_id], ) diff --git a/agents-api/agents_api/queries/developers/patch_developer.py b/agents-api/agents_api/queries/developers/patch_developer.py index 49edfe370..b37fc7c5e 100644 --- a/agents-api/agents_api/queries/developers/patch_developer.py +++ b/agents-api/agents_api/queries/developers/patch_developer.py @@ -2,27 +2,35 @@ from beartype import beartype from sqlglot import parse_one +import asyncpg +from fastapi import HTTPException from ...common.protocol.developers import Developer from ..utils import ( pg_query, wrap_in_class, + partialclass, + rewrap_exceptions, ) -query = parse_one(""" +# Define the raw SQL query +developer_query = parse_one(""" UPDATE developers -SET email = $1, active = $2, tags = tags || $3, settings = settings || $4 -WHERE developer_id = $5 +SET email = $1, active = $2, tags = tags || $3, settings = settings || $4 -- settings +WHERE developer_id = $5 -- developer_id RETURNING *; """).sql(pretty=True) -# @rewrap_exceptions( -# { -# QueryException: partialclass(HTTPException, status_code=403), -# ValidationError: partialclass(HTTPException, status_code=500), -# } -# ) +@rewrap_exceptions( + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ) + } +) @wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) @pg_query @beartype @@ -37,6 +45,6 @@ async def patch_developer( developer_id = str(developer_id) return ( - query, + developer_query, [email, active, tags or [], settings or {}, developer_id], ) diff --git a/agents-api/agents_api/queries/developers/update_developer.py b/agents-api/agents_api/queries/developers/update_developer.py index 8350d45a0..410d5ca12 100644 --- a/agents-api/agents_api/queries/developers/update_developer.py +++ b/agents-api/agents_api/queries/developers/update_developer.py @@ -2,14 +2,18 @@ from beartype import beartype from sqlglot import parse_one - +import asyncpg +from fastapi import HTTPException from ...common.protocol.developers import Developer from ..utils import ( pg_query, wrap_in_class, + partialclass, + rewrap_exceptions, ) -query = parse_one(""" +# Define the raw SQL query +developer_query = parse_one(""" UPDATE developers SET email = $1, active = $2, tags = $3, settings = $4 WHERE developer_id = $5 @@ -17,12 +21,15 @@ """).sql(pretty=True) -# @rewrap_exceptions( -# { -# QueryException: partialclass(HTTPException, status_code=403), -# ValidationError: partialclass(HTTPException, status_code=500), -# } -# ) +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ) + } +) @wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) @pg_query @beartype @@ -37,6 +44,6 @@ async def update_developer( developer_id = str(developer_id) return ( - query, + developer_query, [email, active, tags or [], settings or {}, developer_id], ) diff --git a/agents-api/agents_api/queries/entry/__init__.py b/agents-api/agents_api/queries/entries/__init__.py similarity index 68% rename from agents-api/agents_api/queries/entry/__init__.py rename to agents-api/agents_api/queries/entries/__init__.py index 2ad83f115..7c196dd62 100644 --- a/agents-api/agents_api/queries/entry/__init__.py +++ b/agents-api/agents_api/queries/entries/__init__.py @@ -8,14 +8,14 @@ - Listing entries with filtering and pagination """ -from .create_entries import create_entries -from .delete_entries import delete_entries_for_session +from .create_entry import create_entries +from .delete_entry import delete_entries from .get_history import get_history -from .list_entries import list_entries +from .list_entry import list_entries __all__ = [ "create_entries", - "delete_entries_for_session", + "delete_entries", "get_history", "list_entries", ] diff --git a/agents-api/agents_api/queries/entries/create_entry.py b/agents-api/agents_api/queries/entries/create_entry.py new file mode 100644 index 000000000..471d02fe6 --- /dev/null +++ b/agents-api/agents_api/queries/entries/create_entry.py @@ -0,0 +1,196 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateEntryRequest, Entry, Relation +from ...common.utils.datetime import utcnow +from ...common.utils.messages import content_to_json +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for creating entries with a developer check +entry_query = (""" +WITH data AS ( + SELECT + unnest($1::uuid[]) AS session_id, + unnest($2::uuid[]) AS entry_id, + unnest($3::text[]) AS source, + unnest($4::text[])::chat_role AS role, + unnest($5::text[]) AS event_type, + unnest($6::text[]) AS name, + array[unnest($7::jsonb[])] AS content, + unnest($8::text[]) AS tool_call_id, + array[unnest($9::jsonb[])] AS tool_calls, + unnest($10::text[]) AS model, + unnest($11::int[]) AS token_count, + unnest($12::timestamptz[]) AS created_at, + unnest($13::timestamptz[]) AS timestamp +) +INSERT INTO entries ( + session_id, + entry_id, + source, + role, + event_type, + name, + content, + tool_call_id, + tool_calls, + model, + token_count, + created_at, + timestamp +) +SELECT + d.session_id, + d.entry_id, + d.source, + d.role, + d.event_type, + d.name, + d.content, + d.tool_call_id, + d.tool_calls, + d.model, + d.token_count, + d.created_at, + d.timestamp +FROM + data d +JOIN + developers ON developers.developer_id = $14 +RETURNING *; +""") + +# Define the raw SQL query for creating entry relations +entry_relation_query = (""" +WITH data AS ( + SELECT + unnest($1::uuid[]) AS session_id, + unnest($2::uuid[]) AS head, + unnest($3::text[]) AS relation, + unnest($4::uuid[]) AS tail, + unnest($5::boolean[]) AS is_leaf +) +INSERT INTO entry_relations ( + session_id, + head, + relation, + tail, + is_leaf +) +SELECT + d.session_id, + d.head, + d.relation, + d.tail, + d.is_leaf +FROM + data d +JOIN + developers ON developers.developer_id = $6 +RETURNING *; +""") + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( + status_code=404, + detail=str(exc), + ), + asyncpg.UniqueViolationError: lambda exc: HTTPException( + status_code=409, + detail=str(exc), + ), + asyncpg.NotNullViolationError: lambda exc: HTTPException( + status_code=400, + detail=str(exc), + ), + } +) +@wrap_in_class( + Entry, + transform=lambda d: { + "id": UUID(d.pop("entry_id")), + **d, + }, +) +@increase_counter("create_entries") +@pg_query +@beartype +async def create_entries( + *, + developer_id: UUID, + session_id: UUID, + data: list[CreateEntryRequest], +) -> tuple[str, list]: + # Convert the data to a list of dictionaries + data_dicts = [item.model_dump(mode="json") for item in data] + + # Prepare the parameters for the query + params = [ + [session_id] * len(data_dicts), # $1 + [item.pop("id", None) or str(uuid7()) for item in data_dicts], # $2 + [item.get("source") for item in data_dicts], # $3 + [item.get("role") for item in data_dicts], # $4 + [item.get("event_type") or "message.create" for item in data_dicts], # $5 + [item.get("name") for item in data_dicts], # $6 + [content_to_json(item.get("content") or {}) for item in data_dicts], # $7 + [item.get("tool_call_id") for item in data_dicts], # $8 + [content_to_json(item.get("tool_calls") or {}) for item in data_dicts], # $9 + [item.get("modelname") for item in data_dicts], # $10 + [item.get("token_count") for item in data_dicts], # $11 + [item.get("created_at") or utcnow() for item in data_dicts], # $12 + [utcnow() for _ in data_dicts], # $13 + developer_id, # $14 + ] + + return ( + entry_query, + params, + ) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( + status_code=404, + detail=str(exc), + ), + asyncpg.UniqueViolationError: lambda exc: HTTPException( + status_code=409, + detail=str(exc), + ), + } +) +@wrap_in_class(Relation) +@increase_counter("add_entry_relations") +@pg_query +@beartype +async def add_entry_relations( + *, + developer_id: UUID, + data: list[Relation], +) -> tuple[str, list]: + # Convert the data to a list of dictionaries + data_dicts = [item.model_dump(mode="json") for item in data] + + # Prepare the parameters for the query + params = [ + [item.get("session_id") for item in data_dicts], # $1 + [item.get("head") for item in data_dicts], # $2 + [item.get("relation") for item in data_dicts], # $3 + [item.get("tail") for item in data_dicts], # $4 + [item.get("is_leaf", False) for item in data_dicts], # $5 + developer_id, # $6 + ] + + return ( + entry_relation_query, + params, + ) diff --git a/agents-api/agents_api/queries/entries/delete_entry.py b/agents-api/agents_api/queries/entries/delete_entry.py new file mode 100644 index 000000000..82615745f --- /dev/null +++ b/agents-api/agents_api/queries/entries/delete_entry.py @@ -0,0 +1,96 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...common.utils.datetime import utcnow +from ...autogen.openapi_model import ResourceDeletedResponse +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for deleting entries with a developer check +entry_query = parse_one(""" +DELETE FROM entries +USING developers +WHERE entries.session_id = $1 -- session_id +AND developers.developer_id = $2 +RETURNING entries.session_id as session_id; +""").sql(pretty=True) + +# Define the raw SQL query for deleting entries by entry_ids with a developer check +delete_entry_by_ids_query = parse_one(""" +DELETE FROM entries +USING developers +WHERE entries.entry_id = ANY($1) -- entry_ids +AND developers.developer_id = $2 +AND entries.session_id = $3 -- session_id +RETURNING entries.entry_id as entry_id; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( + status_code=400, + detail=str(exc), + ), + asyncpg.UniqueViolationError: lambda exc: HTTPException( + status_code=404, + detail=str(exc), + ), + } +) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + "id": d["session_id"], # Only return session cleared + "deleted_at": utcnow(), + "jobs": [], + }, +) +@pg_query +@beartype +async def delete_entries_for_session( + *, + developer_id: UUID, + session_id: UUID, +) -> tuple[str, list]: + return ( + entry_query, + [session_id, developer_id], + ) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=400, + detail="The specified developer does not exist.", + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="One or more specified entries do not exist.", + ), + } +) +@wrap_in_class( + ResourceDeletedResponse, + transform=lambda d: { + "id": d["entry_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@pg_query +@beartype +async def delete_entries( + *, developer_id: UUID, session_id: UUID, entry_ids: list[UUID] +) -> tuple[str, list]: + return ( + delete_entry_by_ids_query, + [entry_ids, developer_id, session_id], + ) diff --git a/agents-api/agents_api/queries/entries/get_history.py b/agents-api/agents_api/queries/entries/get_history.py new file mode 100644 index 000000000..c6c38d366 --- /dev/null +++ b/agents-api/agents_api/queries/entries/get_history.py @@ -0,0 +1,72 @@ +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import History +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for getting history with a developer check +history_query = parse_one(""" +SELECT + e.entry_id as id, -- entry_id + e.session_id, -- session_id + e.role, -- role + e.name, -- name + e.content, -- content + e.source, -- source + e.token_count, -- token_count + e.created_at, -- created_at + e.timestamp, -- timestamp + e.tool_calls, -- tool_calls + e.tool_call_id -- tool_call_id +FROM entries e +JOIN developers d ON d.developer_id = $3 +WHERE e.session_id = $1 +AND e.source = ANY($2) +ORDER BY e.created_at; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( + status_code=404, + detail=str(exc), + ), + asyncpg.UniqueViolationError: lambda exc: HTTPException( + status_code=404, + detail=str(exc), + ), + } +) +@wrap_in_class( + History, + one=True, + transform=lambda d: { + **d, + "relations": [ + { + "head": r["head"], + "relation": r["relation"], + "tail": r["tail"], + } + for r in d.pop("relations") + ], + "entries": d.pop("entries"), + }, +) +@pg_query +@beartype +async def get_history( + *, + developer_id: UUID, + session_id: UUID, + allowed_sources: list[str] = ["api_request", "api_response"], +) -> tuple[str, list]: + return ( + history_query, + [session_id, allowed_sources, developer_id], + ) diff --git a/agents-api/agents_api/queries/entries/list_entry.py b/agents-api/agents_api/queries/entries/list_entry.py new file mode 100644 index 000000000..5a4871a88 --- /dev/null +++ b/agents-api/agents_api/queries/entries/list_entry.py @@ -0,0 +1,80 @@ +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import Entry +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +entry_query = """ +SELECT + e.entry_id as id, -- entry_id + e.session_id, -- session_id + e.role, -- role + e.name, -- name + e.content, -- content + e.source, -- source + e.token_count, -- token_count + e.created_at, -- created_at + e.timestamp -- timestamp +FROM entries e +JOIN developers d ON d.developer_id = $7 +LEFT JOIN entry_relations er ON er.head = e.entry_id AND er.session_id = e.session_id +WHERE e.session_id = $1 +AND e.source = ANY($2) +AND (er.relation IS NULL OR er.relation != ALL($8)) +ORDER BY e.$3 $4 +LIMIT $5 +OFFSET $6; +""" + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( + status_code=404, + detail=str(exc), + ), + asyncpg.UniqueViolationError: lambda exc: HTTPException( + status_code=404, + detail=str(exc), + ), + } +) +@wrap_in_class(Entry) +@pg_query +@beartype +async def list_entries( + *, + developer_id: UUID, + session_id: UUID, + allowed_sources: list[str] = ["api_request", "api_response"], + limit: int = 1, + offset: int = 0, + sort_by: Literal["created_at", "timestamp"] = "timestamp", + direction: Literal["asc", "desc"] = "asc", + exclude_relations: list[str] = [], +) -> tuple[str, list]: + + if limit < 1 or limit > 1000: + raise HTTPException(status_code=400, detail="Limit must be between 1 and 1000") + if offset < 0: + raise HTTPException(status_code=400, detail="Offset must be non-negative") + + # making the parameters for the query + params = [ + session_id, # $1 + allowed_sources, # $2 + sort_by, # $3 + direction, # $4 + limit, # $5 + offset, # $6 + developer_id, # $7 + exclude_relations, # $8 + ] + return ( + entry_query, + params, + ) diff --git a/agents-api/agents_api/queries/entry/create_entries.py b/agents-api/agents_api/queries/entry/create_entries.py deleted file mode 100644 index d3b3b4982..000000000 --- a/agents-api/agents_api/queries/entry/create_entries.py +++ /dev/null @@ -1,107 +0,0 @@ -from uuid import UUID - -import asyncpg -from beartype import beartype -from fastapi import HTTPException -from sqlglot import parse_one -from sqlglot.optimizer import optimize -from uuid_extensions import uuid7 - -from ...autogen.openapi_model import CreateEntryRequest, Entry -from ...common.utils.datetime import utcnow -from ...common.utils.messages import content_to_json -from ...metrics.counters import increase_counter -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class - -# Define the raw SQL query for creating entries with a developer check -raw_query = """ -INSERT INTO entries ( - session_id, - entry_id, - source, - role, - event_type, - name, - content, - tool_call_id, - tool_calls, - model, - token_count, - created_at, - timestamp -) -SELECT - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13 -FROM - developers -WHERE - developer_id = $14 -RETURNING *; -""" - -# Parse and optimize the query -query = optimize( - parse_one(raw_query), - schema={ - "entries": { - "session_id": "UUID", - "entry_id": "UUID", - "source": "TEXT", - "role": "chat_role", - "event_type": "TEXT", - "name": "TEXT", - "content": "JSONB[]", - "tool_call_id": "TEXT", - "tool_calls": "JSONB[]", - "model": "TEXT", - "token_count": "INTEGER", - "created_at": "TIMESTAMP", - "timestamp": "TIMESTAMP", - } - }, -).sql(pretty=True) - - -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: partialclass(HTTPException, status_code=400), - asyncpg.UniqueViolationError: partialclass(HTTPException, status_code=409), - } -) -@wrap_in_class(Entry) -@increase_counter("create_entries") -@pg_query -@beartype -def create_entries( - *, - developer_id: UUID, - session_id: UUID, - data: list[CreateEntryRequest], - mark_session_as_updated: bool = True, -) -> tuple[str, list]: - data_dicts = [item.model_dump(mode="json") for item in data] - - params = [ - ( - session_id, - item.pop("id", None) or str(uuid7()), - item.get("source"), - item.get("role"), - item.get("event_type") or "message.create", - item.get("name"), - content_to_json(item.get("content") or []), - item.get("tool_call_id"), - item.get("tool_calls") or [], - item.get("model"), - item.get("token_count"), - (item.get("created_at") or utcnow()).timestamp(), - utcnow().timestamp(), - developer_id, - ) - for item in data_dicts - ] - - return ( - query, - params, - ) diff --git a/agents-api/agents_api/queries/entry/delete_entries.py b/agents-api/agents_api/queries/entry/delete_entries.py deleted file mode 100644 index 1fa34176f..000000000 --- a/agents-api/agents_api/queries/entry/delete_entries.py +++ /dev/null @@ -1,48 +0,0 @@ -from uuid import UUID - -import asyncpg -from beartype import beartype -from fastapi import HTTPException -from sqlglot import parse_one -from sqlglot.optimizer import optimize - -from ...autogen.openapi_model import ResourceDeletedResponse -from ...metrics.counters import increase_counter -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class - -# Define the raw SQL query for deleting entries with a developer check -raw_query = """ -DELETE FROM entries -USING developers -WHERE entries.session_id = $1 -AND developers.developer_id = $2 -RETURNING entries.session_id as id; -""" - -# Parse and optimize the query -query = optimize( - parse_one(raw_query), - schema={ - "entries": { - "session_id": "UUID", - } - }, -).sql(pretty=True) - - -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(ResourceDeletedResponse, one=True) -@increase_counter("delete_entries_for_session") -@pg_query -@beartype -def delete_entries_for_session( - *, developer_id: UUID, session_id: UUID, mark_session_as_updated: bool = True -) -> tuple[str, list]: - return ( - query, - [session_id, developer_id], - ) diff --git a/agents-api/agents_api/queries/entry/get_history.py b/agents-api/agents_api/queries/entry/get_history.py deleted file mode 100644 index dd06734b0..000000000 --- a/agents-api/agents_api/queries/entry/get_history.py +++ /dev/null @@ -1,73 +0,0 @@ -from uuid import UUID - -import asyncpg -from beartype import beartype -from fastapi import HTTPException -from sqlglot import parse_one -from sqlglot.optimizer import optimize - -from ...autogen.openapi_model import History -from ...metrics.counters import increase_counter -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class - -# Define the raw SQL query for getting history with a developer check -raw_query = """ -SELECT - e.entry_id as id, - e.session_id, - e.role, - e.name, - e.content, - e.source, - e.token_count, - e.created_at, - e.timestamp, - e.tool_calls, - e.tool_call_id -FROM entries e -JOIN developers d ON d.developer_id = $3 -WHERE e.session_id = $1 -AND e.source = ANY($2) -ORDER BY e.created_at; -""" - -# Parse and optimize the query -query = optimize( - parse_one(raw_query), - schema={ - "entries": { - "entry_id": "UUID", - "session_id": "UUID", - "role": "STRING", - "name": "STRING", - "content": "JSONB", - "source": "STRING", - "token_count": "INTEGER", - "created_at": "TIMESTAMP", - "timestamp": "TIMESTAMP", - "tool_calls": "JSONB", - "tool_call_id": "UUID", - } - }, -).sql(pretty=True) - - -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(History, one=True) -@increase_counter("get_history") -@pg_query -@beartype -def get_history( - *, - developer_id: UUID, - session_id: UUID, - allowed_sources: list[str] = ["api_request", "api_response"], -) -> tuple[str, list]: - return ( - query, - [session_id, allowed_sources, developer_id], - ) diff --git a/agents-api/agents_api/queries/entry/list_entries.py b/agents-api/agents_api/queries/entry/list_entries.py deleted file mode 100644 index 42add6899..000000000 --- a/agents-api/agents_api/queries/entry/list_entries.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Literal -from uuid import UUID - -import asyncpg -from beartype import beartype -from fastapi import HTTPException -from sqlglot import parse_one -from sqlglot.optimizer import optimize - -from ...autogen.openapi_model import Entry -from ...metrics.counters import increase_counter -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class - -# Define the raw SQL query for listing entries with a developer check -raw_query = """ -SELECT - e.entry_id as id, - e.session_id, - e.role, - e.name, - e.content, - e.source, - e.token_count, - e.created_at, - e.timestamp -FROM entries e -JOIN developers d ON d.developer_id = $7 -WHERE e.session_id = $1 -AND e.source = ANY($2) -ORDER BY e.$3 $4 -LIMIT $5 OFFSET $6; -""" - -# Parse and optimize the query -query = optimize( - parse_one(raw_query), - schema={ - "entries": { - "entry_id": "UUID", - "session_id": "UUID", - "role": "STRING", - "name": "STRING", - "content": "JSONB", - "source": "STRING", - "token_count": "INTEGER", - "created_at": "TIMESTAMP", - "timestamp": "TIMESTAMP", - } - }, -).sql(pretty=True) - - -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: partialclass(HTTPException, status_code=400), - } -) -@wrap_in_class(Entry) -@increase_counter("list_entries") -@pg_query -@beartype -def list_entries( - *, - developer_id: UUID, - session_id: UUID, - allowed_sources: list[str] = ["api_request", "api_response"], - limit: int = -1, - offset: int = 0, - sort_by: Literal["created_at", "timestamp"] = "timestamp", - direction: Literal["asc", "desc"] = "asc", - exclude_relations: list[str] = [], -) -> tuple[str, list]: - return ( - query, - [session_id, allowed_sources, sort_by, direction, limit, offset, developer_id], - ) diff --git a/agents-api/agents_api/queries/users/create_or_update_user.py b/agents-api/agents_api/queries/users/create_or_update_user.py index d2be71bb4..6fd97942a 100644 --- a/agents-api/agents_api/queries/users/create_or_update_user.py +++ b/agents-api/agents_api/queries/users/create_or_update_user.py @@ -4,14 +4,13 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize -from ...autogen.openapi_model import CreateOrUpdateUserRequest, User from ...metrics.counters import increase_counter +from ...autogen.openapi_model import CreateOrUpdateUserRequest, User from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -# Optimize the raw query by using COALESCE for metadata to avoid explicit check -raw_query = """ +# Define the raw SQL query for creating or updating a user +user_query = parse_one(""" INSERT INTO users ( developer_id, user_id, @@ -20,21 +19,18 @@ metadata ) VALUES ( - $1, - $2, - $3, - $4, - $5 + $1, -- developer_id + $2, -- user_id + $3, -- name + $4, -- about + $5::jsonb -- metadata ) ON CONFLICT (developer_id, user_id) DO UPDATE SET name = EXCLUDED.name, about = EXCLUDED.about, metadata = EXCLUDED.metadata RETURNING *; -""" - -# Add index hint for better performance -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) @rewrap_exceptions( @@ -51,7 +47,14 @@ ), } ) -@wrap_in_class(User, one=True, transform=lambda d: {**d, "id": d["user_id"]}) +@wrap_in_class( + User, + one=True, + transform=lambda d: { + **d, + "id": d["user_id"], + }, +) @increase_counter("create_or_update_user") @pg_query @beartype @@ -73,14 +76,14 @@ async def create_or_update_user( HTTPException: If developer doesn't exist (404) or on unique constraint violation (409) """ params = [ - developer_id, - user_id, - data.name, - data.about, - data.metadata or {}, + developer_id, # $1 + user_id, # $2 + data.name, # $3 + data.about, # $4 + data.metadata or {}, # $5 ] return ( - query, + user_query, params, ) diff --git a/agents-api/agents_api/queries/users/create_user.py b/agents-api/agents_api/queries/users/create_user.py index 66e8bcc27..d77fbff47 100644 --- a/agents-api/agents_api/queries/users/create_user.py +++ b/agents-api/agents_api/queries/users/create_user.py @@ -4,15 +4,14 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from uuid_extensions import uuid7 -from ...autogen.openapi_model import CreateUserRequest, User from ...metrics.counters import increase_counter +from ...autogen.openapi_model import CreateUserRequest, User from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query outside the function -raw_query = """ +user_query = parse_one(""" INSERT INTO users ( developer_id, user_id, @@ -21,17 +20,14 @@ metadata ) VALUES ( - $1, - $2, - $3, - $4, - $5 + $1, -- developer_id + $2, -- user_id + $3, -- name + $4, -- about + $5::jsonb -- metadata ) RETURNING *; -""" - -# Parse and optimize the query -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) @rewrap_exceptions( @@ -48,7 +44,14 @@ ), } ) -@wrap_in_class(User, one=True, transform=lambda d: {**d, "id": d["user_id"]}) +@wrap_in_class( + User, + one=True, + transform=lambda d: { + **d, + "id": d["user_id"], + }, +) @increase_counter("create_user") @pg_query @beartype @@ -72,14 +75,14 @@ async def create_user( user_id = user_id or uuid7() params = [ - developer_id, - user_id, - data.name, - data.about, - data.metadata or {}, + developer_id, # $1 + user_id, # $2 + data.name, # $3 + data.about, # $4 + data.metadata or {}, # $5 ] return ( - query, + user_query, params, ) diff --git a/agents-api/agents_api/queries/users/delete_user.py b/agents-api/agents_api/queries/users/delete_user.py index 520c8d695..86bcc0b26 100644 --- a/agents-api/agents_api/queries/users/delete_user.py +++ b/agents-api/agents_api/queries/users/delete_user.py @@ -4,18 +4,17 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow -from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query outside the function -raw_query = """ +delete_query = parse_one(""" WITH deleted_data AS ( - DELETE FROM user_files - WHERE developer_id = $1 AND user_id = $2 + DELETE FROM user_files -- user_files + WHERE developer_id = $1 -- developer_id + AND user_id = $2 -- user_id ), deleted_docs AS ( DELETE FROM user_docs @@ -24,10 +23,7 @@ DELETE FROM users WHERE developer_id = $1 AND user_id = $2 RETURNING user_id, developer_id; -""" - -# Parse and optimize the query -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) @rewrap_exceptions( @@ -36,15 +32,24 @@ HTTPException, status_code=404, detail="The specified developer does not exist.", - ) + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified user does not exist.", + ), } ) @wrap_in_class( ResourceDeletedResponse, one=True, - transform=lambda d: {**d, "id": d["user_id"], "deleted_at": utcnow()}, + transform=lambda d: { + **d, + "id": d["user_id"], + "deleted_at": utcnow(), + "jobs": [], + }, ) -@increase_counter("delete_user") @pg_query @beartype async def delete_user(*, developer_id: UUID, user_id: UUID) -> tuple[str, list]: @@ -61,6 +66,6 @@ async def delete_user(*, developer_id: UUID, user_id: UUID) -> tuple[str, list]: """ return ( - query, + delete_query, [developer_id, user_id], ) diff --git a/agents-api/agents_api/queries/users/get_user.py b/agents-api/agents_api/queries/users/get_user.py index 6989c8edb..2b71f9192 100644 --- a/agents-api/agents_api/queries/users/get_user.py +++ b/agents-api/agents_api/queries/users/get_user.py @@ -4,29 +4,24 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import User -from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query outside the function -raw_query = """ +user_query = parse_one(""" SELECT - user_id as id, - developer_id, - name, - about, - metadata, - created_at, - updated_at + user_id as id, -- user_id + developer_id, -- developer_id + name, -- name + about, -- about + metadata, -- metadata + created_at, -- created_at + updated_at -- updated_at FROM users WHERE developer_id = $1 AND user_id = $2; -""" - -# Parse and optimize the query -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) @rewrap_exceptions( @@ -35,11 +30,15 @@ HTTPException, status_code=404, detail="The specified developer does not exist.", - ) + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified user does not exist.", + ), } ) @wrap_in_class(User, one=True) -@increase_counter("get_user") @pg_query @beartype async def get_user(*, developer_id: UUID, user_id: UUID) -> tuple[str, list]: @@ -56,6 +55,6 @@ async def get_user(*, developer_id: UUID, user_id: UUID) -> tuple[str, list]: """ return ( - query, + user_query, [developer_id, user_id], ) diff --git a/agents-api/agents_api/queries/users/list_users.py b/agents-api/agents_api/queries/users/list_users.py index 7f3677eab..0f0818135 100644 --- a/agents-api/agents_api/queries/users/list_users.py +++ b/agents-api/agents_api/queries/users/list_users.py @@ -4,24 +4,21 @@ import asyncpg from beartype import beartype from fastapi import HTTPException -from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import User -from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query outside the function -raw_query = """ +user_query = """ WITH filtered_users AS ( SELECT - user_id as id, - developer_id, - name, - about, - metadata, - created_at, - updated_at + user_id as id, -- user_id + developer_id, -- developer_id + name, -- name + about, -- about + metadata, -- metadata + created_at, -- created_at + updated_at -- updated_at FROM users WHERE developer_id = $1 AND ($4::jsonb IS NULL OR metadata @> $4) @@ -37,9 +34,6 @@ OFFSET $3; """ -# Parse and optimize the query -# query = parse_one(raw_query).sql(pretty=True) - @rewrap_exceptions( { @@ -47,11 +41,15 @@ HTTPException, status_code=404, detail="The specified developer does not exist.", - ) + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified user does not exist.", + ), } ) @wrap_in_class(User) -@increase_counter("list_users") @pg_query @beartype async def list_users( @@ -84,15 +82,15 @@ async def list_users( raise HTTPException(status_code=400, detail="Offset must be non-negative") params = [ - developer_id, - limit, - offset, + developer_id, # $1 + limit, # $2 + offset, # $3 metadata_filter, # Will be NULL if not provided - sort_by, - direction, + sort_by, # $4 + direction, # $5 ] return ( - raw_query, + user_query, params, ) diff --git a/agents-api/agents_api/queries/users/patch_user.py b/agents-api/agents_api/queries/users/patch_user.py index 971e96b81..c55ee31b7 100644 --- a/agents-api/agents_api/queries/users/patch_user.py +++ b/agents-api/agents_api/queries/users/patch_user.py @@ -4,42 +4,38 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import PatchUserRequest, ResourceUpdatedResponse from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query outside the function -raw_query = """ +user_query = parse_one(""" UPDATE users SET name = CASE - WHEN $3::text IS NOT NULL THEN $3 + WHEN $3::text IS NOT NULL THEN $3 -- name ELSE name END, about = CASE - WHEN $4::text IS NOT NULL THEN $4 + WHEN $4::text IS NOT NULL THEN $4 -- about ELSE about END, metadata = CASE - WHEN $5::jsonb IS NOT NULL THEN metadata || $5 + WHEN $5::jsonb IS NOT NULL THEN metadata || $5 -- metadata ELSE metadata END WHERE developer_id = $1 AND user_id = $2 RETURNING - user_id as id, - developer_id, - name, - about, - metadata, - created_at, - updated_at; -""" - -# Parse and optimize the query -query = parse_one(raw_query).sql(pretty=True) + user_id as id, -- user_id + developer_id, -- developer_id + name, -- name + about, -- about + metadata, -- metadata + created_at, -- created_at + updated_at; -- updated_at +""").sql(pretty=True) @rewrap_exceptions( @@ -48,7 +44,12 @@ HTTPException, status_code=404, detail="The specified developer does not exist.", - ) + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified user does not exist.", + ), } ) @wrap_in_class(ResourceUpdatedResponse, one=True) @@ -71,11 +72,14 @@ async def patch_user( tuple[str, list]: SQL query and parameters """ params = [ - developer_id, - user_id, - data.name, # Will be NULL if not provided - data.about, # Will be NULL if not provided - data.metadata, # Will be NULL if not provided + developer_id, # $1 + user_id, # $2 + data.name, # $3. Will be NULL if not provided + data.about, # $4. Will be NULL if not provided + data.metadata, # $5. Will be NULL if not provided ] - return query, params + return ( + user_query, + params, + ) diff --git a/agents-api/agents_api/queries/users/update_user.py b/agents-api/agents_api/queries/users/update_user.py index 1fffdebe7..91572e15d 100644 --- a/agents-api/agents_api/queries/users/update_user.py +++ b/agents-api/agents_api/queries/users/update_user.py @@ -4,26 +4,22 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateUserRequest from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query outside the function -raw_query = """ +user_query = parse_one(""" UPDATE users SET - name = $3, - about = $4, - metadata = $5 -WHERE developer_id = $1 -AND user_id = $2 + name = $3, -- name + about = $4, -- about + metadata = $5 -- metadata +WHERE developer_id = $1 -- developer_id +AND user_id = $2 -- user_id RETURNING * -""" - -# Parse and optimize the query -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) @rewrap_exceptions( @@ -32,7 +28,12 @@ HTTPException, status_code=404, detail="The specified developer does not exist.", - ) + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified user does not exist.", + ), } ) @wrap_in_class( @@ -67,6 +68,6 @@ async def update_user( ] return ( - query, + user_query, params, ) diff --git a/agents-api/tests/test_developer_queries.py b/agents-api/tests/test_developer_queries.py index d360a7dc2..eedc07dd2 100644 --- a/agents-api/tests/test_developer_queries.py +++ b/agents-api/tests/test_developer_queries.py @@ -4,7 +4,6 @@ from ward import raises, test from agents_api.clients.pg import create_db_pool -from agents_api.common.protocol.developers import Developer from agents_api.queries.developers.create_developer import create_developer from agents_api.queries.developers.get_developer import ( get_developer, diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 220b8d232..242d0abfb 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -1,89 +1,53 @@ -# """ -# This module contains tests for entry queries against the CozoDB database. -# It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. -# """ +""" +This module contains tests for entry queries against the CozoDB database. +It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. +""" -# # Tests for entry queries +from uuid import UUID -# import time +from ward import test +from agents_api.clients.pg import create_db_pool -# from ward import test +from agents_api.queries.entries.create_entry import create_entries +from agents_api.queries.entries.list_entry import list_entries +from agents_api.queries.entries.get_history import get_history +from agents_api.queries.entries.delete_entry import delete_entries +from tests.fixtures import pg_dsn, test_developer_id # , test_session +from agents_api.autogen.openapi_model import CreateEntryRequest, Entry -# from agents_api.autogen.openapi_model import CreateEntryRequest -# from agents_api.queries.entry.create_entries import create_entries -# from agents_api.queries.entry.delete_entries import delete_entries -# from agents_api.queries.entry.get_history import get_history -# from agents_api.queries.entry.list_entries import list_entries -# from agents_api.queries.session.get_session import get_session -# from tests.fixtures import cozo_client, test_developer_id, test_session +# Test UUIDs for consistent testing +MODEL = "gpt-4o-mini" +SESSION_ID = UUID("123e4567-e89b-12d3-a456-426614174001") +TEST_DEVELOPER_ID = UUID("123e4567-e89b-12d3-a456-426614174000") +TEST_USER_ID = UUID("987e6543-e21b-12d3-a456-426614174000") -# MODEL = "gpt-4o-mini" +@test("query: create entry") +async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session + """Test the addition of a new entry to the database.""" -# @test("query: create entry") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# """ -# Tests the addition of a new entry to the database. -# Verifies that the entry can be successfully added using the create_entries function. -# """ + pool = await create_db_pool(dsn=dsn) + test_entry = CreateEntryRequest.from_model_input( + model=MODEL, + role="user", + source="internal", + content="test entry content", + ) -# test_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# source="internal", -# content="test entry content", -# ) - -# create_entries( -# developer_id=developer_id, -# session_id=session.id, -# data=[test_entry], -# mark_session_as_updated=False, -# client=client, -# ) + await create_entries( + developer_id=TEST_DEVELOPER_ID, + session_id=SESSION_ID, + data=[test_entry], + connection_pool=pool, + ) -# @test("query: create entry, update session") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# """ -# Tests the addition of a new entry to the database. -# Verifies that the entry can be successfully added using the create_entries function. -# """ - -# test_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# source="internal", -# content="test entry content", -# ) - -# # TODO: We should make sessions.updated_at also a updated_at_ms field to avoid this sleep -# time.sleep(1) - -# create_entries( -# developer_id=developer_id, -# session_id=session.id, -# data=[test_entry], -# mark_session_as_updated=True, -# client=client, -# ) - -# updated_session = get_session( -# developer_id=developer_id, -# session_id=session.id, -# client=client, -# ) - -# assert updated_session.updated_at > session.updated_at - # @test("query: get entries") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# """ -# Tests the retrieval of entries from the database. -# Verifies that entries matching specific criteria can be successfully retrieved. -# """ +# async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session +# """Test the retrieval of entries from the database.""" +# pool = await create_db_pool(dsn=dsn) # test_entry = CreateEntryRequest.from_model_input( # model=MODEL, # role="user", @@ -98,30 +62,32 @@ # source="internal", # ) -# create_entries( -# developer_id=developer_id, -# session_id=session.id, +# await create_entries( +# developer_id=TEST_DEVELOPER_ID, +# session_id=SESSION_ID, # data=[test_entry, internal_entry], -# client=client, +# connection_pool=pool, # ) -# result = list_entries( -# developer_id=developer_id, -# session_id=session.id, -# client=client, +# result = await list_entries( +# developer_id=TEST_DEVELOPER_ID, +# session_id=SESSION_ID, +# connection_pool=pool, # ) -# # Asserts that only one entry is retrieved, matching the session_id. + + +# # Assert that only one entry is retrieved, matching the session_id. # assert len(result) == 1 +# assert isinstance(result[0], Entry) +# assert result is not None # @test("query: get history") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# """ -# Tests the retrieval of entries from the database. -# Verifies that entries matching specific criteria can be successfully retrieved. -# """ +# async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session +# """Test the retrieval of entry history from the database.""" +# pool = await create_db_pool(dsn=dsn) # test_entry = CreateEntryRequest.from_model_input( # model=MODEL, # role="user", @@ -136,31 +102,31 @@ # source="internal", # ) -# create_entries( +# await create_entries( # developer_id=developer_id, -# session_id=session.id, +# session_id=SESSION_ID, # data=[test_entry, internal_entry], -# client=client, +# connection_pool=pool, # ) -# result = get_history( +# result = await get_history( # developer_id=developer_id, -# session_id=session.id, -# client=client, +# session_id=SESSION_ID, +# connection_pool=pool, # ) -# # Asserts that only one entry is retrieved, matching the session_id. +# # Assert that entries are retrieved and have valid IDs. +# assert result is not None +# assert isinstance(result, History) # assert len(result.entries) > 0 # assert result.entries[0].id # @test("query: delete entries") -# def _(client=cozo_client, developer_id=test_developer_id, session=test_session): -# """ -# Tests the deletion of entries from the database. -# Verifies that entries can be successfully deleted using the delete_entries function. -# """ +# async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session +# """Test the deletion of entries from the database.""" +# pool = await create_db_pool(dsn=dsn) # test_entry = CreateEntryRequest.from_model_input( # model=MODEL, # role="user", @@ -175,27 +141,29 @@ # source="internal", # ) -# created_entries = create_entries( +# created_entries = await create_entries( # developer_id=developer_id, -# session_id=session.id, +# session_id=SESSION_ID, # data=[test_entry, internal_entry], -# client=client, +# connection_pool=pool, # ) -# entry_ids = [entry.id for entry in created_entries] + # entry_ids = [entry.id for entry in created_entries] -# delete_entries( -# developer_id=developer_id, -# session_id=session.id, -# entry_ids=entry_ids, -# client=client, -# ) + # await delete_entries( + # developer_id=developer_id, + # session_id=SESSION_ID, + # entry_ids=[UUID("123e4567-e89b-12d3-a456-426614174002")], + # connection_pool=pool, + # ) -# result = list_entries( -# developer_id=developer_id, -# session_id=session.id, -# client=client, -# ) + # result = await list_entries( + # developer_id=developer_id, + # session_id=SESSION_ID, + # connection_pool=pool, + # ) -# # Asserts that no entries are retrieved after deletion. -# assert all(id not in [entry.id for entry in result] for id in entry_ids) + # Assert that no entries are retrieved after deletion. + # assert all(id not in [entry.id for entry in result] for id in entry_ids) + # assert len(result) == 0 + # assert result is not None diff --git a/agents-api/tests/test_user_queries.py b/agents-api/tests/test_user_queries.py index cbe7e0353..002532816 100644 --- a/agents-api/tests/test_user_queries.py +++ b/agents-api/tests/test_user_queries.py @@ -5,7 +5,6 @@ from uuid import UUID -import asyncpg from uuid_extensions import uuid7 from ward import raises, test diff --git a/agents-api/tests/utils.py b/agents-api/tests/utils.py index 990a1015e..a4f98ac80 100644 --- a/agents-api/tests/utils.py +++ b/agents-api/tests/utils.py @@ -1,5 +1,4 @@ import asyncio -import json import logging import subprocess from contextlib import asynccontextmanager, contextmanager @@ -7,7 +6,6 @@ from typing import Any, Dict, Optional from unittest.mock import patch -import asyncpg from botocore import exceptions from fastapi.testclient import TestClient from litellm.types.utils import ModelResponse diff --git a/integrations-service/integrations/autogen/Entries.py b/integrations-service/integrations/autogen/Entries.py index de37e77d8..d195b518f 100644 --- a/integrations-service/integrations/autogen/Entries.py +++ b/integrations-service/integrations/autogen/Entries.py @@ -52,6 +52,7 @@ class BaseEntry(BaseModel): ] tokenizer: str token_count: int + modelname: str = "gpt-40-mini" tool_calls: ( list[ ChosenFunctionCall diff --git a/typespec/entries/models.tsp b/typespec/entries/models.tsp index 7f8c8b9fa..640e6831d 100644 --- a/typespec/entries/models.tsp +++ b/typespec/entries/models.tsp @@ -107,6 +107,7 @@ model BaseEntry { tokenizer: string; token_count: uint16; + modelname: string = "gpt-40-mini"; /** Tool calls generated by the model. */ tool_calls?: ChosenToolCall[] | null = null; diff --git a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml index 0a12aac74..9b36baa2b 100644 --- a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml +++ b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml @@ -3064,6 +3064,7 @@ components: - source - tokenizer - token_count + - modelname - timestamp properties: role: @@ -3307,6 +3308,9 @@ components: token_count: type: integer format: uint16 + modelname: + type: string + default: gpt-40-mini tool_calls: type: array items: From b064234b2cdd37d33ee9acd547e13df673295eba Mon Sep 17 00:00:00 2001 From: Vedantsahai18 Date: Wed, 18 Dec 2024 04:34:14 +0000 Subject: [PATCH 10/29] refactor: Lint agents-api (CI) --- .../queries/developers/create_developer.py | 8 ++-- .../queries/developers/get_developer.py | 2 +- .../queries/developers/patch_developer.py | 8 ++-- .../queries/developers/update_developer.py | 9 ++-- .../queries/entries/create_entry.py | 8 ++-- .../queries/entries/delete_entry.py | 2 +- .../agents_api/queries/entries/list_entry.py | 3 +- .../queries/users/create_or_update_user.py | 2 +- .../agents_api/queries/users/create_user.py | 2 +- agents-api/tests/test_entry_queries.py | 48 +++++++++---------- 10 files changed, 45 insertions(+), 47 deletions(-) diff --git a/agents-api/agents_api/queries/developers/create_developer.py b/agents-api/agents_api/queries/developers/create_developer.py index 793d2f184..bed6371c4 100644 --- a/agents-api/agents_api/queries/developers/create_developer.py +++ b/agents-api/agents_api/queries/developers/create_developer.py @@ -1,17 +1,17 @@ from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from uuid_extensions import uuid7 -import asyncpg -from fastapi import HTTPException from ...common.protocol.developers import Developer from ..utils import ( + partialclass, pg_query, - wrap_in_class, rewrap_exceptions, - partialclass, + wrap_in_class, ) # Define the raw SQL query diff --git a/agents-api/agents_api/queries/developers/get_developer.py b/agents-api/agents_api/queries/developers/get_developer.py index 54d4cf9d9..373a2fb36 100644 --- a/agents-api/agents_api/queries/developers/get_developer.py +++ b/agents-api/agents_api/queries/developers/get_developer.py @@ -3,10 +3,10 @@ from typing import Any, TypeVar from uuid import UUID +import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -import asyncpg from ...common.protocol.developers import Developer from ..utils import ( diff --git a/agents-api/agents_api/queries/developers/patch_developer.py b/agents-api/agents_api/queries/developers/patch_developer.py index b37fc7c5e..af2ddb1f8 100644 --- a/agents-api/agents_api/queries/developers/patch_developer.py +++ b/agents-api/agents_api/queries/developers/patch_developer.py @@ -1,16 +1,16 @@ from uuid import UUID -from beartype import beartype -from sqlglot import parse_one import asyncpg +from beartype import beartype from fastapi import HTTPException +from sqlglot import parse_one from ...common.protocol.developers import Developer from ..utils import ( - pg_query, - wrap_in_class, partialclass, + pg_query, rewrap_exceptions, + wrap_in_class, ) # Define the raw SQL query diff --git a/agents-api/agents_api/queries/developers/update_developer.py b/agents-api/agents_api/queries/developers/update_developer.py index 410d5ca12..d41b333d5 100644 --- a/agents-api/agents_api/queries/developers/update_developer.py +++ b/agents-api/agents_api/queries/developers/update_developer.py @@ -1,15 +1,16 @@ from uuid import UUID -from beartype import beartype -from sqlglot import parse_one import asyncpg +from beartype import beartype from fastapi import HTTPException +from sqlglot import parse_one + from ...common.protocol.developers import Developer from ..utils import ( - pg_query, - wrap_in_class, partialclass, + pg_query, rewrap_exceptions, + wrap_in_class, ) # Define the raw SQL query diff --git a/agents-api/agents_api/queries/entries/create_entry.py b/agents-api/agents_api/queries/entries/create_entry.py index 471d02fe6..ea0e7e97d 100644 --- a/agents-api/agents_api/queries/entries/create_entry.py +++ b/agents-api/agents_api/queries/entries/create_entry.py @@ -13,7 +13,7 @@ from ..utils import pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for creating entries with a developer check -entry_query = (""" +entry_query = """ WITH data AS ( SELECT unnest($1::uuid[]) AS session_id, @@ -64,10 +64,10 @@ JOIN developers ON developers.developer_id = $14 RETURNING *; -""") +""" # Define the raw SQL query for creating entry relations -entry_relation_query = (""" +entry_relation_query = """ WITH data AS ( SELECT unnest($1::uuid[]) AS session_id, @@ -94,7 +94,7 @@ JOIN developers ON developers.developer_id = $6 RETURNING *; -""") +""" @rewrap_exceptions( diff --git a/agents-api/agents_api/queries/entries/delete_entry.py b/agents-api/agents_api/queries/entries/delete_entry.py index 82615745f..d6cdc6e87 100644 --- a/agents-api/agents_api/queries/entries/delete_entry.py +++ b/agents-api/agents_api/queries/entries/delete_entry.py @@ -5,8 +5,8 @@ from fastapi import HTTPException from sqlglot import parse_one -from ...common.utils.datetime import utcnow from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for deleting entries with a developer check diff --git a/agents-api/agents_api/queries/entries/list_entry.py b/agents-api/agents_api/queries/entries/list_entry.py index 5a4871a88..1fa6479d1 100644 --- a/agents-api/agents_api/queries/entries/list_entry.py +++ b/agents-api/agents_api/queries/entries/list_entry.py @@ -57,12 +57,11 @@ async def list_entries( direction: Literal["asc", "desc"] = "asc", exclude_relations: list[str] = [], ) -> tuple[str, list]: - if limit < 1 or limit > 1000: raise HTTPException(status_code=400, detail="Limit must be between 1 and 1000") if offset < 0: raise HTTPException(status_code=400, detail="Offset must be non-negative") - + # making the parameters for the query params = [ session_id, # $1 diff --git a/agents-api/agents_api/queries/users/create_or_update_user.py b/agents-api/agents_api/queries/users/create_or_update_user.py index 6fd97942a..965ae4ce4 100644 --- a/agents-api/agents_api/queries/users/create_or_update_user.py +++ b/agents-api/agents_api/queries/users/create_or_update_user.py @@ -5,8 +5,8 @@ from fastapi import HTTPException from sqlglot import parse_one -from ...metrics.counters import increase_counter from ...autogen.openapi_model import CreateOrUpdateUserRequest, User +from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for creating or updating a user diff --git a/agents-api/agents_api/queries/users/create_user.py b/agents-api/agents_api/queries/users/create_user.py index d77fbff47..8f35a646c 100644 --- a/agents-api/agents_api/queries/users/create_user.py +++ b/agents-api/agents_api/queries/users/create_user.py @@ -6,8 +6,8 @@ from sqlglot import parse_one from uuid_extensions import uuid7 -from ...metrics.counters import increase_counter from ...autogen.openapi_model import CreateUserRequest, User +from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query outside the function diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 242d0abfb..c07891305 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -6,14 +6,14 @@ from uuid import UUID from ward import test -from agents_api.clients.pg import create_db_pool +from agents_api.autogen.openapi_model import CreateEntryRequest, Entry +from agents_api.clients.pg import create_db_pool from agents_api.queries.entries.create_entry import create_entries -from agents_api.queries.entries.list_entry import list_entries -from agents_api.queries.entries.get_history import get_history from agents_api.queries.entries.delete_entry import delete_entries +from agents_api.queries.entries.get_history import get_history +from agents_api.queries.entries.list_entry import list_entries from tests.fixtures import pg_dsn, test_developer_id # , test_session -from agents_api.autogen.openapi_model import CreateEntryRequest, Entry # Test UUIDs for consistent testing MODEL = "gpt-4o-mini" @@ -42,7 +42,6 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_sessi ) - # @test("query: get entries") # async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session # """Test the retrieval of entries from the database.""" @@ -76,7 +75,6 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_sessi # ) - # # Assert that only one entry is retrieved, matching the session_id. # assert len(result) == 1 # assert isinstance(result[0], Entry) @@ -148,22 +146,22 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_sessi # connection_pool=pool, # ) - # entry_ids = [entry.id for entry in created_entries] - - # await delete_entries( - # developer_id=developer_id, - # session_id=SESSION_ID, - # entry_ids=[UUID("123e4567-e89b-12d3-a456-426614174002")], - # connection_pool=pool, - # ) - - # result = await list_entries( - # developer_id=developer_id, - # session_id=SESSION_ID, - # connection_pool=pool, - # ) - - # Assert that no entries are retrieved after deletion. - # assert all(id not in [entry.id for entry in result] for id in entry_ids) - # assert len(result) == 0 - # assert result is not None +# entry_ids = [entry.id for entry in created_entries] + +# await delete_entries( +# developer_id=developer_id, +# session_id=SESSION_ID, +# entry_ids=[UUID("123e4567-e89b-12d3-a456-426614174002")], +# connection_pool=pool, +# ) + +# result = await list_entries( +# developer_id=developer_id, +# session_id=SESSION_ID, +# connection_pool=pool, +# ) + +# Assert that no entries are retrieved after deletion. +# assert all(id not in [entry.id for entry in result] for id in entry_ids) +# assert len(result) == 0 +# assert result is not None From a72812946d4bed45d68041962f4f6d1c7487c7d5 Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Wed, 18 Dec 2024 13:21:02 +0530 Subject: [PATCH 11/29] feat(agents-api): Fix tests for sessions Signed-off-by: Diwank Singh Tomer --- agents-api/agents_api/app.py | 10 +-- .../queries/sessions/list_sessions.py | 3 +- .../queries/users/create_or_update_user.py | 1 - .../agents_api/queries/users/create_user.py | 1 - .../agents_api/queries/users/delete_user.py | 1 - .../agents_api/queries/users/get_user.py | 1 - .../agents_api/queries/users/list_users.py | 2 - .../agents_api/queries/users/patch_user.py | 1 - .../agents_api/queries/users/update_user.py | 1 - agents-api/agents_api/queries/utils.py | 54 +++++++------- agents-api/agents_api/web.py | 2 +- agents-api/tests/fixtures.py | 70 +++++++++++-------- agents-api/tests/test_session_queries.py | 3 +- agents-api/tests/test_user_queries.py | 1 - agents-api/tests/utils.py | 2 - memory-store/migrations/000015_entries.up.sql | 4 +- 16 files changed, 79 insertions(+), 78 deletions(-) diff --git a/agents-api/agents_api/app.py b/agents-api/agents_api/app.py index 735dfc8c0..ced41decb 100644 --- a/agents-api/agents_api/app.py +++ b/agents-api/agents_api/app.py @@ -1,7 +1,5 @@ -import json from contextlib import asynccontextmanager -import asyncpg from fastapi import FastAPI from prometheus_fastapi_instrumentator import Instrumentator @@ -11,9 +9,13 @@ @asynccontextmanager async def lifespan(app: FastAPI): - app.state.postgres_pool = await create_db_pool() + if not app.state.postgres_pool: + app.state.postgres_pool = await create_db_pool() + yield - await app.state.postgres_pool.close() + + if app.state.postgres_pool: + await app.state.postgres_pool.close() app: FastAPI = FastAPI( diff --git a/agents-api/agents_api/queries/sessions/list_sessions.py b/agents-api/agents_api/queries/sessions/list_sessions.py index 5ce31803b..3aabaf32d 100644 --- a/agents-api/agents_api/queries/sessions/list_sessions.py +++ b/agents-api/agents_api/queries/sessions/list_sessions.py @@ -1,12 +1,11 @@ """This module contains functions for querying session data from the PostgreSQL database.""" -from typing import Any, Literal, TypeVar +from typing import Any, Literal from uuid import UUID import asyncpg from beartype import beartype from fastapi import HTTPException -from sqlglot import parse_one from ...autogen.openapi_model import Session from ...metrics.counters import increase_counter diff --git a/agents-api/agents_api/queries/users/create_or_update_user.py b/agents-api/agents_api/queries/users/create_or_update_user.py index d2be71bb4..cff9ed09b 100644 --- a/agents-api/agents_api/queries/users/create_or_update_user.py +++ b/agents-api/agents_api/queries/users/create_or_update_user.py @@ -4,7 +4,6 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import CreateOrUpdateUserRequest, User from ...metrics.counters import increase_counter diff --git a/agents-api/agents_api/queries/users/create_user.py b/agents-api/agents_api/queries/users/create_user.py index 66e8bcc27..bdab2541f 100644 --- a/agents-api/agents_api/queries/users/create_user.py +++ b/agents-api/agents_api/queries/users/create_user.py @@ -4,7 +4,6 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from uuid_extensions import uuid7 from ...autogen.openapi_model import CreateUserRequest, User diff --git a/agents-api/agents_api/queries/users/delete_user.py b/agents-api/agents_api/queries/users/delete_user.py index 520c8d695..6ea5e9664 100644 --- a/agents-api/agents_api/queries/users/delete_user.py +++ b/agents-api/agents_api/queries/users/delete_user.py @@ -4,7 +4,6 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow diff --git a/agents-api/agents_api/queries/users/get_user.py b/agents-api/agents_api/queries/users/get_user.py index 6989c8edb..ee75157e0 100644 --- a/agents-api/agents_api/queries/users/get_user.py +++ b/agents-api/agents_api/queries/users/get_user.py @@ -4,7 +4,6 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import User from ...metrics.counters import increase_counter diff --git a/agents-api/agents_api/queries/users/list_users.py b/agents-api/agents_api/queries/users/list_users.py index 74b40eb7b..4c30cd100 100644 --- a/agents-api/agents_api/queries/users/list_users.py +++ b/agents-api/agents_api/queries/users/list_users.py @@ -4,8 +4,6 @@ import asyncpg from beartype import beartype from fastapi import HTTPException -from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import User from ...metrics.counters import increase_counter diff --git a/agents-api/agents_api/queries/users/patch_user.py b/agents-api/agents_api/queries/users/patch_user.py index 971e96b81..3a2189014 100644 --- a/agents-api/agents_api/queries/users/patch_user.py +++ b/agents-api/agents_api/queries/users/patch_user.py @@ -4,7 +4,6 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import PatchUserRequest, ResourceUpdatedResponse from ...metrics.counters import increase_counter diff --git a/agents-api/agents_api/queries/users/update_user.py b/agents-api/agents_api/queries/users/update_user.py index 1fffdebe7..c3f436b5c 100644 --- a/agents-api/agents_api/queries/users/update_user.py +++ b/agents-api/agents_api/queries/users/update_user.py @@ -4,7 +4,6 @@ from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateUserRequest from ...metrics.counters import increase_counter diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index e93135172..e7be9f981 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -6,6 +6,7 @@ from typing import Any, Awaitable, Callable, ParamSpec, Type, TypeVar, cast import asyncpg +from beartype import beartype import pandas as pd from asyncpg import Record from fastapi import HTTPException @@ -30,13 +31,16 @@ class NewCls(cls): return NewCls +@beartype def pg_query( func: Callable[P, tuple[str | list[str | None], dict]] | None = None, debug: bool | None = None, only_on_error: bool = False, timeit: bool = False, -): - def pg_query_dec(func: Callable[P, tuple[str | list[Any], dict]]): +) -> Callable[..., Callable[P, list[Record]]] | Callable[P, list[Record]]: + def pg_query_dec( + func: Callable[P, tuple[str, list[Any]] | list[tuple[str, list[Any]]]] + ) -> Callable[..., Callable[P, list[Record]]]: """ Decorator that wraps a function that takes arbitrary arguments, and returns a (query string, variables) tuple. @@ -47,19 +51,6 @@ def pg_query_dec(func: Callable[P, tuple[str | list[Any], dict]]): from pprint import pprint - # from tenacity import ( - # retry, - # retry_if_exception, - # stop_after_attempt, - # wait_exponential, - # ) - - # TODO: Remove all tenacity decorators - # @retry( - # stop=stop_after_attempt(4), - # wait=wait_exponential(multiplier=1, min=4, max=10), - # # retry=retry_if_exception(is_resource_busy), - # ) @wraps(func) async def wrapper( *args: P.args, @@ -76,17 +67,25 @@ async def wrapper( ) # Run the query + pool = ( + connection_pool + if connection_pool is not None + else cast(asyncpg.Pool, app.state.postgres_pool) + ) + + assert isinstance(variables, list) and len(variables) > 0 + + queries = query if isinstance(query, list) else [query] + variables_list = variables if isinstance(variables[0], list) else [variables] + zipped = zip(queries, variables_list) try: - pool = ( - connection_pool - if connection_pool is not None - else cast(asyncpg.Pool, app.state.postgres_pool) - ) async with pool.acquire() as conn: async with conn.transaction(): start = timeit and time.perf_counter() - results: list[Record] = await conn.fetch(query, *variables) + for query, variables in zipped: + results: list[Record] = await conn.fetch(query, *variables) + end = timeit and time.perf_counter() timeit and print( @@ -136,8 +135,7 @@ def wrap_in_class( cls: Type[ModelT] | Callable[..., ModelT], one: bool = False, transform: Callable[[dict], dict] | None = None, - _kind: str | None = None, -): +) -> Callable[..., Callable[..., ModelT | list[ModelT]]]: def _return_data(rec: list[Record]): data = [dict(r.items()) for r in rec] @@ -152,7 +150,9 @@ def _return_data(rec: list[Record]): objs: list[ModelT] = [cls(**item) for item in map(transform, data)] return objs - def decorator(func: Callable[P, pd.DataFrame | Awaitable[pd.DataFrame]]): + def decorator( + func: Callable[P, list[Record] | Awaitable[list[Record]]] + ) -> Callable[P, ModelT | list[ModelT]]: @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> ModelT | list[ModelT]: return _return_data(func(*args, **kwargs)) @@ -179,7 +179,7 @@ def rewrap_exceptions( Type[BaseException] | Callable[[BaseException], BaseException], ], /, -): +) -> Callable[..., Callable[P, T | Awaitable[T]]]: def _check_error(error): nonlocal mapping @@ -199,7 +199,9 @@ def _check_error(error): raise new_error from error - def decorator(func: Callable[P, T | Awaitable[T]]): + def decorator( + func: Callable[P, T | Awaitable[T]] + ) -> Callable[..., Callable[P, T | Awaitable[T]]]: @wraps(func) async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: try: diff --git a/agents-api/agents_api/web.py b/agents-api/agents_api/web.py index b354f97bf..379526e0f 100644 --- a/agents-api/agents_api/web.py +++ b/agents-api/agents_api/web.py @@ -9,7 +9,7 @@ import sentry_sdk import uvicorn import uvloop -from fastapi import APIRouter, Depends, FastAPI, Request, status +from fastapi import APIRouter, FastAPI, Request, status from fastapi.exceptions import HTTPException, RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 389dafab2..1b86224a6 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -43,8 +43,8 @@ # from agents_api.queries.tools.create_tools import create_tools # from agents_api.queries.tools.delete_tool import delete_tool from agents_api.queries.users.create_user import create_user +from agents_api.queries.users.delete_user import delete_user -# from agents_api.queries.users.delete_user import delete_user from agents_api.web import app from .utils import ( @@ -67,11 +67,10 @@ def pg_dsn(): @fixture(scope="global") def test_developer_id(): if not multi_tenant_mode: - yield UUID(int=0) - return + return UUID(int=0) developer_id = uuid7() - yield developer_id + return developer_id # @fixture(scope="global") @@ -98,8 +97,7 @@ async def test_developer(dsn=pg_dsn, developer_id=test_developer_id): connection_pool=pool, ) - yield developer - await pool.close() + return developer @fixture(scope="test") @@ -138,8 +136,7 @@ async def test_user(dsn=pg_dsn, developer=test_developer): connection_pool=pool, ) - yield user - await pool.close() + return user @fixture(scope="test") @@ -345,38 +342,49 @@ async def test_new_developer(dsn=pg_dsn, email=random_email): # "type": "function", # } -# async with get_pg_client(dsn=dsn) as client: -# [tool, *_] = await create_tools( +# [tool, *_] = await create_tools( +# developer_id=developer_id, +# agent_id=agent.id, +# data=[CreateToolRequest(**tool)], +# connection_pool=pool, +# ) +# yield tool + +# # Cleanup +# try: +# await delete_tool( # developer_id=developer_id, -# agent_id=agent.id, -# data=[CreateToolRequest(**tool)], -# client=client, +# tool_id=tool.id, +# connection_pool=pool, # ) -# yield tool +# finally: +# await pool.close() -# @fixture(scope="global") -# def client(dsn=pg_dsn): -# client = TestClient(app=app) -# client.state.pg_client = get_pg_client(dsn=dsn) -# return client +@fixture(scope="global") +async def client(dsn=pg_dsn): + pool = await create_db_pool(dsn=dsn) + client = TestClient(app=app) + client.state.postgres_pool = pool + return client -# @fixture(scope="global") -# def make_request(client=client, developer_id=test_developer_id): -# def _make_request(method, url, **kwargs): -# headers = kwargs.pop("headers", {}) -# headers = { -# **headers, -# api_key_header_name: api_key, -# } -# if multi_tenant_mode: -# headers["X-Developer-Id"] = str(developer_id) +@fixture(scope="global") +async def make_request(client=client, developer_id=test_developer_id): + def _make_request(method, url, **kwargs): + headers = kwargs.pop("headers", {}) + headers = { + **headers, + api_key_header_name: api_key, + } + + if multi_tenant_mode: + headers["X-Developer-Id"] = str(developer_id) -# return client.request(method, url, headers=headers, **kwargs) + return client.request(method, url, headers=headers, **kwargs) -# return _make_request + return _make_request @fixture(scope="global") diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 90b40a0d8..d182586dc 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -32,6 +32,7 @@ from tests.fixtures import ( pg_dsn, test_developer_id, + test_user, ) # , test_session, test_agent, test_user # @test("query: create session sql") @@ -118,7 +119,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): # assert isinstance(result, Session) -@test("query: list sessions sql") +@test("query: list sessions when none exist sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that listing sessions returns a collection of session information.""" diff --git a/agents-api/tests/test_user_queries.py b/agents-api/tests/test_user_queries.py index cbe7e0353..002532816 100644 --- a/agents-api/tests/test_user_queries.py +++ b/agents-api/tests/test_user_queries.py @@ -5,7 +5,6 @@ from uuid import UUID -import asyncpg from uuid_extensions import uuid7 from ward import raises, test diff --git a/agents-api/tests/utils.py b/agents-api/tests/utils.py index 990a1015e..a4f98ac80 100644 --- a/agents-api/tests/utils.py +++ b/agents-api/tests/utils.py @@ -1,5 +1,4 @@ import asyncio -import json import logging import subprocess from contextlib import asynccontextmanager, contextmanager @@ -7,7 +6,6 @@ from typing import Any, Dict, Optional from unittest.mock import patch -import asyncpg from botocore import exceptions from fastapi.testclient import TestClient from litellm.types.utils import ModelResponse diff --git a/memory-store/migrations/000015_entries.up.sql b/memory-store/migrations/000015_entries.up.sql index e9d5c6a4f..c104091a2 100644 --- a/memory-store/migrations/000015_entries.up.sql +++ b/memory-store/migrations/000015_entries.up.sql @@ -1,7 +1,7 @@ BEGIN; -- Create chat_role enum -CREATE TYPE chat_role AS ENUM('user', 'assistant', 'tool', 'system'); +CREATE TYPE chat_role AS ENUM('user', 'assistant', 'tool', 'system', 'developer'); -- Create entries table CREATE TABLE IF NOT EXISTS entries ( @@ -101,4 +101,4 @@ AFTER INSERT OR UPDATE ON entries FOR EACH ROW EXECUTE FUNCTION update_session_updated_at(); -COMMIT; \ No newline at end of file +COMMIT; From 372f3203f390839716428d678ad78be60142f4d9 Mon Sep 17 00:00:00 2001 From: creatorrr Date: Wed, 18 Dec 2024 07:52:14 +0000 Subject: [PATCH 12/29] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/utils.py | 12 +++++++----- agents-api/tests/fixtures.py | 1 - 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index e7be9f981..3b5dc0bb0 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -6,9 +6,9 @@ from typing import Any, Awaitable, Callable, ParamSpec, Type, TypeVar, cast import asyncpg -from beartype import beartype import pandas as pd from asyncpg import Record +from beartype import beartype from fastapi import HTTPException from pydantic import BaseModel @@ -39,7 +39,7 @@ def pg_query( timeit: bool = False, ) -> Callable[..., Callable[P, list[Record]]] | Callable[P, list[Record]]: def pg_query_dec( - func: Callable[P, tuple[str, list[Any]] | list[tuple[str, list[Any]]]] + func: Callable[P, tuple[str, list[Any]] | list[tuple[str, list[Any]]]], ) -> Callable[..., Callable[P, list[Record]]]: """ Decorator that wraps a function that takes arbitrary arguments, and @@ -76,7 +76,9 @@ async def wrapper( assert isinstance(variables, list) and len(variables) > 0 queries = query if isinstance(query, list) else [query] - variables_list = variables if isinstance(variables[0], list) else [variables] + variables_list = ( + variables if isinstance(variables[0], list) else [variables] + ) zipped = zip(queries, variables_list) try: @@ -151,7 +153,7 @@ def _return_data(rec: list[Record]): return objs def decorator( - func: Callable[P, list[Record] | Awaitable[list[Record]]] + func: Callable[P, list[Record] | Awaitable[list[Record]]], ) -> Callable[P, ModelT | list[ModelT]]: @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> ModelT | list[ModelT]: @@ -200,7 +202,7 @@ def _check_error(error): raise new_error from error def decorator( - func: Callable[P, T | Awaitable[T]] + func: Callable[P, T | Awaitable[T]], ) -> Callable[..., Callable[P, T | Awaitable[T]]]: @wraps(func) async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 1b86224a6..c2aa350a8 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -44,7 +44,6 @@ # from agents_api.queries.tools.delete_tool import delete_tool from agents_api.queries.users.create_user import create_user from agents_api.queries.users.delete_user import delete_user - from agents_api.web import app from .utils import ( From 2b907eff42c33f8fc5fcc3acc30350e5c3af99cd Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Wed, 18 Dec 2024 19:56:31 +0530 Subject: [PATCH 13/29] wip(agents-api): Entry queries Signed-off-by: Diwank Singh Tomer --- agents-api/agents_api/env.py | 2 + .../agents_api/queries/agents/create_agent.py | 1 - .../queries/agents/create_or_update_agent.py | 1 - .../agents_api/queries/agents/delete_agent.py | 2 - .../agents_api/queries/agents/get_agent.py | 1 - .../agents_api/queries/agents/list_agents.py | 1 - .../agents_api/queries/agents/patch_agent.py | 1 - .../agents_api/queries/agents/update_agent.py | 1 - .../agents_api/queries/entries/__init__.py | 6 +- .../queries/entries/create_entries.py | 181 ++++++++++++++++ .../queries/entries/create_entry.py | 196 ------------------ .../queries/entries/delete_entries.py | 128 ++++++++++++ .../queries/entries/delete_entry.py | 96 --------- .../agents_api/queries/entries/get_history.py | 2 +- .../queries/entries/list_entries.py | 112 ++++++++++ .../agents_api/queries/entries/list_entry.py | 79 ------- agents-api/agents_api/queries/utils.py | 108 +++++++--- agents-api/agents_api/web.py | 1 - agents-api/tests/fixtures.py | 13 -- agents-api/tests/test_entry_queries.py | 53 +++-- agents-api/tests/test_session_queries.py | 15 -- 21 files changed, 538 insertions(+), 462 deletions(-) create mode 100644 agents-api/agents_api/queries/entries/create_entries.py delete mode 100644 agents-api/agents_api/queries/entries/create_entry.py create mode 100644 agents-api/agents_api/queries/entries/delete_entries.py delete mode 100644 agents-api/agents_api/queries/entries/delete_entry.py create mode 100644 agents-api/agents_api/queries/entries/list_entries.py delete mode 100644 agents-api/agents_api/queries/entries/list_entry.py diff --git a/agents-api/agents_api/env.py b/agents-api/agents_api/env.py index 48623b771..8b9fd4dae 100644 --- a/agents-api/agents_api/env.py +++ b/agents-api/agents_api/env.py @@ -66,6 +66,8 @@ default="postgres://postgres:postgres@0.0.0.0:5432/postgres?sslmode=disable", ) +query_timeout: float = env.float("QUERY_TIMEOUT", default=90.0) + # Auth # ---- diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index 46dc453f9..4c731d3dd 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -13,7 +13,6 @@ from uuid_extensions import uuid7 from ...autogen.openapi_model import Agent, CreateAgentRequest -from ...metrics.counters import increase_counter from ..utils import ( # generate_canonical_name, partialclass, diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index 261508237..96681255c 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -11,7 +11,6 @@ from psycopg import errors as psycopg_errors from ...autogen.openapi_model import Agent, CreateOrUpdateAgentRequest -from ...metrics.counters import increase_counter from ..utils import ( # generate_canonical_name, partialclass, diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index 9d6869a94..f3c64fd18 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -11,8 +11,6 @@ from psycopg import errors as psycopg_errors from ...autogen.openapi_model import ResourceDeletedResponse -from ...common.utils.datetime import utcnow -from ...metrics.counters import increase_counter from ..utils import ( partialclass, pg_query, diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index 9061db7cf..5e0edbb98 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -11,7 +11,6 @@ from psycopg import errors as psycopg_errors from ...autogen.openapi_model import Agent -from ...metrics.counters import increase_counter from ..utils import ( partialclass, pg_query, diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index 62aed6536..5fda7c626 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -11,7 +11,6 @@ from psycopg import errors as psycopg_errors from ...autogen.openapi_model import Agent -from ...metrics.counters import increase_counter from ..utils import ( partialclass, pg_query, diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index c418f5c26..450cbf8cc 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -11,7 +11,6 @@ from psycopg import errors as psycopg_errors from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse -from ...metrics.counters import increase_counter from ..utils import ( partialclass, pg_query, diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index 4e38adfac..61548de70 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -11,7 +11,6 @@ from psycopg import errors as psycopg_errors from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest -from ...metrics.counters import increase_counter from ..utils import ( partialclass, pg_query, diff --git a/agents-api/agents_api/queries/entries/__init__.py b/agents-api/agents_api/queries/entries/__init__.py index 7c196dd62..e6db0efed 100644 --- a/agents-api/agents_api/queries/entries/__init__.py +++ b/agents-api/agents_api/queries/entries/__init__.py @@ -8,10 +8,10 @@ - Listing entries with filtering and pagination """ -from .create_entry import create_entries -from .delete_entry import delete_entries +from .create_entries import create_entries +from .delete_entries import delete_entries from .get_history import get_history -from .list_entry import list_entries +from .list_entries import list_entries __all__ = [ "create_entries", diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py new file mode 100644 index 000000000..ffbd2de22 --- /dev/null +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -0,0 +1,181 @@ +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateEntryRequest, Entry, Relation +from ...common.utils.datetime import utcnow +from ...common.utils.messages import content_to_json +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Query for checking if the session exists +session_exists_query = """ +SELECT CASE + WHEN EXISTS ( + SELECT 1 FROM sessions + WHERE session_id = $1 AND developer_id = $2 + ) + THEN TRUE + ELSE (SELECT NULL::boolean WHERE FALSE) -- This raises a NO_DATA_FOUND error +END; +""" + +# Define the raw SQL query for creating entries +entry_query = """ +INSERT INTO entries ( + session_id, + entry_id, + source, + role, + event_type, + name, + content, + tool_call_id, + tool_calls, + model, + token_count, + created_at, + timestamp +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) +RETURNING *; +""" + +# Define the raw SQL query for creating entry relations +entry_relation_query = """ +INSERT INTO entry_relations ( + session_id, + head, + relation, + tail, + is_leaf +) VALUES ($1, $2, $3, $4, $5) +RETURNING *; +""" + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( + status_code=404, + detail=str(exc), + ), + asyncpg.UniqueViolationError: lambda exc: HTTPException( + status_code=409, + detail=str(exc), + ), + asyncpg.NotNullViolationError: lambda exc: HTTPException( + status_code=400, + detail=str(exc), + ), + } +) +@wrap_in_class( + Entry, + transform=lambda d: { + "id": UUID(d.pop("entry_id")), + **d, + }, +) +@increase_counter("create_entries") +@pg_query +@beartype +async def create_entries( + *, + developer_id: UUID, + session_id: UUID, + data: list[CreateEntryRequest], +) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: + # Convert the data to a list of dictionaries + data_dicts = [item.model_dump(mode="json") for item in data] + + # Prepare the parameters for the query + params = [] + + for item in data_dicts: + params.append( + [ + session_id, # $1 + item.pop("id", None) or str(uuid7()), # $2 + item.get("source"), # $3 + item.get("role"), # $4 + item.get("event_type") or "message.create", # $5 + item.get("name"), # $6 + content_to_json(item.get("content") or {}), # $7 + item.get("tool_call_id"), # $8 + content_to_json(item.get("tool_calls") or {}), # $9 + item.get("modelname"), # $10 + item.get("token_count"), # $11 + item.get("created_at") or utcnow(), # $12 + utcnow(), # $13 + ] + ) + + return [ + ( + session_exists_query, + [session_id, developer_id], + "fetch", + ), + ( + entry_query, + params, + "fetchmany", + ), + ] + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( + status_code=404, + detail=str(exc), + ), + asyncpg.UniqueViolationError: lambda exc: HTTPException( + status_code=409, + detail=str(exc), + ), + } +) +@wrap_in_class(Relation) +@increase_counter("add_entry_relations") +@pg_query +@beartype +async def add_entry_relations( + *, + developer_id: UUID, + session_id: UUID, + data: list[Relation], +) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: + # Convert the data to a list of dictionaries + data_dicts = [item.model_dump(mode="json") for item in data] + + # Prepare the parameters for the query + params = [] + + for item in data_dicts: + params.append( + [ + item.get("session_id"), # $1 + item.get("head"), # $2 + item.get("relation"), # $3 + item.get("tail"), # $4 + item.get("is_leaf", False), # $5 + ] + ) + + return [ + ( + session_exists_query, + [session_id, developer_id], + "fetch", + ), + ( + entry_relation_query, + params, + "fetchmany", + ), + ] diff --git a/agents-api/agents_api/queries/entries/create_entry.py b/agents-api/agents_api/queries/entries/create_entry.py deleted file mode 100644 index ea0e7e97d..000000000 --- a/agents-api/agents_api/queries/entries/create_entry.py +++ /dev/null @@ -1,196 +0,0 @@ -from uuid import UUID - -import asyncpg -from beartype import beartype -from fastapi import HTTPException -from sqlglot import parse_one -from uuid_extensions import uuid7 - -from ...autogen.openapi_model import CreateEntryRequest, Entry, Relation -from ...common.utils.datetime import utcnow -from ...common.utils.messages import content_to_json -from ...metrics.counters import increase_counter -from ..utils import pg_query, rewrap_exceptions, wrap_in_class - -# Define the raw SQL query for creating entries with a developer check -entry_query = """ -WITH data AS ( - SELECT - unnest($1::uuid[]) AS session_id, - unnest($2::uuid[]) AS entry_id, - unnest($3::text[]) AS source, - unnest($4::text[])::chat_role AS role, - unnest($5::text[]) AS event_type, - unnest($6::text[]) AS name, - array[unnest($7::jsonb[])] AS content, - unnest($8::text[]) AS tool_call_id, - array[unnest($9::jsonb[])] AS tool_calls, - unnest($10::text[]) AS model, - unnest($11::int[]) AS token_count, - unnest($12::timestamptz[]) AS created_at, - unnest($13::timestamptz[]) AS timestamp -) -INSERT INTO entries ( - session_id, - entry_id, - source, - role, - event_type, - name, - content, - tool_call_id, - tool_calls, - model, - token_count, - created_at, - timestamp -) -SELECT - d.session_id, - d.entry_id, - d.source, - d.role, - d.event_type, - d.name, - d.content, - d.tool_call_id, - d.tool_calls, - d.model, - d.token_count, - d.created_at, - d.timestamp -FROM - data d -JOIN - developers ON developers.developer_id = $14 -RETURNING *; -""" - -# Define the raw SQL query for creating entry relations -entry_relation_query = """ -WITH data AS ( - SELECT - unnest($1::uuid[]) AS session_id, - unnest($2::uuid[]) AS head, - unnest($3::text[]) AS relation, - unnest($4::uuid[]) AS tail, - unnest($5::boolean[]) AS is_leaf -) -INSERT INTO entry_relations ( - session_id, - head, - relation, - tail, - is_leaf -) -SELECT - d.session_id, - d.head, - d.relation, - d.tail, - d.is_leaf -FROM - data d -JOIN - developers ON developers.developer_id = $6 -RETURNING *; -""" - - -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( - status_code=404, - detail=str(exc), - ), - asyncpg.UniqueViolationError: lambda exc: HTTPException( - status_code=409, - detail=str(exc), - ), - asyncpg.NotNullViolationError: lambda exc: HTTPException( - status_code=400, - detail=str(exc), - ), - } -) -@wrap_in_class( - Entry, - transform=lambda d: { - "id": UUID(d.pop("entry_id")), - **d, - }, -) -@increase_counter("create_entries") -@pg_query -@beartype -async def create_entries( - *, - developer_id: UUID, - session_id: UUID, - data: list[CreateEntryRequest], -) -> tuple[str, list]: - # Convert the data to a list of dictionaries - data_dicts = [item.model_dump(mode="json") for item in data] - - # Prepare the parameters for the query - params = [ - [session_id] * len(data_dicts), # $1 - [item.pop("id", None) or str(uuid7()) for item in data_dicts], # $2 - [item.get("source") for item in data_dicts], # $3 - [item.get("role") for item in data_dicts], # $4 - [item.get("event_type") or "message.create" for item in data_dicts], # $5 - [item.get("name") for item in data_dicts], # $6 - [content_to_json(item.get("content") or {}) for item in data_dicts], # $7 - [item.get("tool_call_id") for item in data_dicts], # $8 - [content_to_json(item.get("tool_calls") or {}) for item in data_dicts], # $9 - [item.get("modelname") for item in data_dicts], # $10 - [item.get("token_count") for item in data_dicts], # $11 - [item.get("created_at") or utcnow() for item in data_dicts], # $12 - [utcnow() for _ in data_dicts], # $13 - developer_id, # $14 - ] - - return ( - entry_query, - params, - ) - - -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( - status_code=404, - detail=str(exc), - ), - asyncpg.UniqueViolationError: lambda exc: HTTPException( - status_code=409, - detail=str(exc), - ), - } -) -@wrap_in_class(Relation) -@increase_counter("add_entry_relations") -@pg_query -@beartype -async def add_entry_relations( - *, - developer_id: UUID, - data: list[Relation], -) -> tuple[str, list]: - # Convert the data to a list of dictionaries - data_dicts = [item.model_dump(mode="json") for item in data] - - # Prepare the parameters for the query - params = [ - [item.get("session_id") for item in data_dicts], # $1 - [item.get("head") for item in data_dicts], # $2 - [item.get("relation") for item in data_dicts], # $3 - [item.get("tail") for item in data_dicts], # $4 - [item.get("is_leaf", False) for item in data_dicts], # $5 - developer_id, # $6 - ] - - return ( - entry_relation_query, - params, - ) diff --git a/agents-api/agents_api/queries/entries/delete_entries.py b/agents-api/agents_api/queries/entries/delete_entries.py new file mode 100644 index 000000000..9a5d6faa3 --- /dev/null +++ b/agents-api/agents_api/queries/entries/delete_entries.py @@ -0,0 +1,128 @@ +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query for deleting entries with a developer check +delete_entry_query = parse_one(""" +DELETE FROM entries +USING developers +WHERE entries.session_id = $1 -- session_id + AND developers.developer_id = $2 -- developer_id + +RETURNING entries.session_id as session_id; +""").sql(pretty=True) + +# Define the raw SQL query for deleting entries with a developer check +delete_entry_relations_query = parse_one(""" +DELETE FROM entry_relations +WHERE entry_relations.session_id = $1 -- session_id +""").sql(pretty=True) + +# Define the raw SQL query for deleting entries with a developer check +delete_entry_relations_by_ids_query = parse_one(""" +DELETE FROM entry_relations +WHERE entry_relations.session_id = $1 -- session_id + AND (entry_relations.head = ANY($2) -- entry_ids + OR entry_relations.tail = ANY($2)) -- entry_ids +""").sql(pretty=True) + +# Define the raw SQL query for deleting entries by entry_ids with a developer check +delete_entry_by_ids_query = parse_one(""" +DELETE FROM entries +USING developers +WHERE entries.entry_id = ANY($1) -- entry_ids + AND developers.developer_id = $2 -- developer_id + AND entries.session_id = $3 -- session_id + +RETURNING entries.entry_id as entry_id; +""").sql(pretty=True) + +# Add a session_exists_query similar to create_entries.py +session_exists_query = """ +SELECT EXISTS ( + SELECT 1 + FROM sessions + WHERE session_id = $1 + AND developer_id = $2 +); +""" + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( + status_code=404, + detail="The specified session or developer does not exist.", + ), + asyncpg.UniqueViolationError: lambda exc: HTTPException( + status_code=409, + detail="The specified session has already been deleted.", + ), + } +) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + "id": d["session_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@increase_counter("delete_entries_for_session") +@pg_query +@beartype +async def delete_entries_for_session( + *, + developer_id: UUID, + session_id: UUID, +) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: + """Delete all entries for a given session.""" + return [ + (session_exists_query, [session_id, developer_id], "fetch"), + (delete_entry_relations_query, [session_id], "fetchmany"), + (delete_entry_query, [session_id, developer_id], "fetchmany"), + ] + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( + status_code=404, + detail="The specified entries, session, or developer does not exist.", + ), + asyncpg.UniqueViolationError: lambda exc: HTTPException( + status_code=409, + detail="One or more specified entries have already been deleted.", + ), + } +) +@wrap_in_class( + ResourceDeletedResponse, + transform=lambda d: { + "id": d["entry_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@increase_counter("delete_entries") +@pg_query +@beartype +async def delete_entries( + *, developer_id: UUID, session_id: UUID, entry_ids: list[UUID] +) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: + """Delete specific entries by their IDs.""" + return [ + (session_exists_query, [session_id, developer_id], "fetch"), + (delete_entry_relations_by_ids_query, [session_id, entry_ids], "fetchmany"), + (delete_entry_by_ids_query, [entry_ids, developer_id, session_id], "fetchmany"), + ] diff --git a/agents-api/agents_api/queries/entries/delete_entry.py b/agents-api/agents_api/queries/entries/delete_entry.py deleted file mode 100644 index d6cdc6e87..000000000 --- a/agents-api/agents_api/queries/entries/delete_entry.py +++ /dev/null @@ -1,96 +0,0 @@ -from uuid import UUID - -import asyncpg -from beartype import beartype -from fastapi import HTTPException -from sqlglot import parse_one - -from ...autogen.openapi_model import ResourceDeletedResponse -from ...common.utils.datetime import utcnow -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class - -# Define the raw SQL query for deleting entries with a developer check -entry_query = parse_one(""" -DELETE FROM entries -USING developers -WHERE entries.session_id = $1 -- session_id -AND developers.developer_id = $2 -RETURNING entries.session_id as session_id; -""").sql(pretty=True) - -# Define the raw SQL query for deleting entries by entry_ids with a developer check -delete_entry_by_ids_query = parse_one(""" -DELETE FROM entries -USING developers -WHERE entries.entry_id = ANY($1) -- entry_ids -AND developers.developer_id = $2 -AND entries.session_id = $3 -- session_id -RETURNING entries.entry_id as entry_id; -""").sql(pretty=True) - - -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( - status_code=400, - detail=str(exc), - ), - asyncpg.UniqueViolationError: lambda exc: HTTPException( - status_code=404, - detail=str(exc), - ), - } -) -@wrap_in_class( - ResourceDeletedResponse, - one=True, - transform=lambda d: { - "id": d["session_id"], # Only return session cleared - "deleted_at": utcnow(), - "jobs": [], - }, -) -@pg_query -@beartype -async def delete_entries_for_session( - *, - developer_id: UUID, - session_id: UUID, -) -> tuple[str, list]: - return ( - entry_query, - [session_id, developer_id], - ) - - -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: partialclass( - HTTPException, - status_code=400, - detail="The specified developer does not exist.", - ), - asyncpg.UniqueViolationError: partialclass( - HTTPException, - status_code=404, - detail="One or more specified entries do not exist.", - ), - } -) -@wrap_in_class( - ResourceDeletedResponse, - transform=lambda d: { - "id": d["entry_id"], - "deleted_at": utcnow(), - "jobs": [], - }, -) -@pg_query -@beartype -async def delete_entries( - *, developer_id: UUID, session_id: UUID, entry_ids: list[UUID] -) -> tuple[str, list]: - return ( - delete_entry_by_ids_query, - [entry_ids, developer_id, session_id], - ) diff --git a/agents-api/agents_api/queries/entries/get_history.py b/agents-api/agents_api/queries/entries/get_history.py index c6c38d366..8f0ddf4a1 100644 --- a/agents-api/agents_api/queries/entries/get_history.py +++ b/agents-api/agents_api/queries/entries/get_history.py @@ -6,7 +6,7 @@ from sqlglot import parse_one from ...autogen.openapi_model import History -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class +from ..utils import pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for getting history with a developer check history_query = parse_one(""" diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py new file mode 100644 index 000000000..a3fa6d0a0 --- /dev/null +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -0,0 +1,112 @@ +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException + +from ...autogen.openapi_model import Entry +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class + +# Query for checking if the session exists +session_exists_query = """ +SELECT CASE + WHEN EXISTS ( + SELECT 1 FROM sessions + WHERE session_id = $1 AND developer_id = $2 + ) + THEN TRUE + ELSE (SELECT NULL::boolean WHERE FALSE) -- This raises a NO_DATA_FOUND error +END; +""" + +list_entries_query = """ +SELECT + e.entry_id as id, + e.session_id, + e.role, + e.name, + e.content, + e.source, + e.token_count, + e.created_at, + e.timestamp, + e.event_type, + e.tool_call_id, + e.tool_calls, + e.model +FROM entries e +JOIN developers d ON d.developer_id = $5 +LEFT JOIN entry_relations er ON er.head = e.entry_id AND er.session_id = e.session_id +WHERE e.session_id = $1 +AND e.source = ANY($2) +AND (er.relation IS NULL OR er.relation != ALL($6)) +ORDER BY e.{sort_by} {direction} -- safe to interpolate +LIMIT $3 +OFFSET $4; +""" + + +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( + status_code=404, + detail=str(exc), + ), + asyncpg.UniqueViolationError: lambda exc: HTTPException( + status_code=409, + detail=str(exc), + ), + asyncpg.NotNullViolationError: lambda exc: HTTPException( + status_code=400, + detail=str(exc), + ), + } +) +@wrap_in_class(Entry) +@increase_counter("list_entries") +@pg_query +@beartype +async def list_entries( + *, + developer_id: UUID, + session_id: UUID, + allowed_sources: list[str] = ["api_request", "api_response"], + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "timestamp"] = "timestamp", + direction: Literal["asc", "desc"] = "asc", + exclude_relations: list[str] = [], +) -> list[tuple[str, list]]: + if limit < 1 or limit > 1000: + raise HTTPException(status_code=400, detail="Limit must be between 1 and 1000") + if offset < 0: + raise HTTPException(status_code=400, detail="Offset must be non-negative") + + query = list_entries_query.format( + sort_by=sort_by, + direction=direction, + ) + + # Parameters for the entry query + entry_params = [ + session_id, # $1 + allowed_sources, # $2 + limit, # $3 + offset, # $4 + developer_id, # $5 + exclude_relations, # $6 + ] + + return [ + ( + session_exists_query, + [session_id, developer_id], + ), + ( + query, + entry_params, + ), + ] + diff --git a/agents-api/agents_api/queries/entries/list_entry.py b/agents-api/agents_api/queries/entries/list_entry.py deleted file mode 100644 index 1fa6479d1..000000000 --- a/agents-api/agents_api/queries/entries/list_entry.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Literal -from uuid import UUID - -import asyncpg -from beartype import beartype -from fastapi import HTTPException - -from ...autogen.openapi_model import Entry -from ..utils import pg_query, rewrap_exceptions, wrap_in_class - -entry_query = """ -SELECT - e.entry_id as id, -- entry_id - e.session_id, -- session_id - e.role, -- role - e.name, -- name - e.content, -- content - e.source, -- source - e.token_count, -- token_count - e.created_at, -- created_at - e.timestamp -- timestamp -FROM entries e -JOIN developers d ON d.developer_id = $7 -LEFT JOIN entry_relations er ON er.head = e.entry_id AND er.session_id = e.session_id -WHERE e.session_id = $1 -AND e.source = ANY($2) -AND (er.relation IS NULL OR er.relation != ALL($8)) -ORDER BY e.$3 $4 -LIMIT $5 -OFFSET $6; -""" - - -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( - status_code=404, - detail=str(exc), - ), - asyncpg.UniqueViolationError: lambda exc: HTTPException( - status_code=404, - detail=str(exc), - ), - } -) -@wrap_in_class(Entry) -@pg_query -@beartype -async def list_entries( - *, - developer_id: UUID, - session_id: UUID, - allowed_sources: list[str] = ["api_request", "api_response"], - limit: int = 1, - offset: int = 0, - sort_by: Literal["created_at", "timestamp"] = "timestamp", - direction: Literal["asc", "desc"] = "asc", - exclude_relations: list[str] = [], -) -> tuple[str, list]: - if limit < 1 or limit > 1000: - raise HTTPException(status_code=400, detail="Limit must be between 1 and 1000") - if offset < 0: - raise HTTPException(status_code=400, detail="Offset must be non-negative") - - # making the parameters for the query - params = [ - session_id, # $1 - allowed_sources, # $2 - sort_by, # $3 - direction, # $4 - limit, # $5 - offset, # $6 - developer_id, # $7 - exclude_relations, # $8 - ] - return ( - entry_query, - params, - ) diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index 3b5dc0bb0..db583e08f 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -3,16 +3,27 @@ import socket import time from functools import partialmethod, wraps -from typing import Any, Awaitable, Callable, ParamSpec, Type, TypeVar, cast +from typing import ( + Any, + Awaitable, + Callable, + Literal, + NotRequired, + ParamSpec, + Type, + TypeVar, + cast, +) import asyncpg -import pandas as pd from asyncpg import Record from beartype import beartype from fastapi import HTTPException from pydantic import BaseModel +from typing_extensions import TypedDict from ..app import app +from ..env import query_timeout P = ParamSpec("P") T = TypeVar("T") @@ -31,15 +42,61 @@ class NewCls(cls): return NewCls +class AsyncPGFetchArgs(TypedDict): + query: str + args: list[Any] + timeout: NotRequired[float | None] + + +type SQLQuery = str +type FetchMethod = Literal["fetch", "fetchmany"] +type PGQueryArgs = tuple[SQLQuery, list[Any]] | tuple[SQLQuery, list[Any], FetchMethod] +type PreparedPGQueryArgs = tuple[FetchMethod, AsyncPGFetchArgs] +type BatchedPreparedPGQueryArgs = list[PreparedPGQueryArgs] + + +@beartype +def prepare_pg_query_args( + query_args: PGQueryArgs | list[PGQueryArgs], +) -> BatchedPreparedPGQueryArgs: + batch = [] + query_args = [query_args] if isinstance(query_args, tuple) else query_args + + for query_arg in query_args: + match query_arg: + case (query, variables) | (query, variables, "fetch"): + batch.append( + ( + "fetch", + AsyncPGFetchArgs( + query=query, args=variables, timeout=query_timeout + ), + ) + ) + case (query, variables, "fetchmany"): + batch.append( + ( + "fetchmany", + AsyncPGFetchArgs( + query=query, args=[variables], timeout=query_timeout + ), + ) + ) + case _: + raise ValueError("Invalid query arguments") + + return batch + + @beartype def pg_query( - func: Callable[P, tuple[str | list[str | None], dict]] | None = None, + func: Callable[P, PGQueryArgs | list[PGQueryArgs]] | None = None, debug: bool | None = None, only_on_error: bool = False, timeit: bool = False, ) -> Callable[..., Callable[P, list[Record]]] | Callable[P, list[Record]]: def pg_query_dec( - func: Callable[P, tuple[str, list[Any]] | list[tuple[str, list[Any]]]], + func: Callable[P, PGQueryArgs | list[PGQueryArgs]], ) -> Callable[..., Callable[P, list[Record]]]: """ Decorator that wraps a function that takes arbitrary arguments, and @@ -57,14 +114,10 @@ async def wrapper( connection_pool: asyncpg.Pool | None = None, **kwargs: P.kwargs, ) -> list[Record]: - query, variables = await func(*args, **kwargs) + query_args = await func(*args, **kwargs) + batch = prepare_pg_query_args(query_args) - not only_on_error and debug and print(query) - not only_on_error and debug and pprint( - dict( - variables=variables, - ) - ) + not only_on_error and debug and pprint(batch) # Run the query pool = ( @@ -73,20 +126,20 @@ async def wrapper( else cast(asyncpg.Pool, app.state.postgres_pool) ) - assert isinstance(variables, list) and len(variables) > 0 - - queries = query if isinstance(query, list) else [query] - variables_list = ( - variables if isinstance(variables[0], list) else [variables] - ) - zipped = zip(queries, variables_list) - try: async with pool.acquire() as conn: async with conn.transaction(): start = timeit and time.perf_counter() - for query, variables in zipped: - results: list[Record] = await conn.fetch(query, *variables) + for method_name, payload in batch: + method = getattr(conn, method_name) + + query = payload["query"] + args = payload["args"] + timeout = payload.get("timeout") + + results: list[Record] = await method( + query, *args, timeout=timeout + ) end = timeit and time.perf_counter() @@ -96,8 +149,7 @@ async def wrapper( except Exception as e: if only_on_error and debug: - print(query) - pprint(variables) + pprint(batch) debug and print(repr(e)) connection_error = isinstance( @@ -113,11 +165,7 @@ async def wrapper( raise - not only_on_error and debug and pprint( - dict( - results=[dict(result.items()) for result in results], - ) - ) + not only_on_error and debug and pprint(results) return results @@ -210,7 +258,7 @@ async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: result: T = await func(*args, **kwargs) except BaseException as error: _check_error(error) - raise + raise error return result @@ -220,7 +268,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: result: T = func(*args, **kwargs) except BaseException as error: _check_error(error) - raise + raise error return result diff --git a/agents-api/agents_api/web.py b/agents-api/agents_api/web.py index 379526e0f..a04a7fc66 100644 --- a/agents-api/agents_api/web.py +++ b/agents-api/agents_api/web.py @@ -20,7 +20,6 @@ from .app import app from .common.exceptions import BaseCommonException -from .dependencies.auth import get_api_key from .env import api_prefix, hostname, protocol, public_port, sentry_dsn from .exceptions import PromptTooBigError diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index c2aa350a8..4a02efac4 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -1,24 +1,12 @@ -import json import random import string -import time from uuid import UUID -import asyncpg from fastapi.testclient import TestClient -from temporalio.client import WorkflowHandle from uuid_extensions import uuid7 from ward import fixture from agents_api.autogen.openapi_model import ( - CreateAgentRequest, - CreateDocRequest, - CreateExecutionRequest, - CreateFileRequest, - CreateSessionRequest, - CreateTaskRequest, - CreateToolRequest, - CreateTransitionRequest, CreateUserRequest, ) from agents_api.clients.pg import create_db_pool @@ -43,7 +31,6 @@ # from agents_api.queries.tools.create_tools import create_tools # from agents_api.queries.tools.delete_tool import delete_tool from agents_api.queries.users.create_user import create_user -from agents_api.queries.users.delete_user import delete_user from agents_api.web import app from .utils import ( diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index c07891305..87d9cdb4f 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -3,27 +3,21 @@ It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. """ -from uuid import UUID +from uuid import uuid4 -from ward import test +from fastapi import HTTPException +from ward import raises, test -from agents_api.autogen.openapi_model import CreateEntryRequest, Entry +from agents_api.autogen.openapi_model import CreateEntryRequest from agents_api.clients.pg import create_db_pool -from agents_api.queries.entries.create_entry import create_entries -from agents_api.queries.entries.delete_entry import delete_entries -from agents_api.queries.entries.get_history import get_history -from agents_api.queries.entries.list_entry import list_entries -from tests.fixtures import pg_dsn, test_developer_id # , test_session +from agents_api.queries.entries import create_entries, list_entries +from tests.fixtures import pg_dsn, test_developer # , test_session -# Test UUIDs for consistent testing MODEL = "gpt-4o-mini" -SESSION_ID = UUID("123e4567-e89b-12d3-a456-426614174001") -TEST_DEVELOPER_ID = UUID("123e4567-e89b-12d3-a456-426614174000") -TEST_USER_ID = UUID("987e6543-e21b-12d3-a456-426614174000") -@test("query: create entry") -async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session +@test("query: create entry no session") +async def _(dsn=pg_dsn, developer=test_developer): """Test the addition of a new entry to the database.""" pool = await create_db_pool(dsn=dsn) @@ -34,12 +28,31 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_sessi content="test entry content", ) - await create_entries( - developer_id=TEST_DEVELOPER_ID, - session_id=SESSION_ID, - data=[test_entry], - connection_pool=pool, - ) + with raises(HTTPException) as exc_info: + await create_entries( + developer_id=developer.id, + session_id=uuid4(), + data=[test_entry], + connection_pool=pool, + ) + + assert exc_info.raised.status_code == 404 + + +@test("query: list entries no session") +async def _(dsn=pg_dsn, developer=test_developer): + """Test the retrieval of entries from the database.""" + + pool = await create_db_pool(dsn=dsn) + + with raises(HTTPException) as exc_info: + await list_entries( + developer_id=developer.id, + session_id=uuid4(), + connection_pool=pool, + ) + + assert exc_info.raised.status_code == 404 # @test("query: get entries") diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index d182586dc..4fdc7e6e4 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -3,36 +3,21 @@ Tests verify the SQL queries without actually executing them against a database. """ -from uuid import UUID - -import asyncpg from uuid_extensions import uuid7 from ward import raises, test from agents_api.autogen.openapi_model import ( - CreateOrUpdateSessionRequest, - CreateSessionRequest, - PatchSessionRequest, - ResourceDeletedResponse, - ResourceUpdatedResponse, Session, - UpdateSessionRequest, ) from agents_api.clients.pg import create_db_pool from agents_api.queries.sessions import ( count_sessions, - create_or_update_session, - create_session, - delete_session, get_session, list_sessions, - patch_session, - update_session, ) from tests.fixtures import ( pg_dsn, test_developer_id, - test_user, ) # , test_session, test_agent, test_user # @test("query: create session sql") From 2b8686c2f52996899eb41cf35a0dbacbc0d07d06 Mon Sep 17 00:00:00 2001 From: creatorrr Date: Wed, 18 Dec 2024 14:27:48 +0000 Subject: [PATCH 14/29] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/entries/list_entries.py | 1 - 1 file changed, 1 deletion(-) diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py index a3fa6d0a0..0aeb92a25 100644 --- a/agents-api/agents_api/queries/entries/list_entries.py +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -109,4 +109,3 @@ async def list_entries( entry_params, ), ] - From 94aa3ce1684b0a058d4b3bd0cf68e630918fb2cb Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Wed, 18 Dec 2024 18:11:02 +0300 Subject: [PATCH 15/29] fix(agents-api): change modelname to model in BaseEntry --- agents-api/agents_api/autogen/Entries.py | 2 +- agents-api/agents_api/autogen/openapi_model.py | 2 +- agents-api/agents_api/queries/entries/create_entries.py | 2 +- agents-api/agents_api/queries/entries/delete_entries.py | 2 +- integrations-service/integrations/autogen/Entries.py | 2 +- typespec/entries/models.tsp | 2 +- typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml | 6 +++--- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/agents-api/agents_api/autogen/Entries.py b/agents-api/agents_api/autogen/Entries.py index d195b518f..867b10192 100644 --- a/agents-api/agents_api/autogen/Entries.py +++ b/agents-api/agents_api/autogen/Entries.py @@ -52,7 +52,7 @@ class BaseEntry(BaseModel): ] tokenizer: str token_count: int - modelname: str = "gpt-40-mini" + model: str = "gpt-4o-mini" tool_calls: ( list[ ChosenFunctionCall diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index 01042c58c..af73e8015 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -400,7 +400,7 @@ def from_model_input( source=source, tokenizer=tokenizer["type"], token_count=token_count, - modelname=model, + model=model, **kwargs, ) diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py index ffbd2de22..24c0be26e 100644 --- a/agents-api/agents_api/queries/entries/create_entries.py +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -107,7 +107,7 @@ async def create_entries( content_to_json(item.get("content") or {}), # $7 item.get("tool_call_id"), # $8 content_to_json(item.get("tool_calls") or {}), # $9 - item.get("modelname"), # $10 + item.get("model"), # $10 item.get("token_count"), # $11 item.get("created_at") or utcnow(), # $12 utcnow(), # $13 diff --git a/agents-api/agents_api/queries/entries/delete_entries.py b/agents-api/agents_api/queries/entries/delete_entries.py index 9a5d6faa3..dfdadb8da 100644 --- a/agents-api/agents_api/queries/entries/delete_entries.py +++ b/agents-api/agents_api/queries/entries/delete_entries.py @@ -9,7 +9,7 @@ from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow from ...metrics.counters import increase_counter -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class +from ..utils import pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for deleting entries with a developer check delete_entry_query = parse_one(""" diff --git a/integrations-service/integrations/autogen/Entries.py b/integrations-service/integrations/autogen/Entries.py index d195b518f..867b10192 100644 --- a/integrations-service/integrations/autogen/Entries.py +++ b/integrations-service/integrations/autogen/Entries.py @@ -52,7 +52,7 @@ class BaseEntry(BaseModel): ] tokenizer: str token_count: int - modelname: str = "gpt-40-mini" + model: str = "gpt-4o-mini" tool_calls: ( list[ ChosenFunctionCall diff --git a/typespec/entries/models.tsp b/typespec/entries/models.tsp index 640e6831d..d7eae55e7 100644 --- a/typespec/entries/models.tsp +++ b/typespec/entries/models.tsp @@ -107,7 +107,7 @@ model BaseEntry { tokenizer: string; token_count: uint16; - modelname: string = "gpt-40-mini"; + "model": string = "gpt-4o-mini"; /** Tool calls generated by the model. */ tool_calls?: ChosenToolCall[] | null = null; diff --git a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml index 9b36baa2b..9298ab458 100644 --- a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml +++ b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml @@ -3064,7 +3064,7 @@ components: - source - tokenizer - token_count - - modelname + - model - timestamp properties: role: @@ -3308,9 +3308,9 @@ components: token_count: type: integer format: uint16 - modelname: + model: type: string - default: gpt-40-mini + default: gpt-4o-mini tool_calls: type: array items: From 64a34cdac3883d63d1764e9473fcab982ab346bd Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Tue, 17 Dec 2024 13:39:21 +0300 Subject: [PATCH 16/29] feat(agents-api): add agent queries tests --- .../agents_api/queries/agents/__init__.py | 12 +- .../agents_api/queries/agents/create_agent.py | 61 ++- .../queries/agents/create_or_update_agent.py | 21 +- .../agents_api/queries/agents/delete_agent.py | 23 +- .../agents_api/queries/agents/get_agent.py | 24 +- .../agents_api/queries/agents/list_agents.py | 23 +- .../agents_api/queries/agents/patch_agent.py | 23 +- .../agents_api/queries/agents/update_agent.py | 23 +- agents-api/tests/fixtures.py | 34 +- agents-api/tests/test_agent_queries.py | 350 ++++++++++-------- 10 files changed, 307 insertions(+), 287 deletions(-) diff --git a/agents-api/agents_api/queries/agents/__init__.py b/agents-api/agents_api/queries/agents/__init__.py index 709b051ea..ebd169040 100644 --- a/agents-api/agents_api/queries/agents/__init__.py +++ b/agents-api/agents_api/queries/agents/__init__.py @@ -13,9 +13,9 @@ # ruff: noqa: F401, F403, F405 from .create_agent import create_agent -from .create_or_update_agent import create_or_update_agent_query -from .delete_agent import delete_agent_query -from .get_agent import get_agent_query -from .list_agents import list_agents_query -from .patch_agent import patch_agent_query -from .update_agent import update_agent_query +from .create_or_update_agent import create_or_update_agent +from .delete_agent import delete_agent +from .get_agent import get_agent +from .list_agents import list_agents +from .patch_agent import patch_agent +from .update_agent import update_agent diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index 4c731d3dd..cbdb32972 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -8,7 +8,6 @@ from beartype import beartype from fastapi import HTTPException -from psycopg import errors as psycopg_errors from pydantic import ValidationError from uuid_extensions import uuid7 @@ -25,35 +24,35 @@ T = TypeVar("T") -@rewrap_exceptions( - { - psycopg_errors.ForeignKeyViolation: partialclass( - HTTPException, - status_code=404, - detail="The specified developer does not exist.", - ), - psycopg_errors.UniqueViolation: partialclass( - HTTPException, - status_code=409, - detail="An agent with this canonical name already exists for this developer.", - ), - psycopg_errors.CheckViolation: partialclass( - HTTPException, - status_code=400, - detail="The provided data violates one or more constraints. Please check the input values.", - ), - ValidationError: partialclass( - HTTPException, - status_code=400, - detail="Input validation failed. Please check the provided data.", - ), - TypeError: partialclass( - HTTPException, - status_code=400, - detail="A type mismatch occurred. Please review the input.", - ), - } -) +# @rewrap_exceptions( +# { +# psycopg_errors.ForeignKeyViolation: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified developer does not exist.", +# ), +# psycopg_errors.UniqueViolation: partialclass( +# HTTPException, +# status_code=409, +# detail="An agent with this canonical name already exists for this developer.", +# ), +# psycopg_errors.CheckViolation: partialclass( +# HTTPException, +# status_code=400, +# detail="The provided data violates one or more constraints. Please check the input values.", +# ), +# ValidationError: partialclass( +# HTTPException, +# status_code=400, +# detail="Input validation failed. Please check the provided data.", +# ), +# TypeError: partialclass( +# HTTPException, +# status_code=400, +# detail="A type mismatch occurred. Please review the input.", +# ), +# } +# ) @wrap_in_class( Agent, one=True, @@ -63,7 +62,7 @@ @pg_query # @increase_counter("create_agent") @beartype -def create_agent( +async def create_agent( *, developer_id: UUID, agent_id: UUID | None = None, diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index 96681255c..9c92f0b46 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -8,7 +8,6 @@ from beartype import beartype from fastapi import HTTPException -from psycopg import errors as psycopg_errors from ...autogen.openapi_model import Agent, CreateOrUpdateAgentRequest from ..utils import ( @@ -23,15 +22,15 @@ T = TypeVar("T") -@rewrap_exceptions( - { - psycopg_errors.ForeignKeyViolation: partialclass( - HTTPException, - status_code=404, - detail="The specified developer does not exist.", - ) - } -) +# @rewrap_exceptions( +# { +# psycopg_errors.ForeignKeyViolation: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified developer does not exist.", +# ) +# } +# ) @wrap_in_class( Agent, one=True, @@ -41,7 +40,7 @@ @pg_query # @increase_counter("create_or_update_agent1") @beartype -def create_or_update_agent_query( +async def create_or_update_agent( *, agent_id: UUID, developer_id: UUID, data: CreateOrUpdateAgentRequest ) -> tuple[list[str], dict]: """ diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index f3c64fd18..545a976d5 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -8,7 +8,6 @@ from beartype import beartype from fastapi import HTTPException -from psycopg import errors as psycopg_errors from ...autogen.openapi_model import ResourceDeletedResponse from ..utils import ( @@ -22,16 +21,16 @@ T = TypeVar("T") -@rewrap_exceptions( - { - psycopg_errors.ForeignKeyViolation: partialclass( - HTTPException, - status_code=404, - detail="The specified developer does not exist.", - ) - } - # TODO: Add more exceptions -) +# @rewrap_exceptions( +# { +# psycopg_errors.ForeignKeyViolation: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified developer does not exist.", +# ) +# } +# # TODO: Add more exceptions +# ) @wrap_in_class( ResourceDeletedResponse, one=True, @@ -42,7 +41,7 @@ @pg_query # @increase_counter("delete_agent1") @beartype -def delete_agent_query(*, agent_id: UUID, developer_id: UUID) -> tuple[list[str], dict]: +async def delete_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[list[str], dict]: """ Constructs the SQL queries to delete an agent and its related settings. diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index 5e0edbb98..18d253e8d 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -8,8 +8,6 @@ from beartype import beartype from fastapi import HTTPException -from psycopg import errors as psycopg_errors - from ...autogen.openapi_model import Agent from ..utils import ( partialclass, @@ -22,21 +20,21 @@ T = TypeVar("T") -@rewrap_exceptions( - { - psycopg_errors.ForeignKeyViolation: partialclass( - HTTPException, - status_code=404, - detail="The specified developer does not exist.", - ) - } - # TODO: Add more exceptions -) +# @rewrap_exceptions( + # { + # psycopg_errors.ForeignKeyViolation: partialclass( + # HTTPException, + # status_code=404, + # detail="The specified developer does not exist.", + # ) + # } + # # TODO: Add more exceptions +# ) @wrap_in_class(Agent, one=True) @pg_query # @increase_counter("get_agent1") @beartype -def get_agent_query(*, agent_id: UUID, developer_id: UUID) -> tuple[list[str], dict]: +async def get_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[list[str], dict]: """ Constructs the SQL query to retrieve an agent's details. diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index 5fda7c626..c24276a97 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -8,7 +8,6 @@ from beartype import beartype from fastapi import HTTPException -from psycopg import errors as psycopg_errors from ...autogen.openapi_model import Agent from ..utils import ( @@ -22,21 +21,21 @@ T = TypeVar("T") -@rewrap_exceptions( - { - psycopg_errors.ForeignKeyViolation: partialclass( - HTTPException, - status_code=404, - detail="The specified developer does not exist.", - ) - } - # TODO: Add more exceptions -) +# @rewrap_exceptions( +# { +# psycopg_errors.ForeignKeyViolation: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified developer does not exist.", +# ) +# } +# # TODO: Add more exceptions +# ) @wrap_in_class(Agent) @pg_query # @increase_counter("list_agents1") @beartype -def list_agents_query( +async def list_agents( *, developer_id: UUID, limit: int = 100, diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index 450cbf8cc..d4adff092 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -8,7 +8,6 @@ from beartype import beartype from fastapi import HTTPException -from psycopg import errors as psycopg_errors from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse from ..utils import ( @@ -22,16 +21,16 @@ T = TypeVar("T") -@rewrap_exceptions( - { - psycopg_errors.ForeignKeyViolation: partialclass( - HTTPException, - status_code=404, - detail="The specified developer does not exist.", - ) - } - # TODO: Add more exceptions -) +# @rewrap_exceptions( +# { +# psycopg_errors.ForeignKeyViolation: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified developer does not exist.", +# ) +# } +# # TODO: Add more exceptions +# ) @wrap_in_class( ResourceUpdatedResponse, one=True, @@ -41,7 +40,7 @@ @pg_query # @increase_counter("patch_agent1") @beartype -def patch_agent_query( +async def patch_agent( *, agent_id: UUID, developer_id: UUID, data: PatchAgentRequest ) -> tuple[str, dict]: """ diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index 61548de70..2116e49b0 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -8,7 +8,6 @@ from beartype import beartype from fastapi import HTTPException -from psycopg import errors as psycopg_errors from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest from ..utils import ( @@ -22,16 +21,16 @@ T = TypeVar("T") -@rewrap_exceptions( - { - psycopg_errors.ForeignKeyViolation: partialclass( - HTTPException, - status_code=404, - detail="The specified developer does not exist.", - ) - } - # TODO: Add more exceptions -) +# @rewrap_exceptions( +# { +# psycopg_errors.ForeignKeyViolation: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified developer does not exist.", +# ) +# } +# # TODO: Add more exceptions +# ) @wrap_in_class( ResourceUpdatedResponse, one=True, @@ -41,7 +40,7 @@ @pg_query # @increase_counter("update_agent1") @beartype -def update_agent_query( +async def update_agent( *, agent_id: UUID, developer_id: UUID, data: UpdateAgentRequest ) -> tuple[str, dict]: """ diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 4a02efac4..1151b433d 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -13,7 +13,7 @@ from agents_api.env import api_key, api_key_header_name, multi_tenant_mode from agents_api.queries.developers.create_developer import create_developer -# from agents_api.queries.agents.create_agent import create_agent +from agents_api.queries.agents.create_agent import create_agent # from agents_api.queries.agents.delete_agent import delete_agent from agents_api.queries.developers.get_developer import get_developer @@ -93,20 +93,24 @@ def patch_embed_acompletion(): yield embed, acompletion -# @fixture(scope="global") -# async def test_agent(dsn=pg_dsn, developer_id=test_developer_id): -# async with get_pg_client(dsn=dsn) as client: -# agent = await create_agent( -# developer_id=developer_id, -# data=CreateAgentRequest( -# model="gpt-4o-mini", -# name="test agent", -# about="test agent about", -# metadata={"test": "test"}, -# ), -# client=client, -# ) -# yield agent +@fixture(scope="global") +async def test_agent(dsn=pg_dsn, developer=test_developer): + pool = await asyncpg.create_pool(dsn=dsn) + + async with get_pg_client(pool=pool) as client: + agent = await create_agent( + developer_id=developer.id, + data=CreateAgentRequest( + model="gpt-4o-mini", + name="test agent", + about="test agent about", + metadata={"test": "test"}, + ), + client=client, + ) + + yield agent + await pool.close() @fixture(scope="global") diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index f079642b3..f8f75fd0b 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -1,163 +1,187 @@ -# # Tests for agent queries - -# from uuid_extensions import uuid7 -# from ward import raises, test - -# from agents_api.autogen.openapi_model import ( -# Agent, -# CreateAgentRequest, -# CreateOrUpdateAgentRequest, -# PatchAgentRequest, -# ResourceUpdatedResponse, -# UpdateAgentRequest, -# ) -# from agents_api.queries.agent.create_agent import create_agent -# from agents_api.queries.agent.create_or_update_agent import create_or_update_agent -# from agents_api.queries.agent.delete_agent import delete_agent -# from agents_api.queries.agent.get_agent import get_agent -# from agents_api.queries.agent.list_agents import list_agents -# from agents_api.queries.agent.patch_agent import patch_agent -# from agents_api.queries.agent.update_agent import update_agent -# from tests.fixtures import cozo_client, test_agent, test_developer_id - - -# @test("query: create agent") -# def _(client=cozo_client, developer_id=test_developer_id): -# create_agent( -# developer_id=developer_id, -# data=CreateAgentRequest( -# name="test agent", -# about="test agent about", -# model="gpt-4o-mini", -# ), -# client=client, -# ) - - -# @test("query: create agent with instructions") -# def _(client=cozo_client, developer_id=test_developer_id): -# create_agent( -# developer_id=developer_id, -# data=CreateAgentRequest( -# name="test agent", -# about="test agent about", -# model="gpt-4o-mini", -# instructions=["test instruction"], -# ), -# client=client, -# ) - - -# @test("query: create or update agent") -# def _(client=cozo_client, developer_id=test_developer_id): -# create_or_update_agent( -# developer_id=developer_id, -# agent_id=uuid7(), -# data=CreateOrUpdateAgentRequest( -# name="test agent", -# about="test agent about", -# model="gpt-4o-mini", -# instructions=["test instruction"], -# ), -# client=client, -# ) - - -# @test("query: get agent not exists") -# def _(client=cozo_client, developer_id=test_developer_id): -# agent_id = uuid7() - -# with raises(Exception): -# get_agent(agent_id=agent_id, developer_id=developer_id, client=client) - - -# @test("query: get agent exists") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# result = get_agent(agent_id=agent.id, developer_id=developer_id, client=client) - -# assert result is not None -# assert isinstance(result, Agent) - - -# @test("query: delete agent") -# def _(client=cozo_client, developer_id=test_developer_id): -# temp_agent = create_agent( -# developer_id=developer_id, -# data=CreateAgentRequest( -# name="test agent", -# about="test agent about", -# model="gpt-4o-mini", -# instructions=["test instruction"], -# ), -# client=client, -# ) - -# # Delete the agent -# delete_agent(agent_id=temp_agent.id, developer_id=developer_id, client=client) - -# # Check that the agent is deleted -# with raises(Exception): -# get_agent(agent_id=temp_agent.id, developer_id=developer_id, client=client) - - -# @test("query: update agent") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# result = update_agent( -# agent_id=agent.id, -# developer_id=developer_id, -# data=UpdateAgentRequest( -# name="updated agent", -# about="updated agent about", -# model="gpt-4o-mini", -# default_settings={"temperature": 1.0}, -# metadata={"hello": "world"}, -# ), -# client=client, -# ) - -# assert result is not None -# assert isinstance(result, ResourceUpdatedResponse) - -# agent = get_agent( -# agent_id=agent.id, -# developer_id=developer_id, -# client=client, -# ) - -# assert "test" not in agent.metadata - - -# @test("query: patch agent") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# result = patch_agent( -# agent_id=agent.id, -# developer_id=developer_id, -# data=PatchAgentRequest( -# name="patched agent", -# about="patched agent about", -# default_settings={"temperature": 1.0}, -# metadata={"something": "else"}, -# ), -# client=client, -# ) - -# assert result is not None -# assert isinstance(result, ResourceUpdatedResponse) - -# agent = get_agent( -# agent_id=agent.id, -# developer_id=developer_id, -# client=client, -# ) - -# assert "hello" in agent.metadata - - -# @test("query: list agents") -# def _(client=cozo_client, developer_id=test_developer_id, agent=test_agent): -# """Tests listing all agents associated with a developer in the database. Verifies that the correct list of agents is retrieved.""" - -# result = list_agents(developer_id=developer_id, client=client) - -# assert isinstance(result, list) -# assert all(isinstance(agent, Agent) for agent in result) +# Tests for agent queries +from uuid import uuid4 + +import asyncpg +from ward import raises, test + +from agents_api.autogen.openapi_model import ( + Agent, + CreateAgentRequest, + CreateOrUpdateAgentRequest, + PatchAgentRequest, + ResourceUpdatedResponse, + UpdateAgentRequest, +) +from agents_api.clients.pg import get_pg_client +from agents_api.queries.agents import ( + create_agent, + create_or_update_agent, + delete_agent, + get_agent, + list_agents, + patch_agent, + update_agent, +) +from tests.fixtures import pg_dsn, test_agent, test_developer_id + + +@test("model: create agent") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + pool = await asyncpg.create_pool(dsn=dsn) + async with get_pg_client(pool=pool) as client: + await create_agent( + developer_id=developer_id, + data=CreateAgentRequest( + name="test agent", + about="test agent about", + model="gpt-4o-mini", + ), + client=client, + ) + + +@test("model: create agent with instructions") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + pool = await asyncpg.create_pool(dsn=dsn) + async with get_pg_client(pool=pool) as client: + await create_agent( + developer_id=developer_id, + data=CreateAgentRequest( + name="test agent", + about="test agent about", + model="gpt-4o-mini", + instructions=["test instruction"], + ), + client=client, + ) + + +@test("model: create or update agent") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + pool = await asyncpg.create_pool(dsn=dsn) + async with get_pg_client(pool=pool) as client: + await create_or_update_agent( + developer_id=developer_id, + agent_id=uuid4(), + data=CreateOrUpdateAgentRequest( + name="test agent", + about="test agent about", + model="gpt-4o-mini", + instructions=["test instruction"], + ), + client=client, + ) + + +@test("model: get agent not exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + agent_id = uuid4() + pool = await asyncpg.create_pool(dsn=dsn) + + with raises(Exception): + async with get_pg_client(pool=pool) as client: + await get_agent(agent_id=agent_id, developer_id=developer_id, client=client) + + +@test("model: get agent exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + pool = await asyncpg.create_pool(dsn=dsn) + async with get_pg_client(pool=pool) as client: + result = await get_agent(agent_id=agent.id, developer_id=developer_id, client=client) + + assert result is not None + assert isinstance(result, Agent) + + +@test("model: delete agent") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + pool = await asyncpg.create_pool(dsn=dsn) + async with get_pg_client(pool=pool) as client: + temp_agent = await create_agent( + developer_id=developer_id, + data=CreateAgentRequest( + name="test agent", + about="test agent about", + model="gpt-4o-mini", + instructions=["test instruction"], + ), + client=client, + ) + + # Delete the agent + await delete_agent(agent_id=temp_agent.id, developer_id=developer_id, client=client) + + # Check that the agent is deleted + with raises(Exception): + await get_agent(agent_id=temp_agent.id, developer_id=developer_id, client=client) + + +@test("model: update agent") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + pool = await asyncpg.create_pool(dsn=dsn) + async with get_pg_client(pool=pool) as client: + result = await update_agent( + agent_id=agent.id, + developer_id=developer_id, + data=UpdateAgentRequest( + name="updated agent", + about="updated agent about", + model="gpt-4o-mini", + default_settings={"temperature": 1.0}, + metadata={"hello": "world"}, + ), + client=client, + ) + + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) + + async with get_pg_client(pool=pool) as client: + agent = await get_agent( + agent_id=agent.id, + developer_id=developer_id, + client=client, + ) + + assert "test" not in agent.metadata + + +@test("model: patch agent") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + pool = await asyncpg.create_pool(dsn=dsn) + async with get_pg_client(pool=pool) as client: + result = await patch_agent( + agent_id=agent.id, + developer_id=developer_id, + data=PatchAgentRequest( + name="patched agent", + about="patched agent about", + default_settings={"temperature": 1.0}, + metadata={"something": "else"}, + ), + client=client, + ) + + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) + + async with get_pg_client(pool=pool) as client: + agent = await get_agent( + agent_id=agent.id, + developer_id=developer_id, + client=client, + ) + + assert "hello" in agent.metadata + + +@test("model: list agents") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Tests listing all agents associated with a developer in the database. Verifies that the correct list of agents is retrieved.""" + + pool = await asyncpg.create_pool(dsn=dsn) + async with get_pg_client(pool=pool) as client: + result = await list_agents(developer_id=developer_id, client=client) + + assert isinstance(result, list) + assert all(isinstance(agent, Agent) for agent in result) From 8cc2ae31b95e596edc69f0ccf80f7695afd52a24 Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Wed, 18 Dec 2024 11:48:48 +0300 Subject: [PATCH 17/29] feat(agents-api): implement agent queries and tests --- .../agents_api/queries/agents/create_agent.py | 85 ++++--- .../queries/agents/create_or_update_agent.py | 88 ++++--- .../agents_api/queries/agents/delete_agent.py | 82 +++--- .../agents_api/queries/agents/get_agent.py | 53 ++-- .../agents_api/queries/agents/list_agents.py | 82 +++--- .../agents_api/queries/agents/patch_agent.py | 73 ++++-- .../agents_api/queries/agents/update_agent.py | 57 +++-- agents-api/agents_api/queries/utils.py | 14 + agents-api/tests/fixtures.py | 26 +- agents-api/tests/test_agent_queries.py | 239 ++++++++---------- 10 files changed, 430 insertions(+), 369 deletions(-) diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index cbdb32972..63ac4870f 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -6,6 +6,7 @@ from typing import Any, TypeVar from uuid import UUID +from sqlglot import parse_one from beartype import beartype from fastapi import HTTPException from pydantic import ValidationError @@ -13,7 +14,7 @@ from ...autogen.openapi_model import Agent, CreateAgentRequest from ..utils import ( - # generate_canonical_name, + generate_canonical_name, partialclass, pg_query, rewrap_exceptions, @@ -23,6 +24,33 @@ ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") +raw_query = """ +INSERT INTO agents ( + developer_id, + agent_id, + canonical_name, + name, + about, + instructions, + model, + metadata, + default_settings +) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8, + $9 +) +RETURNING *; +""" + +query = parse_one(raw_query).sql(pretty=True) # @rewrap_exceptions( # { @@ -57,17 +85,16 @@ Agent, one=True, transform=lambda d: {"id": d["agent_id"], **d}, - _kind="inserted", ) -@pg_query # @increase_counter("create_agent") +@pg_query @beartype async def create_agent( *, developer_id: UUID, agent_id: UUID | None = None, data: CreateAgentRequest, -) -> tuple[str, dict]: +) -> tuple[str, list]: """ Constructs and executes a SQL query to create a new agent in the database. @@ -90,49 +117,23 @@ async def create_agent( # Convert default_settings to dict if it exists default_settings = ( - data.default_settings.model_dump() if data.default_settings else None + data.default_settings.model_dump() if data.default_settings else {} ) # Set default values - data.metadata = data.metadata or None - # data.canonical_name = data.canonical_name or generate_canonical_name(data.name) + data.metadata = data.metadata or {} + data.canonical_name = data.canonical_name or generate_canonical_name(data.name) - query = """ - INSERT INTO agents ( + params = [ developer_id, agent_id, - canonical_name, - name, - about, - instructions, - model, - metadata, - default_settings - ) - VALUES ( - %(developer_id)s, - %(agent_id)s, - %(canonical_name)s, - %(name)s, - %(about)s, - %(instructions)s, - %(model)s, - %(metadata)s, - %(default_settings)s - ) - RETURNING *; - """ - - params = { - "developer_id": developer_id, - "agent_id": agent_id, - "canonical_name": data.canonical_name, - "name": data.name, - "about": data.about, - "instructions": data.instructions, - "model": data.model, - "metadata": data.metadata, - "default_settings": default_settings, - } + data.canonical_name, + data.name, + data.about, + data.instructions, + data.model, + data.metadata, + default_settings, + ] return query, params diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index 9c92f0b46..bbb897fe5 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -6,12 +6,15 @@ from typing import Any, TypeVar from uuid import UUID +from sqlglot import parse_one +from sqlglot.optimizer import optimize + from beartype import beartype from fastapi import HTTPException from ...autogen.openapi_model import Agent, CreateOrUpdateAgentRequest from ..utils import ( - # generate_canonical_name, + generate_canonical_name, partialclass, pg_query, rewrap_exceptions, @@ -21,6 +24,34 @@ ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") +raw_query = """ +INSERT INTO agents ( + developer_id, + agent_id, + canonical_name, + name, + about, + instructions, + model, + metadata, + default_settings +) +VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7, + $8, + $9 +) +RETURNING *; +""" + +query = parse_one(raw_query).sql(pretty=True) + # @rewrap_exceptions( # { @@ -35,14 +66,13 @@ Agent, one=True, transform=lambda d: {"id": d["agent_id"], **d}, - _kind="inserted", ) +# @increase_counter("create_or_update_agent") @pg_query -# @increase_counter("create_or_update_agent1") @beartype async def create_or_update_agent( *, agent_id: UUID, developer_id: UUID, data: CreateOrUpdateAgentRequest -) -> tuple[list[str], dict]: +) -> tuple[str, list]: """ Constructs the SQL queries to create a new agent or update an existing agent's details. @@ -64,49 +94,23 @@ async def create_or_update_agent( # Convert default_settings to dict if it exists default_settings = ( - data.default_settings.model_dump() if data.default_settings else None + data.default_settings.model_dump() if data.default_settings else {} ) # Set default values - data.metadata = data.metadata or None - # data.canonical_name = data.canonical_name or generate_canonical_name(data.name) + data.metadata = data.metadata or {} + data.canonical_name = data.canonical_name or generate_canonical_name(data.name) - query = """ - INSERT INTO agents ( + params = [ developer_id, agent_id, - canonical_name, - name, - about, - instructions, - model, - metadata, - default_settings - ) - VALUES ( - %(developer_id)s, - %(agent_id)s, - %(canonical_name)s, - %(name)s, - %(about)s, - %(instructions)s, - %(model)s, - %(metadata)s, - %(default_settings)s - ) - RETURNING *; - """ - - params = { - "developer_id": developer_id, - "agent_id": agent_id, - "canonical_name": data.canonical_name, - "name": data.name, - "about": data.about, - "instructions": data.instructions, - "model": data.model, - "metadata": data.metadata, - "default_settings": default_settings, - } + data.canonical_name, + data.name, + data.about, + data.instructions, + data.model, + data.metadata, + default_settings, + ] return (query, params) diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index 545a976d5..a5062f783 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -16,10 +16,40 @@ rewrap_exceptions, wrap_in_class, ) +from beartype import beartype +from sqlglot import parse_one +from sqlglot.optimizer import optimize +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") +raw_query = """ +WITH deleted_docs AS ( + DELETE FROM docs + WHERE developer_id = $1 + AND doc_id IN ( + SELECT ad.doc_id + FROM agent_docs ad + WHERE ad.agent_id = $2 + AND ad.developer_id = $1 + ) +), deleted_agent_docs AS ( + DELETE FROM agent_docs + WHERE agent_id = $2 AND developer_id = $1 +), deleted_tools AS ( + DELETE FROM tools + WHERE agent_id = $2 AND developer_id = $1 +) +DELETE FROM agents +WHERE agent_id = $2 AND developer_id = $1 +RETURNING developer_id, agent_id; +""" + + +# Convert the list of queries into a single query string +query = parse_one(raw_query).sql(pretty=True) # @rewrap_exceptions( # { @@ -34,57 +64,23 @@ @wrap_in_class( ResourceDeletedResponse, one=True, - transform=lambda d: { - "id": d["agent_id"], - }, + transform=lambda d: {**d, "id": d["agent_id"], "deleted_at": utcnow()}, ) +# @increase_counter("delete_agent") @pg_query -# @increase_counter("delete_agent1") @beartype -async def delete_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[list[str], dict]: +async def delete_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]: """ - Constructs the SQL queries to delete an agent and its related settings. + Constructs the SQL query to delete an agent and its related settings. Args: agent_id (UUID): The UUID of the agent to be deleted. developer_id (UUID): The UUID of the developer owning the agent. Returns: - tuple[list[str], dict]: A tuple containing the list of SQL queries and their parameters. + tuple[str, list]: A tuple containing the SQL query and its parameters. """ - - queries = [ - """ - -- Delete docs that were only associated with this agent - DELETE FROM docs - WHERE developer_id = %(developer_id)s - AND doc_id IN ( - SELECT ad.doc_id - FROM agent_docs ad - WHERE ad.agent_id = %(agent_id)s - AND ad.developer_id = %(developer_id)s - ); - """, - """ - -- Delete agent_docs entries - DELETE FROM agent_docs - WHERE agent_id = %(agent_id)s AND developer_id = %(developer_id)s; - """, - """ - -- Delete tools related to the agent - DELETE FROM tools - WHERE agent_id = %(agent_id)s AND developer_id = %(developer_id)s; - """, - """ - -- Delete the agent - DELETE FROM agents - WHERE agent_id = %(agent_id)s AND developer_id = %(developer_id)s; - """, - ] - - params = { - "agent_id": agent_id, - "developer_id": developer_id, - } - - return (queries, params) + # Note: We swap the parameter order because the queries use $1 for developer_id and $2 for agent_id + params = [developer_id, agent_id] + + return (query, params) diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index 18d253e8d..061d0b165 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -9,12 +9,39 @@ from beartype import beartype from fastapi import HTTPException from ...autogen.openapi_model import Agent +from ...metrics.counters import increase_counter +from sqlglot import parse_one +from sqlglot.optimizer import optimize from ..utils import ( partialclass, pg_query, rewrap_exceptions, wrap_in_class, ) +from beartype import beartype + +from ...autogen.openapi_model import Agent + +raw_query = """ +SELECT + agent_id, + developer_id, + name, + canonical_name, + about, + instructions, + model, + metadata, + default_settings, + created_at, + updated_at +FROM + agents +WHERE + agent_id = $2 AND developer_id = $1; +""" + +query = parse_one(raw_query).sql(pretty=True) ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") @@ -30,11 +57,11 @@ # } # # TODO: Add more exceptions # ) -@wrap_in_class(Agent, one=True) +@wrap_in_class(Agent, one=True, transform=lambda d: {"id": d["agent_id"], **d}) +# @increase_counter("get_agent") @pg_query -# @increase_counter("get_agent1") @beartype -async def get_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[list[str], dict]: +async def get_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]: """ Constructs the SQL query to retrieve an agent's details. @@ -45,23 +72,5 @@ async def get_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[list[str], d Returns: tuple[list[str], dict]: A tuple containing the SQL query and its parameters. """ - query = """ - SELECT - agent_id, - developer_id, - name, - canonical_name, - about, - instructions, - model, - metadata, - default_settings, - created_at, - updated_at - FROM - agents - WHERE - agent_id = %(agent_id)s AND developer_id = %(developer_id)s; - """ - return (query, {"agent_id": agent_id, "developer_id": developer_id}) + return (query, [developer_id, agent_id]) diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index c24276a97..92165e414 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -16,12 +16,42 @@ rewrap_exceptions, wrap_in_class, ) +from beartype import beartype +from sqlglot import parse_one +from sqlglot.optimizer import optimize + +from ...autogen.openapi_model import Agent ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") +raw_query = """ +SELECT + agent_id, + developer_id, + name, + canonical_name, + about, + instructions, + model, + metadata, + default_settings, + created_at, + updated_at +FROM agents +WHERE developer_id = $1 $7 +ORDER BY + CASE WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at END ASC NULLS LAST, + CASE WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at END DESC NULLS LAST, + CASE WHEN $4 = 'updated_at' AND $5 = 'asc' THEN updated_at END ASC NULLS LAST, + CASE WHEN $4 = 'updated_at' AND $5 = 'desc' THEN updated_at END DESC NULLS LAST +LIMIT $2 OFFSET $3; +""" + +query = raw_query + -# @rewrap_exceptions( +# @rewrap_exceptions( # { # psycopg_errors.ForeignKeyViolation: partialclass( # HTTPException, @@ -31,9 +61,9 @@ # } # # TODO: Add more exceptions # ) -@wrap_in_class(Agent) +@wrap_in_class(Agent, transform=lambda d: {"id": d["agent_id"], **d}) +# @increase_counter("list_agents") @pg_query -# @increase_counter("list_agents1") @beartype async def list_agents( *, @@ -43,7 +73,7 @@ async def list_agents( sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", metadata_filter: dict[str, Any] = {}, -) -> tuple[str, dict]: +) -> tuple[str, list]: """ Constructs query to list agents for a developer with pagination. @@ -63,33 +93,25 @@ async def list_agents( raise HTTPException(status_code=400, detail="Invalid sort direction") # Build metadata filter clause if needed - metadata_clause = "" - if metadata_filter: - metadata_clause = "AND metadata @> %(metadata_filter)s::jsonb" - query = f""" - SELECT - agent_id, + final_query = query + if metadata_filter: + final_query = query.replace("$7", "AND metadata @> $6::jsonb") + else: + final_query = query.replace("$7", "") + + params = [ developer_id, - name, - canonical_name, - about, - instructions, - model, - metadata, - default_settings, - created_at, - updated_at - FROM agents - WHERE developer_id = %(developer_id)s - {metadata_clause} - ORDER BY {sort_by} {direction} - LIMIT %(limit)s OFFSET %(offset)s; - """ - - params = {"developer_id": developer_id, "limit": limit, "offset": offset} - + limit, + offset + ] + + params.append(sort_by) + params.append(direction) if metadata_filter: - params["metadata_filter"] = metadata_filter + params.append(metadata_filter) + + print(final_query) + print(params) - return query, params + return final_query, params diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index d4adff092..647ea3e52 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -10,6 +10,10 @@ from fastapi import HTTPException from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse +from fastapi import HTTPException +from sqlglot import parse_one +from sqlglot.optimizer import optimize +from ...metrics.counters import increase_counter from ..utils import ( partialclass, pg_query, @@ -19,6 +23,35 @@ ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") + +raw_query = """ +UPDATE agents +SET + name = CASE + WHEN $3::text IS NOT NULL THEN $3 + ELSE name + END, + about = CASE + WHEN $4::text IS NOT NULL THEN $4 + ELSE about + END, + metadata = CASE + WHEN $5::jsonb IS NOT NULL THEN metadata || $5 + ELSE metadata + END, + model = CASE + WHEN $6::text IS NOT NULL THEN $6 + ELSE model + END, + default_settings = CASE + WHEN $7::jsonb IS NOT NULL THEN $7 + ELSE default_settings + END +WHERE agent_id = $2 AND developer_id = $1 +RETURNING *; +""" + +query = parse_one(raw_query).sql(pretty=True) # @rewrap_exceptions( @@ -35,14 +68,13 @@ ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["agent_id"], **d}, - _kind="inserted", ) +# @increase_counter("patch_agent") @pg_query -# @increase_counter("patch_agent1") @beartype async def patch_agent( *, agent_id: UUID, developer_id: UUID, data: PatchAgentRequest -) -> tuple[str, dict]: +) -> tuple[str, list]: """ Constructs the SQL query to partially update an agent's details. @@ -52,27 +84,16 @@ async def patch_agent( data (PatchAgentRequest): A dictionary of fields to update. Returns: - tuple[str, dict]: A tuple containing the SQL query and its parameters. + tuple[str, list]: A tuple containing the SQL query and its parameters. """ - patch_fields = data.model_dump(exclude_unset=True) - set_clauses = [] - params = {} - - for key, value in patch_fields.items(): - if value is not None: # Only update non-null values - set_clauses.append(f"{key} = %({key})s") - params[key] = value - - set_clause = ", ".join(set_clauses) - - query = f""" - UPDATE agents - SET {set_clause} - WHERE agent_id = %(agent_id)s AND developer_id = %(developer_id)s - RETURNING *; - """ - - params["agent_id"] = agent_id - params["developer_id"] = developer_id - - return (query, params) + params = [ + developer_id, + agent_id, + data.name, + data.about, + data.metadata, + data.model, + data.default_settings.model_dump() if data.default_settings else None, + ] + + return query, params diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index 2116e49b0..d65354fa1 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -10,6 +10,10 @@ from fastapi import HTTPException from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest +from ...metrics.counters import increase_counter +from sqlglot import parse_one +from sqlglot.optimizer import optimize + from ..utils import ( partialclass, pg_query, @@ -20,6 +24,20 @@ ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") +raw_query = """ +UPDATE agents +SET + metadata = $3, + name = $4, + about = $5, + model = $6, + default_settings = $7::jsonb +WHERE agent_id = $2 AND developer_id = $1 +RETURNING *; +""" + +query = parse_one(raw_query).sql(pretty=True) + # @rewrap_exceptions( # { @@ -34,15 +52,12 @@ @wrap_in_class( ResourceUpdatedResponse, one=True, - transform=lambda d: {"id": d["agent_id"], "jobs": [], **d}, - _kind="inserted", + transform=lambda d: {"id": d["agent_id"], **d}, ) +# @increase_counter("update_agent") @pg_query -# @increase_counter("update_agent1") @beartype -async def update_agent( - *, agent_id: UUID, developer_id: UUID, data: UpdateAgentRequest -) -> tuple[str, dict]: +async def update_agent(*, agent_id: UUID, developer_id: UUID, data: UpdateAgentRequest) -> tuple[str, list]: """ Constructs the SQL query to fully update an agent's details. @@ -52,21 +67,19 @@ async def update_agent( data (UpdateAgentRequest): A dictionary containing all agent fields to update. Returns: - tuple[str, dict]: A tuple containing the SQL query and its parameters. + tuple[str, list]: A tuple containing the SQL query and its parameters. """ - fields = ", ".join( - [f"{key} = %({key})s" for key in data.model_dump(exclude_unset=True).keys()] - ) - params = {key: value for key, value in data.model_dump(exclude_unset=True).items()} - - query = f""" - UPDATE agents - SET {fields} - WHERE agent_id = %(agent_id)s AND developer_id = %(developer_id)s - RETURNING *; - """ - - params["agent_id"] = agent_id - params["developer_id"] = developer_id - + params = [ + developer_id, + agent_id, + data.metadata or {}, + data.name, + data.about, + data.model, + data.default_settings.model_dump() if data.default_settings else {}, + ] + print("*" * 100) + print(query) + print(params) + print("*" * 100) return (query, params) diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index db583e08f..152ab5ba9 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -1,5 +1,6 @@ import concurrent.futures import inspect +import re import socket import time from functools import partialmethod, wraps @@ -29,6 +30,19 @@ T = TypeVar("T") ModelT = TypeVar("ModelT", bound=BaseModel) +def generate_canonical_name(name: str) -> str: + """Convert a display name to a canonical name. + Example: "My Cool Agent!" -> "my_cool_agent" + """ + # Remove special characters, replace spaces with underscores + canonical = re.sub(r"[^\w\s-]", "", name.lower()) + canonical = re.sub(r"[-\s]+", "_", canonical) + + # Ensure it starts with a letter (prepend 'a' if not) + if not canonical[0].isalpha(): + canonical = f"a_{canonical}" + + return canonical def partialclass(cls, *args, **kwargs): cls_signature = inspect.signature(cls) diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 1151b433d..46e45dbc7 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -95,19 +95,19 @@ def patch_embed_acompletion(): @fixture(scope="global") async def test_agent(dsn=pg_dsn, developer=test_developer): - pool = await asyncpg.create_pool(dsn=dsn) - - async with get_pg_client(pool=pool) as client: - agent = await create_agent( - developer_id=developer.id, - data=CreateAgentRequest( - model="gpt-4o-mini", - name="test agent", - about="test agent about", - metadata={"test": "test"}, - ), - client=client, - ) + pool = await create_db_pool(dsn=dsn) + + agent = await create_agent( + developer_id=developer.id, + data=CreateAgentRequest( + model="gpt-4o-mini", + name="test agent", + canonical_name=f"test_agent_{str(int(time.time()))}", + about="test agent about", + metadata={"test": "test"}, + ), + connection_pool=pool, + ) yield agent await pool.close() diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index f8f75fd0b..4b8ccd959 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -1,7 +1,9 @@ # Tests for agent queries from uuid import uuid4 +from uuid import UUID import asyncpg +from uuid_extensions import uuid7 from ward import raises, test from agents_api.autogen.openapi_model import ( @@ -9,10 +11,11 @@ CreateAgentRequest, CreateOrUpdateAgentRequest, PatchAgentRequest, + ResourceDeletedResponse, ResourceUpdatedResponse, UpdateAgentRequest, ) -from agents_api.clients.pg import get_pg_client +from agents_api.clients.pg import create_db_pool from agents_api.queries.agents import ( create_agent, create_or_update_agent, @@ -25,163 +28,141 @@ from tests.fixtures import pg_dsn, test_agent, test_developer_id -@test("model: create agent") +@test("query: create agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): - pool = await asyncpg.create_pool(dsn=dsn) - async with get_pg_client(pool=pool) as client: - await create_agent( - developer_id=developer_id, - data=CreateAgentRequest( - name="test agent", - about="test agent about", - model="gpt-4o-mini", - ), - client=client, - ) - - -@test("model: create agent with instructions") + """Test that an agent can be successfully created.""" + + pool = await create_db_pool(dsn=dsn) + await create_agent( + developer_id=developer_id, + data=CreateAgentRequest( + name="test agent", + about="test agent about", + model="gpt-4o-mini", + ), + connection_pool=pool, + ) + + +@test("query: create agent with instructions sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): - pool = await asyncpg.create_pool(dsn=dsn) - async with get_pg_client(pool=pool) as client: - await create_agent( - developer_id=developer_id, - data=CreateAgentRequest( - name="test agent", - about="test agent about", - model="gpt-4o-mini", - instructions=["test instruction"], - ), - client=client, - ) - + """Test that an agent can be successfully created or updated.""" + + pool = await create_db_pool(dsn=dsn) + await create_or_update_agent( + developer_id=developer_id, + agent_id=uuid4(), + data=CreateOrUpdateAgentRequest( + name="test agent", + canonical_name="test_agent2", + about="test agent about", + model="gpt-4o-mini", + instructions=["test instruction"], + ), + connection_pool=pool, + ) + + +@test("query: update agent sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): + """Test that an existing agent's information can be successfully updated.""" + + pool = await create_db_pool(dsn=dsn) + result = await update_agent( + agent_id=agent.id, + developer_id=developer_id, + data=UpdateAgentRequest( + name="updated agent", + about="updated agent about", + model="gpt-4o-mini", + default_settings={"temperature": 1.0}, + metadata={"hello": "world"}, + ), + connection_pool=pool, + ) -@test("model: create or update agent") -async def _(dsn=pg_dsn, developer_id=test_developer_id): - pool = await asyncpg.create_pool(dsn=dsn) - async with get_pg_client(pool=pool) as client: - await create_or_update_agent( - developer_id=developer_id, - agent_id=uuid4(), - data=CreateOrUpdateAgentRequest( - name="test agent", - about="test agent about", - model="gpt-4o-mini", - instructions=["test instruction"], - ), - client=client, - ) + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) -@test("model: get agent not exists") +@test("query: get agent not exists sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test that retrieving a non-existent agent raises an exception.""" + agent_id = uuid4() - pool = await asyncpg.create_pool(dsn=dsn) + pool = await create_db_pool(dsn=dsn) with raises(Exception): - async with get_pg_client(pool=pool) as client: - await get_agent(agent_id=agent_id, developer_id=developer_id, client=client) + await get_agent(agent_id=agent_id, developer_id=developer_id, connection_pool=pool) -@test("model: get agent exists") +@test("query: get agent exists sql") async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): - pool = await asyncpg.create_pool(dsn=dsn) - async with get_pg_client(pool=pool) as client: - result = await get_agent(agent_id=agent.id, developer_id=developer_id, client=client) + """Test that retrieving an existing agent returns the correct agent information.""" + + pool = await create_db_pool(dsn=dsn) + result = await get_agent( + agent_id=agent.id, + developer_id=developer_id, + connection_pool=pool, + ) assert result is not None assert isinstance(result, Agent) -@test("model: delete agent") +@test("query: list agents sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): - pool = await asyncpg.create_pool(dsn=dsn) - async with get_pg_client(pool=pool) as client: - temp_agent = await create_agent( - developer_id=developer_id, - data=CreateAgentRequest( - name="test agent", - about="test agent about", - model="gpt-4o-mini", - instructions=["test instruction"], - ), - client=client, - ) - - # Delete the agent - await delete_agent(agent_id=temp_agent.id, developer_id=developer_id, client=client) + """Test that listing agents returns a collection of agent information.""" + + pool = await create_db_pool(dsn=dsn) + result = await list_agents(developer_id=developer_id, connection_pool=pool) - # Check that the agent is deleted - with raises(Exception): - await get_agent(agent_id=temp_agent.id, developer_id=developer_id, client=client) + assert isinstance(result, list) + assert all(isinstance(agent, Agent) for agent in result) -@test("model: update agent") +@test("query: patch agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): - pool = await asyncpg.create_pool(dsn=dsn) - async with get_pg_client(pool=pool) as client: - result = await update_agent( - agent_id=agent.id, - developer_id=developer_id, - data=UpdateAgentRequest( - name="updated agent", - about="updated agent about", - model="gpt-4o-mini", - default_settings={"temperature": 1.0}, - metadata={"hello": "world"}, - ), - client=client, - ) + """Test that an agent can be successfully patched.""" + + pool = await create_db_pool(dsn=dsn) + result = await patch_agent( + agent_id=agent.id, + developer_id=developer_id, + data=PatchAgentRequest( + name="patched agent", + about="patched agent about", + default_settings={"temperature": 1.0}, + metadata={"something": "else"}, + ), + connection_pool=pool, + ) assert result is not None assert isinstance(result, ResourceUpdatedResponse) - async with get_pg_client(pool=pool) as client: - agent = await get_agent( - agent_id=agent.id, - developer_id=developer_id, - client=client, - ) - - assert "test" not in agent.metadata - -@test("model: patch agent") +@test("query: delete agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): - pool = await asyncpg.create_pool(dsn=dsn) - async with get_pg_client(pool=pool) as client: - result = await patch_agent( - agent_id=agent.id, - developer_id=developer_id, - data=PatchAgentRequest( - name="patched agent", - about="patched agent about", - default_settings={"temperature": 1.0}, - metadata={"something": "else"}, - ), - client=client, - ) + """Test that an agent can be successfully deleted.""" + + pool = await create_db_pool(dsn=dsn) + delete_result = await delete_agent(agent_id=agent.id, developer_id=developer_id, connection_pool=pool) - assert result is not None - assert isinstance(result, ResourceUpdatedResponse) + assert delete_result is not None + assert isinstance(delete_result, ResourceDeletedResponse) - async with get_pg_client(pool=pool) as client: - agent = await get_agent( - agent_id=agent.id, + # Verify the agent no longer exists + try: + await get_agent( developer_id=developer_id, - client=client, + agent_id=agent.id, + connection_pool=pool, ) - - assert "hello" in agent.metadata - - -@test("model: list agents") -async def _(dsn=pg_dsn, developer_id=test_developer_id): - """Tests listing all agents associated with a developer in the database. Verifies that the correct list of agents is retrieved.""" - - pool = await asyncpg.create_pool(dsn=dsn) - async with get_pg_client(pool=pool) as client: - result = await list_agents(developer_id=developer_id, client=client) - - assert isinstance(result, list) - assert all(isinstance(agent, Agent) for agent in result) + except Exception: + pass + else: + assert ( + False + ), "Expected an exception to be raised when retrieving a deleted agent." From e745acce3ea2dcd7a7fd49685371689c36e27f5d Mon Sep 17 00:00:00 2001 From: Ahmad-mtos Date: Wed, 18 Dec 2024 08:51:36 +0000 Subject: [PATCH 18/29] refactor: Lint agents-api (CI) --- .../agents_api/queries/agents/create_agent.py | 3 ++- .../queries/agents/create_or_update_agent.py | 5 ++-- .../agents_api/queries/agents/delete_agent.py | 10 +++---- .../agents_api/queries/agents/get_agent.py | 24 ++++++++--------- .../agents_api/queries/agents/list_agents.py | 19 +++++-------- .../agents_api/queries/agents/patch_agent.py | 11 ++++---- .../agents_api/queries/agents/update_agent.py | 9 ++++--- agents-api/agents_api/queries/utils.py | 2 ++ agents-api/tests/fixtures.py | 4 +-- agents-api/tests/test_agent_queries.py | 27 ++++++++++--------- 10 files changed, 54 insertions(+), 60 deletions(-) diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index 63ac4870f..454b24e3b 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -6,10 +6,10 @@ from typing import Any, TypeVar from uuid import UUID -from sqlglot import parse_one from beartype import beartype from fastapi import HTTPException from pydantic import ValidationError +from sqlglot import parse_one from uuid_extensions import uuid7 from ...autogen.openapi_model import Agent, CreateAgentRequest @@ -52,6 +52,7 @@ query = parse_one(raw_query).sql(pretty=True) + # @rewrap_exceptions( # { # psycopg_errors.ForeignKeyViolation: partialclass( diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index bbb897fe5..745be3fb8 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -6,11 +6,10 @@ from typing import Any, TypeVar from uuid import UUID -from sqlglot import parse_one -from sqlglot.optimizer import optimize - from beartype import beartype from fastapi import HTTPException +from sqlglot import parse_one +from sqlglot.optimizer import optimize from ...autogen.openapi_model import Agent, CreateOrUpdateAgentRequest from ..utils import ( diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index a5062f783..73da33261 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -8,6 +8,8 @@ from beartype import beartype from fastapi import HTTPException +from sqlglot import parse_one +from sqlglot.optimizer import optimize from ...autogen.openapi_model import ResourceDeletedResponse from ..utils import ( @@ -16,11 +18,6 @@ rewrap_exceptions, wrap_in_class, ) -from beartype import beartype -from sqlglot import parse_one -from sqlglot.optimizer import optimize -from ...autogen.openapi_model import ResourceDeletedResponse -from ...common.utils.datetime import utcnow ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") @@ -51,6 +48,7 @@ # Convert the list of queries into a single query string query = parse_one(raw_query).sql(pretty=True) + # @rewrap_exceptions( # { # psycopg_errors.ForeignKeyViolation: partialclass( @@ -82,5 +80,5 @@ async def delete_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list """ # Note: We swap the parameter order because the queries use $1 for developer_id and $2 for agent_id params = [developer_id, agent_id] - + return (query, params) diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index 061d0b165..d630a2aeb 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -8,19 +8,17 @@ from beartype import beartype from fastapi import HTTPException -from ...autogen.openapi_model import Agent -from ...metrics.counters import increase_counter from sqlglot import parse_one from sqlglot.optimizer import optimize + +from ...autogen.openapi_model import Agent +from ...metrics.counters import increase_counter from ..utils import ( partialclass, pg_query, rewrap_exceptions, wrap_in_class, ) -from beartype import beartype - -from ...autogen.openapi_model import Agent raw_query = """ SELECT @@ -48,14 +46,14 @@ # @rewrap_exceptions( - # { - # psycopg_errors.ForeignKeyViolation: partialclass( - # HTTPException, - # status_code=404, - # detail="The specified developer does not exist.", - # ) - # } - # # TODO: Add more exceptions +# { +# psycopg_errors.ForeignKeyViolation: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified developer does not exist.", +# ) +# } +# # TODO: Add more exceptions # ) @wrap_in_class(Agent, one=True, transform=lambda d: {"id": d["agent_id"], **d}) # @increase_counter("get_agent") diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index 92165e414..b49e71886 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -8,6 +8,8 @@ from beartype import beartype from fastapi import HTTPException +from sqlglot import parse_one +from sqlglot.optimizer import optimize from ...autogen.openapi_model import Agent from ..utils import ( @@ -16,11 +18,6 @@ rewrap_exceptions, wrap_in_class, ) -from beartype import beartype -from sqlglot import parse_one -from sqlglot.optimizer import optimize - -from ...autogen.openapi_model import Agent ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") @@ -99,18 +96,14 @@ async def list_agents( final_query = query.replace("$7", "AND metadata @> $6::jsonb") else: final_query = query.replace("$7", "") - - params = [ - developer_id, - limit, - offset - ] - + + params = [developer_id, limit, offset] + params.append(sort_by) params.append(direction) if metadata_filter: params.append(metadata_filter) - + print(final_query) print(params) diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index 647ea3e52..929fd9c34 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -8,11 +8,10 @@ from beartype import beartype from fastapi import HTTPException - -from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse -from fastapi import HTTPException from sqlglot import parse_one from sqlglot.optimizer import optimize + +from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse from ...metrics.counters import increase_counter from ..utils import ( partialclass, @@ -23,7 +22,7 @@ ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") - + raw_query = """ UPDATE agents SET @@ -93,7 +92,7 @@ async def patch_agent( data.about, data.metadata, data.model, - data.default_settings.model_dump() if data.default_settings else None, + data.default_settings.model_dump() if data.default_settings else None, ] - + return query, params diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index d65354fa1..3f413c78d 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -8,12 +8,11 @@ from beartype import beartype from fastapi import HTTPException - -from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest -from ...metrics.counters import increase_counter from sqlglot import parse_one from sqlglot.optimizer import optimize +from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest +from ...metrics.counters import increase_counter from ..utils import ( partialclass, pg_query, @@ -57,7 +56,9 @@ # @increase_counter("update_agent") @pg_query @beartype -async def update_agent(*, agent_id: UUID, developer_id: UUID, data: UpdateAgentRequest) -> tuple[str, list]: +async def update_agent( + *, agent_id: UUID, developer_id: UUID, data: UpdateAgentRequest +) -> tuple[str, list]: """ Constructs the SQL query to fully update an agent's details. diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index 152ab5ba9..a3ce89d98 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -30,6 +30,7 @@ T = TypeVar("T") ModelT = TypeVar("ModelT", bound=BaseModel) + def generate_canonical_name(name: str) -> str: """Convert a display name to a canonical name. Example: "My Cool Agent!" -> "my_cool_agent" @@ -44,6 +45,7 @@ def generate_canonical_name(name: str) -> str: return canonical + def partialclass(cls, *args, **kwargs): cls_signature = inspect.signature(cls) bound = cls_signature.bind_partial(*args, **kwargs) diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 46e45dbc7..25892d959 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -11,9 +11,9 @@ ) from agents_api.clients.pg import create_db_pool from agents_api.env import api_key, api_key_header_name, multi_tenant_mode +from agents_api.queries.agents.create_agent import create_agent from agents_api.queries.developers.create_developer import create_developer -from agents_api.queries.agents.create_agent import create_agent # from agents_api.queries.agents.delete_agent import delete_agent from agents_api.queries.developers.get_developer import get_developer @@ -100,7 +100,7 @@ async def test_agent(dsn=pg_dsn, developer=test_developer): agent = await create_agent( developer_id=developer.id, data=CreateAgentRequest( - model="gpt-4o-mini", + model="gpt-4o-mini", name="test agent", canonical_name=f"test_agent_{str(int(time.time()))}", about="test agent about", diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index 4b8ccd959..b27f8abde 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -1,7 +1,6 @@ # Tests for agent queries -from uuid import uuid4 +from uuid import UUID, uuid4 -from uuid import UUID import asyncpg from uuid_extensions import uuid7 from ward import raises, test @@ -31,7 +30,7 @@ @test("query: create agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that an agent can be successfully created.""" - + pool = await create_db_pool(dsn=dsn) await create_agent( developer_id=developer_id, @@ -47,7 +46,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): @test("query: create agent with instructions sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that an agent can be successfully created or updated.""" - + pool = await create_db_pool(dsn=dsn) await create_or_update_agent( developer_id=developer_id, @@ -66,7 +65,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): @test("query: update agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): """Test that an existing agent's information can be successfully updated.""" - + pool = await create_db_pool(dsn=dsn) result = await update_agent( agent_id=agent.id, @@ -88,18 +87,20 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): @test("query: get agent not exists sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that retrieving a non-existent agent raises an exception.""" - + agent_id = uuid4() pool = await create_db_pool(dsn=dsn) with raises(Exception): - await get_agent(agent_id=agent_id, developer_id=developer_id, connection_pool=pool) + await get_agent( + agent_id=agent_id, developer_id=developer_id, connection_pool=pool + ) @test("query: get agent exists sql") async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): """Test that retrieving an existing agent returns the correct agent information.""" - + pool = await create_db_pool(dsn=dsn) result = await get_agent( agent_id=agent.id, @@ -114,7 +115,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): @test("query: list agents sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that listing agents returns a collection of agent information.""" - + pool = await create_db_pool(dsn=dsn) result = await list_agents(developer_id=developer_id, connection_pool=pool) @@ -125,7 +126,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): @test("query: patch agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): """Test that an agent can be successfully patched.""" - + pool = await create_db_pool(dsn=dsn) result = await patch_agent( agent_id=agent.id, @@ -146,9 +147,11 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): @test("query: delete agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): """Test that an agent can be successfully deleted.""" - + pool = await create_db_pool(dsn=dsn) - delete_result = await delete_agent(agent_id=agent.id, developer_id=developer_id, connection_pool=pool) + delete_result = await delete_agent( + agent_id=agent.id, developer_id=developer_id, connection_pool=pool + ) assert delete_result is not None assert isinstance(delete_result, ResourceDeletedResponse) From 2f392f745cf2f0420185f1179b7761b13866ff1f Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Wed, 18 Dec 2024 13:18:41 +0300 Subject: [PATCH 19/29] fix(agents-api): misc fixes --- .../agents_api/queries/agents/create_agent.py | 2 +- .../queries/agents/create_or_update_agent.py | 2 +- .../agents_api/queries/agents/delete_agent.py | 2 +- .../agents_api/queries/agents/get_agent.py | 2 +- .../agents_api/queries/agents/list_agents.py | 29 +++++++++---------- .../agents_api/queries/agents/patch_agent.py | 2 +- .../agents_api/queries/agents/update_agent.py | 7 ++--- agents-api/agents_api/queries/utils.py | 4 +++ agents-api/tests/test_agent_queries.py | 18 ++++-------- 9 files changed, 29 insertions(+), 39 deletions(-) diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index 454b24e3b..81a408f30 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -87,7 +87,7 @@ one=True, transform=lambda d: {"id": d["agent_id"], **d}, ) -# @increase_counter("create_agent") +@increase_counter("create_agent") @pg_query @beartype async def create_agent( diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index 745be3fb8..d74cd57c2 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -66,7 +66,7 @@ one=True, transform=lambda d: {"id": d["agent_id"], **d}, ) -# @increase_counter("create_or_update_agent") +@increase_counter("create_or_update_agent") @pg_query @beartype async def create_or_update_agent( diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index 73da33261..db4a3ab4f 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -64,7 +64,7 @@ one=True, transform=lambda d: {**d, "id": d["agent_id"], "deleted_at": utcnow()}, ) -# @increase_counter("delete_agent") +@increase_counter("delete_agent") @pg_query @beartype async def delete_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]: diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index d630a2aeb..a9893d747 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -56,7 +56,7 @@ # # TODO: Add more exceptions # ) @wrap_in_class(Agent, one=True, transform=lambda d: {"id": d["agent_id"], **d}) -# @increase_counter("get_agent") +@increase_counter("get_agent") @pg_query @beartype async def get_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]: diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index b49e71886..48df01b90 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -36,7 +36,7 @@ created_at, updated_at FROM agents -WHERE developer_id = $1 $7 +WHERE developer_id = $1 {metadata_filter_query} ORDER BY CASE WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at END ASC NULLS LAST, CASE WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at END DESC NULLS LAST, @@ -45,8 +45,6 @@ LIMIT $2 OFFSET $3; """ -query = raw_query - # @rewrap_exceptions( # { @@ -59,7 +57,7 @@ # # TODO: Add more exceptions # ) @wrap_in_class(Agent, transform=lambda d: {"id": d["agent_id"], **d}) -# @increase_counter("list_agents") +@increase_counter("list_agents") @pg_query @beartype async def list_agents( @@ -91,20 +89,19 @@ async def list_agents( # Build metadata filter clause if needed - final_query = query - if metadata_filter: - final_query = query.replace("$7", "AND metadata @> $6::jsonb") - else: - final_query = query.replace("$7", "") - - params = [developer_id, limit, offset] + final_query = raw_query.format( + metadata_filter_query="AND metadata @> $6::jsonb" if metadata_filter else "" + ) + + params = [ + developer_id, + limit, + offset, + sort_by, + direction, + ] - params.append(sort_by) - params.append(direction) if metadata_filter: params.append(metadata_filter) - print(final_query) - print(params) - return final_query, params diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index 929fd9c34..d2a172838 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -68,7 +68,7 @@ one=True, transform=lambda d: {"id": d["agent_id"], **d}, ) -# @increase_counter("patch_agent") +@increase_counter("patch_agent") @pg_query @beartype async def patch_agent( diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index 3f413c78d..d03994e9c 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -53,7 +53,7 @@ one=True, transform=lambda d: {"id": d["agent_id"], **d}, ) -# @increase_counter("update_agent") +@increase_counter("update_agent") @pg_query @beartype async def update_agent( @@ -79,8 +79,5 @@ async def update_agent( data.model, data.default_settings.model_dump() if data.default_settings else {}, ] - print("*" * 100) - print(query) - print(params) - print("*" * 100) + return (query, params) diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index a3ce89d98..ba9bade9e 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -1,5 +1,6 @@ import concurrent.futures import inspect +import random import re import socket import time @@ -43,6 +44,9 @@ def generate_canonical_name(name: str) -> str: if not canonical[0].isalpha(): canonical = f"a_{canonical}" + # Add 3 random numbers to the end + canonical = f"{canonical}_{random.randint(100, 999)}" + return canonical diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index b27f8abde..18d95b743 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -1,6 +1,5 @@ # Tests for agent queries -from uuid import UUID, uuid4 - +from uuid import UUID import asyncpg from uuid_extensions import uuid7 from ward import raises, test @@ -50,7 +49,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): pool = await create_db_pool(dsn=dsn) await create_or_update_agent( developer_id=developer_id, - agent_id=uuid4(), + agent_id=uuid7(), data=CreateOrUpdateAgentRequest( name="test agent", canonical_name="test_agent2", @@ -87,8 +86,8 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): @test("query: get agent not exists sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that retrieving a non-existent agent raises an exception.""" - - agent_id = uuid4() + + agent_id = uuid7() pool = await create_db_pool(dsn=dsn) with raises(Exception): @@ -156,16 +155,9 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): assert delete_result is not None assert isinstance(delete_result, ResourceDeletedResponse) - # Verify the agent no longer exists - try: + with raises(Exception): await get_agent( developer_id=developer_id, agent_id=agent.id, connection_pool=pool, ) - except Exception: - pass - else: - assert ( - False - ), "Expected an exception to be raised when retrieving a deleted agent." From 0579f3c03f62b1d02b597bd4918de5ab1eb4bd34 Mon Sep 17 00:00:00 2001 From: Ahmad-mtos Date: Wed, 18 Dec 2024 10:27:40 +0000 Subject: [PATCH 20/29] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/agents/list_agents.py | 2 +- agents-api/tests/test_agent_queries.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index 48df01b90..69e91f206 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -92,7 +92,7 @@ async def list_agents( final_query = raw_query.format( metadata_filter_query="AND metadata @> $6::jsonb" if metadata_filter else "" ) - + params = [ developer_id, limit, diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index 18d95b743..56a07ed03 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -1,5 +1,6 @@ # Tests for agent queries from uuid import UUID + import asyncpg from uuid_extensions import uuid7 from ward import raises, test @@ -86,7 +87,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): @test("query: get agent not exists sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that retrieving a non-existent agent raises an exception.""" - + agent_id = uuid7() pool = await create_db_pool(dsn=dsn) From 1b7a022d8d3aab446a683eed0914ffa021426b73 Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Thu, 19 Dec 2024 01:14:40 +0300 Subject: [PATCH 21/29] wip --- agents-api/agents_api/autogen/Sessions.py | 40 +++ .../agents_api/queries/agents/create_agent.py | 7 +- .../queries/agents/create_or_update_agent.py | 6 +- .../agents_api/queries/agents/delete_agent.py | 7 +- .../agents_api/queries/agents/get_agent.py | 5 +- .../agents_api/queries/agents/list_agents.py | 6 +- .../agents_api/queries/agents/patch_agent.py | 5 +- .../agents_api/queries/agents/update_agent.py | 5 +- .../queries/developers/get_developer.py | 2 +- .../queries/entries/create_entries.py | 18 +- .../queries/entries/list_entries.py | 10 +- .../queries/sessions/create_session.py | 28 +- agents-api/agents_api/queries/utils.py | 17 +- agents-api/tests/fixtures.py | 44 ++- agents-api/tests/test_agent_queries.py | 2 - agents-api/tests/test_entry_queries.py | 10 +- agents-api/tests/test_messages_truncation.py | 2 +- agents-api/tests/test_session_queries.py | 339 +++++++++++------- .../integrations/autogen/Sessions.py | 40 +++ typespec/sessions/models.tsp | 6 + .../@typespec/openapi3/openapi-1.0.0.yaml | 53 +++ 21 files changed, 439 insertions(+), 213 deletions(-) diff --git a/agents-api/agents_api/autogen/Sessions.py b/agents-api/agents_api/autogen/Sessions.py index 460fd25ce..e2a9ce164 100644 --- a/agents-api/agents_api/autogen/Sessions.py +++ b/agents-api/agents_api/autogen/Sessions.py @@ -31,6 +31,10 @@ class CreateSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -51,6 +55,10 @@ class CreateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -67,6 +75,10 @@ class PatchSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -87,6 +99,10 @@ class PatchSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptionsUpdate | None = None metadata: dict[str, Any] | None = None @@ -121,6 +137,10 @@ class Session(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ summary: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None """ Summary (null at the beginning) - generated automatically after every interaction @@ -145,6 +165,10 @@ class Session(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] metadata: dict[str, Any] | None = None @@ -197,6 +221,10 @@ class UpdateSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -217,6 +245,10 @@ class UpdateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -240,6 +272,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -260,6 +296,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index 81a408f30..bb111b0df 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -7,18 +7,17 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException -from pydantic import ValidationError from sqlglot import parse_one from uuid_extensions import uuid7 +from ...metrics.counters import increase_counter + from ...autogen.openapi_model import Agent, CreateAgentRequest from ..utils import ( generate_canonical_name, - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index d74cd57c2..6cfb83767 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -7,17 +7,15 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import Agent, CreateOrUpdateAgentRequest +from ...metrics.counters import increase_counter from ..utils import ( generate_canonical_name, - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index db4a3ab4f..9c3ee5585 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -7,16 +7,15 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import ResourceDeletedResponse +from ...metrics.counters import increase_counter +from ...common.utils.datetime import utcnow from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index a9893d747..dce424771 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -7,17 +7,14 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import Agent from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) raw_query = """ diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index 69e91f206..3698c68f1 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -8,15 +8,13 @@ from beartype import beartype from fastapi import HTTPException -from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import Agent +from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index d2a172838..6f9cb3b9c 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -7,17 +7,14 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index d03994e9c..cd15313a2 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -7,17 +7,14 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/developers/get_developer.py b/agents-api/agents_api/queries/developers/get_developer.py index 373a2fb36..28be9a4b1 100644 --- a/agents-api/agents_api/queries/developers/get_developer.py +++ b/agents-api/agents_api/queries/developers/get_developer.py @@ -12,8 +12,8 @@ from ..utils import ( partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) # TODO: Add verify_developer diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py index 24c0be26e..a54104274 100644 --- a/agents-api/agents_api/queries/entries/create_entries.py +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -14,14 +14,10 @@ # Query for checking if the session exists session_exists_query = """ -SELECT CASE - WHEN EXISTS ( - SELECT 1 FROM sessions - WHERE session_id = $1 AND developer_id = $2 - ) - THEN TRUE - ELSE (SELECT NULL::boolean WHERE FALSE) -- This raises a NO_DATA_FOUND error -END; +SELECT EXISTS ( + SELECT 1 FROM sessions + WHERE session_id = $1 AND developer_id = $2 +) AS exists; """ # Define the raw SQL query for creating entries @@ -71,6 +67,10 @@ status_code=400, detail=str(exc), ), + asyncpg.NoDataFoundError: lambda exc: HTTPException( + status_code=404, + detail="Session not found", + ), } ) @wrap_in_class( @@ -166,7 +166,7 @@ async def add_entry_relations( item.get("is_leaf", False), # $5 ] ) - + return [ ( session_exists_query, diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py index 0aeb92a25..3f4a0699e 100644 --- a/agents-api/agents_api/queries/entries/list_entries.py +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -62,6 +62,10 @@ status_code=400, detail=str(exc), ), + asyncpg.NoDataFoundError: lambda exc: HTTPException( + status_code=404, + detail="Session not found", + ), } ) @wrap_in_class(Entry) @@ -78,7 +82,7 @@ async def list_entries( sort_by: Literal["created_at", "timestamp"] = "timestamp", direction: Literal["asc", "desc"] = "asc", exclude_relations: list[str] = [], -) -> list[tuple[str, list]]: +) -> list[tuple[str, list] | tuple[str, list, str]]: if limit < 1 or limit > 1000: raise HTTPException(status_code=400, detail="Limit must be between 1 and 1000") if offset < 0: @@ -98,14 +102,14 @@ async def list_entries( developer_id, # $5 exclude_relations, # $6 ] - return [ ( session_exists_query, [session_id, developer_id], + "fetchrow", ), ( query, - entry_params, + entry_params ), ] diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py index 3074f087b..baa3f09d1 100644 --- a/agents-api/agents_api/queries/sessions/create_session.py +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -45,11 +45,7 @@ participant_type, participant_id ) -SELECT - $1 as developer_id, - $2 as session_id, - unnest($3::participant_type[]) as participant_type, - unnest($4::uuid[]) as participant_id; +VALUES ($1, $2, $3, $4); """).sql(pretty=True) @@ -67,7 +63,7 @@ ), } ) -@wrap_in_class(Session, one=True, transform=lambda d: {**d, "id": d["session_id"]}) +@wrap_in_class(Session, transform=lambda d: {**d, "id": d["session_id"]}) @increase_counter("create_session") @pg_query @beartype @@ -76,7 +72,7 @@ async def create_session( developer_id: UUID, session_id: UUID, data: CreateSessionRequest, -) -> list[tuple[str, list]]: +) -> list[tuple[str, list] | tuple[str, list, str]]: """ Constructs SQL queries to create a new session and its participant lookups. @@ -86,7 +82,7 @@ async def create_session( data (CreateSessionRequest): Session creation data Returns: - list[tuple[str, list]]: SQL queries and their parameters + list[tuple[str, list] | tuple[str, list, str]]: SQL queries and their parameters """ # Handle participants users = data.users or ([data.user] if data.user else []) @@ -122,15 +118,15 @@ async def create_session( data.recall_options or {}, # $10 ] - # Prepare lookup parameters - lookup_params = [ - developer_id, # $1 - session_id, # $2 - participant_types, # $3 - participant_ids, # $4 - ] + # Prepare lookup parameters as a list of parameter lists + lookup_params = [] + for ptype, pid in zip(participant_types, participant_ids): + lookup_params.append([developer_id, session_id, ptype, pid]) + print("*" * 100) + print(lookup_params) + print("*" * 100) return [ (session_query, session_params), - (lookup_query, lookup_params), + (lookup_query, lookup_params, "fetchmany"), ] diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index ba9bade9e..194cba7bc 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -69,7 +69,7 @@ class AsyncPGFetchArgs(TypedDict): type SQLQuery = str -type FetchMethod = Literal["fetch", "fetchmany"] +type FetchMethod = Literal["fetch", "fetchmany", "fetchrow"] type PGQueryArgs = tuple[SQLQuery, list[Any]] | tuple[SQLQuery, list[Any], FetchMethod] type PreparedPGQueryArgs = tuple[FetchMethod, AsyncPGFetchArgs] type BatchedPreparedPGQueryArgs = list[PreparedPGQueryArgs] @@ -102,6 +102,13 @@ def prepare_pg_query_args( ), ) ) + case (query, variables, "fetchrow"): + batch.append( + ( + "fetchrow", + AsyncPGFetchArgs(query=query, args=variables, timeout=query_timeout), + ) + ) case _: raise ValueError("Invalid query arguments") @@ -161,6 +168,14 @@ async def wrapper( query, *args, timeout=timeout ) + print("%" * 100) + print(results) + print(*args) + print("%" * 100) + + if method_name == "fetchrow" and (len(results) == 0 or results.get("bool") is None): + raise asyncpg.NoDataFoundError + end = timeit and time.perf_counter() timeit and print( diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 25892d959..9153785a4 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -1,5 +1,6 @@ import random import string +import time from uuid import UUID from fastapi.testclient import TestClient @@ -7,6 +8,8 @@ from ward import fixture from agents_api.autogen.openapi_model import ( + CreateAgentRequest, + CreateSessionRequest, CreateUserRequest, ) from agents_api.clients.pg import create_db_pool @@ -24,8 +27,8 @@ # from agents_api.queries.execution.create_temporal_lookup import create_temporal_lookup # from agents_api.queries.files.create_file import create_file # from agents_api.queries.files.delete_file import delete_file -# from agents_api.queries.session.create_session import create_session -# from agents_api.queries.session.delete_session import delete_session +from agents_api.queries.sessions.create_session import create_session + # from agents_api.queries.task.create_task import create_task # from agents_api.queries.task.delete_task import delete_task # from agents_api.queries.tools.create_tools import create_tools @@ -150,22 +153,27 @@ async def test_new_developer(dsn=pg_dsn, email=random_email): return developer -# @fixture(scope="global") -# async def test_session( -# dsn=pg_dsn, -# developer_id=test_developer_id, -# test_user=test_user, -# test_agent=test_agent, -# ): -# async with get_pg_client(dsn=dsn) as client: -# session = await create_session( -# developer_id=developer_id, -# data=CreateSessionRequest( -# agent=test_agent.id, user=test_user.id, metadata={"test": "test"} -# ), -# client=client, -# ) -# yield session +@fixture(scope="global") +async def test_session( + dsn=pg_dsn, + developer_id=test_developer_id, + test_user=test_user, + test_agent=test_agent, +): + pool = await create_db_pool(dsn=dsn) + + session = await create_session( + developer_id=developer_id, + data=CreateSessionRequest( + agent=test_agent.id, + user=test_user.id, + metadata={"test": "test"}, + system_template="test system template", + ), + connection_pool=pool, + ) + + return session # @fixture(scope="global") diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index 56a07ed03..b6cb7aedc 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -1,7 +1,5 @@ # Tests for agent queries -from uuid import UUID -import asyncpg from uuid_extensions import uuid7 from ward import raises, test diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 87d9cdb4f..da53ce06d 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -3,7 +3,7 @@ It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. """ -from uuid import uuid4 +from uuid_extensions import uuid7 from fastapi import HTTPException from ward import raises, test @@ -11,7 +11,7 @@ from agents_api.autogen.openapi_model import CreateEntryRequest from agents_api.clients.pg import create_db_pool from agents_api.queries.entries import create_entries, list_entries -from tests.fixtures import pg_dsn, test_developer # , test_session +from tests.fixtures import pg_dsn, test_developer, test_session # , test_session MODEL = "gpt-4o-mini" @@ -31,11 +31,10 @@ async def _(dsn=pg_dsn, developer=test_developer): with raises(HTTPException) as exc_info: await create_entries( developer_id=developer.id, - session_id=uuid4(), + session_id=uuid7(), data=[test_entry], connection_pool=pool, ) - assert exc_info.raised.status_code == 404 @@ -48,10 +47,9 @@ async def _(dsn=pg_dsn, developer=test_developer): with raises(HTTPException) as exc_info: await list_entries( developer_id=developer.id, - session_id=uuid4(), + session_id=uuid7(), connection_pool=pool, ) - assert exc_info.raised.status_code == 404 diff --git a/agents-api/tests/test_messages_truncation.py b/agents-api/tests/test_messages_truncation.py index 39cc02c2c..bb1eaee30 100644 --- a/agents-api/tests/test_messages_truncation.py +++ b/agents-api/tests/test_messages_truncation.py @@ -1,4 +1,4 @@ -# from uuid import uuid4 + # from uuid_extensions import uuid7 # from ward import raises, test diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 4fdc7e6e4..b85268434 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -8,79 +8,116 @@ from agents_api.autogen.openapi_model import ( Session, + CreateSessionRequest, + CreateOrUpdateSessionRequest, + UpdateSessionRequest, + PatchSessionRequest, + ResourceUpdatedResponse, + ResourceDeletedResponse, ) from agents_api.clients.pg import create_db_pool from agents_api.queries.sessions import ( count_sessions, get_session, list_sessions, + create_session, + create_or_update_session, + update_session, + patch_session, + delete_session, ) from tests.fixtures import ( pg_dsn, test_developer_id, -) # , test_session, test_agent, test_user - -# @test("query: create session sql") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user): -# """Test that a session can be successfully created.""" - -# pool = await create_db_pool(dsn=dsn) -# await create_session( -# developer_id=developer_id, -# session_id=uuid7(), -# data=CreateSessionRequest( -# users=[user.id], -# agents=[agent.id], -# situation="test session", -# ), -# connection_pool=pool, -# ) - - -# @test("query: create or update session sql") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user): -# """Test that a session can be successfully created or updated.""" - -# pool = await create_db_pool(dsn=dsn) -# await create_or_update_session( -# developer_id=developer_id, -# session_id=uuid7(), -# data=CreateOrUpdateSessionRequest( -# users=[user.id], -# agents=[agent.id], -# situation="test session", -# ), -# connection_pool=pool, -# ) - - -# @test("query: update session sql") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent): -# """Test that an existing session's information can be successfully updated.""" - -# pool = await create_db_pool(dsn=dsn) -# update_result = await update_session( -# session_id=session.id, -# developer_id=developer_id, -# data=UpdateSessionRequest( -# agents=[agent.id], -# situation="updated session", -# ), -# connection_pool=pool, -# ) - -# assert update_result is not None -# assert isinstance(update_result, ResourceUpdatedResponse) -# assert update_result.updated_at > session.created_at - - -@test("query: get session not exists sql") -async def _(dsn=pg_dsn, developer_id=test_developer_id): - """Test that retrieving a non-existent session returns an empty result.""" + test_developer, + test_user, + test_agent, + test_session, +) + +@test("query: create session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user +): + """Test that a session can be successfully created.""" + + pool = await create_db_pool(dsn=dsn) session_id = uuid7() + data = CreateSessionRequest( + users=[user.id], + agents=[agent.id], + situation="test session", + system_template="test system template", + ) + result = await create_session( + developer_id=developer_id, + session_id=session_id, + data=data, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, Session), f"Result is not a Session, {result}" + assert result.id == session_id + assert result.developer_id == developer_id + assert result.situation == "test session" + assert set(result.users) == {user.id} + assert set(result.agents) == {agent.id} + + +@test("query: create or update session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user +): + """Test that a session can be successfully created or updated.""" + pool = await create_db_pool(dsn=dsn) + session_id = uuid7() + data = CreateOrUpdateSessionRequest( + users=[user.id], + agents=[agent.id], + situation="test session", + ) + result = await create_or_update_session( + developer_id=developer_id, + session_id=session_id, + data=data, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, Session) + assert result.id == session_id + assert result.developer_id == developer_id + assert result.situation == "test session" + assert set(result.users) == {user.id} + assert set(result.agents) == {agent.id} + + +@test("query: get session exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test retrieving an existing session.""" + pool = await create_db_pool(dsn=dsn) + result = await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, Session) + assert result.id == session.id + assert result.developer_id == developer_id + + +@test("query: get session does not exist") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test retrieving a non-existent session.""" + + session_id = uuid7() + pool = await create_db_pool(dsn=dsn) with raises(Exception): await get_session( session_id=session_id, @@ -89,90 +126,136 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): ) -# @test("query: get session exists sql") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): -# """Test that retrieving an existing session returns the correct session information.""" +@test("query: list sessions") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test listing sessions with default pagination.""" -# pool = await create_db_pool(dsn=dsn) -# result = await get_session( -# session_id=session.id, -# developer_id=developer_id, -# connection_pool=pool, -# ) + pool = await create_db_pool(dsn=dsn) + result, _ = await list_sessions( + developer_id=developer_id, + limit=10, + offset=0, + connection_pool=pool, + ) -# assert result is not None -# assert isinstance(result, Session) + assert isinstance(result, list) + assert len(result) >= 1 + assert any(s.id == session.id for s in result) -@test("query: list sessions when none exist sql") -async def _(dsn=pg_dsn, developer_id=test_developer_id): - """Test that listing sessions returns a collection of session information.""" +@test("query: list sessions with filters") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test listing sessions with specific filters.""" pool = await create_db_pool(dsn=dsn) - result = await list_sessions( + result, _ = await list_sessions( developer_id=developer_id, + limit=10, + offset=0, + filters={"situation": "test session"}, connection_pool=pool, ) assert isinstance(result, list) assert len(result) >= 1 - assert all(isinstance(session, Session) for session in result) - - -# @test("query: patch session sql") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent): -# """Test that a session can be successfully patched.""" - -# pool = await create_db_pool(dsn=dsn) -# patch_result = await patch_session( -# developer_id=developer_id, -# session_id=session.id, -# data=PatchSessionRequest( -# agents=[agent.id], -# situation="patched session", -# metadata={"test": "metadata"}, -# ), -# connection_pool=pool, -# ) - -# assert patch_result is not None -# assert isinstance(patch_result, ResourceUpdatedResponse) -# assert patch_result.updated_at > session.created_at - - -# @test("query: delete session sql") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): -# """Test that a session can be successfully deleted.""" - -# pool = await create_db_pool(dsn=dsn) -# delete_result = await delete_session( -# developer_id=developer_id, -# session_id=session.id, -# connection_pool=pool, -# ) - -# assert delete_result is not None -# assert isinstance(delete_result, ResourceDeletedResponse) - -# # Verify the session no longer exists -# with raises(Exception): -# await get_session( -# developer_id=developer_id, -# session_id=session.id, -# connection_pool=pool, -# ) - - -@test("query: count sessions sql") -async def _(dsn=pg_dsn, developer_id=test_developer_id): - """Test that sessions can be counted.""" + assert all(s.situation == "test session" for s in result) + + +@test("query: count sessions") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test counting the number of sessions for a developer.""" pool = await create_db_pool(dsn=dsn) - result = await count_sessions( + count = await count_sessions( developer_id=developer_id, connection_pool=pool, ) - assert isinstance(result, dict) - assert "count" in result - assert isinstance(result["count"], int) + assert isinstance(count, int) + assert count >= 1 + + +@test("query: update session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent +): + """Test that an existing session's information can be successfully updated.""" + + pool = await create_db_pool(dsn=dsn) + data = UpdateSessionRequest( + agents=[agent.id], + situation="updated session", + ) + result = await update_session( + session_id=session.id, + developer_id=developer_id, + data=data, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) + assert result.updated_at > session.created_at + + updated_session = await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + assert updated_session.situation == "updated session" + assert set(updated_session.agents) == {agent.id} + + +@test("query: patch session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent +): + """Test that a session can be successfully patched.""" + + pool = await create_db_pool(dsn=dsn) + data = PatchSessionRequest( + agents=[agent.id], + situation="patched session", + metadata={"test": "metadata"}, + ) + result = await patch_session( + developer_id=developer_id, + session_id=session.id, + data=data, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) + assert result.updated_at > session.created_at + + patched_session = await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + assert patched_session.situation == "patched session" + assert set(patched_session.agents) == {agent.id} + assert patched_session.metadata == {"test": "metadata"} + + +@test("query: delete session sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test that a session can be successfully deleted.""" + + pool = await create_db_pool(dsn=dsn) + delete_result = await delete_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + + assert delete_result is not None + assert isinstance(delete_result, ResourceDeletedResponse) + + with raises(Exception): + await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) diff --git a/integrations-service/integrations/autogen/Sessions.py b/integrations-service/integrations/autogen/Sessions.py index 460fd25ce..e2a9ce164 100644 --- a/integrations-service/integrations/autogen/Sessions.py +++ b/integrations-service/integrations/autogen/Sessions.py @@ -31,6 +31,10 @@ class CreateSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -51,6 +55,10 @@ class CreateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -67,6 +75,10 @@ class PatchSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -87,6 +99,10 @@ class PatchSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptionsUpdate | None = None metadata: dict[str, Any] | None = None @@ -121,6 +137,10 @@ class Session(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ summary: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None """ Summary (null at the beginning) - generated automatically after every interaction @@ -145,6 +165,10 @@ class Session(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] metadata: dict[str, Any] | None = None @@ -197,6 +221,10 @@ class UpdateSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -217,6 +245,10 @@ class UpdateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -240,6 +272,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -260,6 +296,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None diff --git a/typespec/sessions/models.tsp b/typespec/sessions/models.tsp index f15453a5f..720625f3b 100644 --- a/typespec/sessions/models.tsp +++ b/typespec/sessions/models.tsp @@ -63,6 +63,9 @@ model Session { /** A specific situation that sets the background for this session */ situation: string = defaultSessionSystemMessage; + /** System prompt for this session */ + system_template: string | null = null; + /** Summary (null at the beginning) - generated automatically after every interaction */ @visibility("read") summary: string | null = null; @@ -83,6 +86,9 @@ model Session { * If a tool call is not made, the model's output will be returned as is. */ auto_run_tools: boolean = false; + /** Whether to forward tool calls to the model */ + forward_tool_calls: boolean = false; + recall_options?: RecallOptions | null = null; ...HasId; diff --git a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml index 9298ab458..d4835a695 100644 --- a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml +++ b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml @@ -3761,10 +3761,12 @@ components: required: - id - situation + - system_template - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls properties: id: $ref: '#/components/schemas/Common.uuid' @@ -3840,6 +3842,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null render_templates: type: boolean description: Render system and assistant message content as jinja templates @@ -3865,6 +3872,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -3880,10 +3891,12 @@ components: type: object required: - situation + - system_template - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls properties: user: allOf: @@ -3957,6 +3970,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null render_templates: type: boolean description: Render system and assistant message content as jinja templates @@ -3982,6 +4000,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -4096,6 +4118,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null render_templates: type: boolean description: Render system and assistant message content as jinja templates @@ -4121,6 +4148,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -4189,11 +4220,13 @@ components: type: object required: - situation + - system_template - summary - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls - id - created_at - updated_at @@ -4254,6 +4287,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null summary: type: string nullable: true @@ -4285,6 +4323,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -4360,10 +4402,12 @@ components: type: object required: - situation + - system_template - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls properties: situation: type: string @@ -4421,6 +4465,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null render_templates: type: boolean description: Render system and assistant message content as jinja templates @@ -4446,6 +4495,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: From db318013484ef0eeab5171b9456c8c221e545867 Mon Sep 17 00:00:00 2001 From: Ahmad-mtos Date: Wed, 18 Dec 2024 22:15:29 +0000 Subject: [PATCH 22/29] refactor: Lint agents-api (CI) --- .../agents_api/queries/agents/create_agent.py | 5 ++--- .../queries/agents/create_or_update_agent.py | 2 +- .../agents_api/queries/agents/delete_agent.py | 4 ++-- .../agents_api/queries/agents/get_agent.py | 2 +- .../agents_api/queries/agents/list_agents.py | 2 +- .../agents_api/queries/agents/patch_agent.py | 2 +- .../agents_api/queries/agents/update_agent.py | 2 +- .../queries/developers/get_developer.py | 2 +- .../queries/entries/create_entries.py | 2 +- .../queries/entries/list_entries.py | 5 +---- agents-api/agents_api/queries/utils.py | 8 +++++-- agents-api/tests/test_entry_queries.py | 3 +-- agents-api/tests/test_messages_truncation.py | 1 - agents-api/tests/test_session_queries.py | 22 +++++++++---------- 14 files changed, 30 insertions(+), 32 deletions(-) diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index bb111b0df..a6b56d84f 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -10,14 +10,13 @@ from sqlglot import parse_one from uuid_extensions import uuid7 -from ...metrics.counters import increase_counter - from ...autogen.openapi_model import Agent, CreateAgentRequest +from ...metrics.counters import increase_counter from ..utils import ( generate_canonical_name, pg_query, - wrap_in_class, rewrap_exceptions, + wrap_in_class, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index 6cfb83767..2aa0d1501 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -14,8 +14,8 @@ from ..utils import ( generate_canonical_name, pg_query, - wrap_in_class, rewrap_exceptions, + wrap_in_class, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index 9c3ee5585..df0f0c325 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -10,12 +10,12 @@ from sqlglot import parse_one from ...autogen.openapi_model import ResourceDeletedResponse -from ...metrics.counters import increase_counter from ...common.utils.datetime import utcnow +from ...metrics.counters import increase_counter from ..utils import ( pg_query, - wrap_in_class, rewrap_exceptions, + wrap_in_class, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index dce424771..2cf1ef28d 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -13,8 +13,8 @@ from ...metrics.counters import increase_counter from ..utils import ( pg_query, - wrap_in_class, rewrap_exceptions, + wrap_in_class, ) raw_query = """ diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index 3698c68f1..306b7465b 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -13,8 +13,8 @@ from ...metrics.counters import increase_counter from ..utils import ( pg_query, - wrap_in_class, rewrap_exceptions, + wrap_in_class, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index 6f9cb3b9c..8d17c9f49 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -13,8 +13,8 @@ from ...metrics.counters import increase_counter from ..utils import ( pg_query, - wrap_in_class, rewrap_exceptions, + wrap_in_class, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index cd15313a2..fe5e31ac6 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -13,8 +13,8 @@ from ...metrics.counters import increase_counter from ..utils import ( pg_query, - wrap_in_class, rewrap_exceptions, + wrap_in_class, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/developers/get_developer.py b/agents-api/agents_api/queries/developers/get_developer.py index 28be9a4b1..373a2fb36 100644 --- a/agents-api/agents_api/queries/developers/get_developer.py +++ b/agents-api/agents_api/queries/developers/get_developer.py @@ -12,8 +12,8 @@ from ..utils import ( partialclass, pg_query, - wrap_in_class, rewrap_exceptions, + wrap_in_class, ) # TODO: Add verify_developer diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py index a54104274..4c1f7bfa7 100644 --- a/agents-api/agents_api/queries/entries/create_entries.py +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -166,7 +166,7 @@ async def add_entry_relations( item.get("is_leaf", False), # $5 ] ) - + return [ ( session_exists_query, diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py index 3f4a0699e..1c398f0ab 100644 --- a/agents-api/agents_api/queries/entries/list_entries.py +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -108,8 +108,5 @@ async def list_entries( [session_id, developer_id], "fetchrow", ), - ( - query, - entry_params - ), + (query, entry_params), ] diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index 194cba7bc..73113580d 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -106,7 +106,9 @@ def prepare_pg_query_args( batch.append( ( "fetchrow", - AsyncPGFetchArgs(query=query, args=variables, timeout=query_timeout), + AsyncPGFetchArgs( + query=query, args=variables, timeout=query_timeout + ), ) ) case _: @@ -173,7 +175,9 @@ async def wrapper( print(*args) print("%" * 100) - if method_name == "fetchrow" and (len(results) == 0 or results.get("bool") is None): + if method_name == "fetchrow" and ( + len(results) == 0 or results.get("bool") is None + ): raise asyncpg.NoDataFoundError end = timeit and time.perf_counter() diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index da53ce06d..60a387591 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -3,9 +3,8 @@ It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. """ -from uuid_extensions import uuid7 - from fastapi import HTTPException +from uuid_extensions import uuid7 from ward import raises, test from agents_api.autogen.openapi_model import CreateEntryRequest diff --git a/agents-api/tests/test_messages_truncation.py b/agents-api/tests/test_messages_truncation.py index bb1eaee30..1a6c344e6 100644 --- a/agents-api/tests/test_messages_truncation.py +++ b/agents-api/tests/test_messages_truncation.py @@ -1,4 +1,3 @@ - # from uuid_extensions import uuid7 # from ward import raises, test diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index b85268434..8e512379f 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -7,32 +7,32 @@ from ward import raises, test from agents_api.autogen.openapi_model import ( - Session, - CreateSessionRequest, CreateOrUpdateSessionRequest, - UpdateSessionRequest, + CreateSessionRequest, PatchSessionRequest, - ResourceUpdatedResponse, ResourceDeletedResponse, + ResourceUpdatedResponse, + Session, + UpdateSessionRequest, ) from agents_api.clients.pg import create_db_pool from agents_api.queries.sessions import ( count_sessions, + create_or_update_session, + create_session, + delete_session, get_session, list_sessions, - create_session, - create_or_update_session, - update_session, patch_session, - delete_session, + update_session, ) from tests.fixtures import ( pg_dsn, - test_developer_id, - test_developer, - test_user, test_agent, + test_developer, + test_developer_id, test_session, + test_user, ) From 638fefb6b2a5c79729db03be298f7c47c243de25 Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Wed, 18 Dec 2024 18:18:39 -0500 Subject: [PATCH 23/29] chore: minor refactors --- .../agents_api/queries/agents/__init__.py | 10 +++ .../agents_api/queries/agents/create_agent.py | 15 ++-- .../queries/agents/create_or_update_agent.py | 15 ++-- .../agents_api/queries/agents/delete_agent.py | 20 +++--- .../agents_api/queries/agents/get_agent.py | 17 ++--- .../agents_api/queries/agents/list_agents.py | 13 ++-- .../agents_api/queries/agents/patch_agent.py | 14 ++-- .../agents_api/queries/agents/update_agent.py | 15 ++-- .../queries/entries/create_entries.py | 72 ++++++++++--------- .../queries/entries/delete_entries.py | 54 +++++++------- .../agents_api/queries/entries/get_history.py | 28 ++++---- .../queries/entries/list_entries.py | 51 +++++++------ 12 files changed, 171 insertions(+), 153 deletions(-) diff --git a/agents-api/agents_api/queries/agents/__init__.py b/agents-api/agents_api/queries/agents/__init__.py index ebd169040..c0712c47c 100644 --- a/agents-api/agents_api/queries/agents/__init__.py +++ b/agents-api/agents_api/queries/agents/__init__.py @@ -19,3 +19,13 @@ from .list_agents import list_agents from .patch_agent import patch_agent from .update_agent import update_agent + +__all__ = [ + "create_agent", + "create_or_update_agent", + "delete_agent", + "get_agent", + "list_agents", + "patch_agent", + "update_agent", +] diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index a6b56d84f..2d8df7978 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -19,10 +19,8 @@ wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -raw_query = """ +# Define the raw SQL query +agent_query = parse_one(""" INSERT INTO agents ( developer_id, agent_id, @@ -46,9 +44,7 @@ $9 ) RETURNING *; -""" - -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) # @rewrap_exceptions( @@ -135,4 +131,7 @@ async def create_agent( default_settings, ] - return query, params + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index 2aa0d1501..e96b30c77 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -18,10 +18,8 @@ wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -raw_query = """ +# Define the raw SQL query +agent_query = parse_one(""" INSERT INTO agents ( developer_id, agent_id, @@ -45,9 +43,7 @@ $9 ) RETURNING *; -""" - -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) # @rewrap_exceptions( @@ -110,4 +106,7 @@ async def create_or_update_agent( default_settings, ] - return (query, params) + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index df0f0c325..6738374db 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -11,17 +11,14 @@ from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow -from ...metrics.counters import increase_counter from ..utils import ( pg_query, rewrap_exceptions, wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -raw_query = """ +# Define the raw SQL query +agent_query = parse_one(""" WITH deleted_docs AS ( DELETE FROM docs WHERE developer_id = $1 @@ -41,13 +38,10 @@ DELETE FROM agents WHERE agent_id = $2 AND developer_id = $1 RETURNING developer_id, agent_id; -""" - - -# Convert the list of queries into a single query string -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) +# @rewrap_exceptions( # @rewrap_exceptions( # { # psycopg_errors.ForeignKeyViolation: partialclass( @@ -63,7 +57,6 @@ one=True, transform=lambda d: {**d, "id": d["agent_id"], "deleted_at": utcnow()}, ) -@increase_counter("delete_agent") @pg_query @beartype async def delete_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]: @@ -80,4 +73,7 @@ async def delete_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list # Note: We swap the parameter order because the queries use $1 for developer_id and $2 for agent_id params = [developer_id, agent_id] - return (query, params) + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index 2cf1ef28d..916572db1 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -10,14 +10,14 @@ from sqlglot import parse_one from ...autogen.openapi_model import Agent -from ...metrics.counters import increase_counter from ..utils import ( pg_query, rewrap_exceptions, wrap_in_class, ) -raw_query = """ +# Define the raw SQL query +agent_query = parse_one(""" SELECT agent_id, developer_id, @@ -34,12 +34,7 @@ agents WHERE agent_id = $2 AND developer_id = $1; -""" - -query = parse_one(raw_query).sql(pretty=True) - -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") +""").sql(pretty=True) # @rewrap_exceptions( @@ -53,7 +48,6 @@ # # TODO: Add more exceptions # ) @wrap_in_class(Agent, one=True, transform=lambda d: {"id": d["agent_id"], **d}) -@increase_counter("get_agent") @pg_query @beartype async def get_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]: @@ -68,4 +62,7 @@ async def get_agent(*, agent_id: UUID, developer_id: UUID) -> tuple[str, list]: tuple[list[str], dict]: A tuple containing the SQL query and its parameters. """ - return (query, [developer_id, agent_id]) + return ( + agent_query, + [developer_id, agent_id], + ) diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index 306b7465b..ce12b32b3 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -10,16 +10,13 @@ from fastapi import HTTPException from ...autogen.openapi_model import Agent -from ...metrics.counters import increase_counter from ..utils import ( pg_query, rewrap_exceptions, wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - +# Define the raw SQL query raw_query = """ SELECT agent_id, @@ -55,7 +52,6 @@ # # TODO: Add more exceptions # ) @wrap_in_class(Agent, transform=lambda d: {"id": d["agent_id"], **d}) -@increase_counter("list_agents") @pg_query @beartype async def list_agents( @@ -87,7 +83,7 @@ async def list_agents( # Build metadata filter clause if needed - final_query = raw_query.format( + agent_query = raw_query.format( metadata_filter_query="AND metadata @> $6::jsonb" if metadata_filter else "" ) @@ -102,4 +98,7 @@ async def list_agents( if metadata_filter: params.append(metadata_filter) - return final_query, params + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index 8d17c9f49..7fb63feda 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -17,10 +17,9 @@ wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") -raw_query = """ +# Define the raw SQL query +agent_query = parse_one(""" UPDATE agents SET name = CASE @@ -45,9 +44,7 @@ END WHERE agent_id = $2 AND developer_id = $1 RETURNING *; -""" - -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) # @rewrap_exceptions( @@ -92,4 +89,7 @@ async def patch_agent( data.default_settings.model_dump() if data.default_settings else None, ] - return query, params + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index fe5e31ac6..79b520cb8 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -17,10 +17,8 @@ wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -raw_query = """ +# Define the raw SQL query +agent_query = parse_one(""" UPDATE agents SET metadata = $3, @@ -30,9 +28,7 @@ default_settings = $7::jsonb WHERE agent_id = $2 AND developer_id = $1 RETURNING *; -""" - -query = parse_one(raw_query).sql(pretty=True) +""").sql(pretty=True) # @rewrap_exceptions( @@ -77,4 +73,7 @@ async def update_agent( data.default_settings.model_dump() if data.default_settings else {}, ] - return (query, params) + return ( + agent_query, + params, + ) diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py index 4c1f7bfa7..7f6e2d4d7 100644 --- a/agents-api/agents_api/queries/entries/create_entries.py +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -10,7 +10,7 @@ from ...common.utils.datetime import utcnow from ...common.utils.messages import content_to_json from ...metrics.counters import increase_counter -from ..utils import pg_query, rewrap_exceptions, wrap_in_class +from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass # Query for checking if the session exists session_exists_query = """ @@ -53,26 +53,30 @@ """ -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( - status_code=404, - detail=str(exc), - ), - asyncpg.UniqueViolationError: lambda exc: HTTPException( - status_code=409, - detail=str(exc), - ), - asyncpg.NotNullViolationError: lambda exc: HTTPException( - status_code=400, - detail=str(exc), - ), - asyncpg.NoDataFoundError: lambda exc: HTTPException( - status_code=404, - detail="Session not found", - ), - } -) +# @rewrap_exceptions( +# { +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=409, +# detail="Entry already exists", +# ), +# asyncpg.NotNullViolationError: partialclass( +# HTTPException, +# status_code=400, +# detail="Not null violation", +# ), +# asyncpg.NoDataFoundError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# } +# ) @wrap_in_class( Entry, transform=lambda d: { @@ -128,18 +132,20 @@ async def create_entries( ] -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( - status_code=404, - detail=str(exc), - ), - asyncpg.UniqueViolationError: lambda exc: HTTPException( - status_code=409, - detail=str(exc), - ), - } -) +# @rewrap_exceptions( +# { +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=409, +# detail="Entry already exists", +# ), +# } +# ) @wrap_in_class(Relation) @increase_counter("add_entry_relations") @pg_query diff --git a/agents-api/agents_api/queries/entries/delete_entries.py b/agents-api/agents_api/queries/entries/delete_entries.py index dfdadb8da..ce1590fd4 100644 --- a/agents-api/agents_api/queries/entries/delete_entries.py +++ b/agents-api/agents_api/queries/entries/delete_entries.py @@ -9,7 +9,7 @@ from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow from ...metrics.counters import increase_counter -from ..utils import pg_query, rewrap_exceptions, wrap_in_class +from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass # Define the raw SQL query for deleting entries with a developer check delete_entry_query = parse_one(""" @@ -57,18 +57,20 @@ """ -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( - status_code=404, - detail="The specified session or developer does not exist.", - ), - asyncpg.UniqueViolationError: lambda exc: HTTPException( - status_code=409, - detail="The specified session has already been deleted.", - ), - } -) +# @rewrap_exceptions( +# { +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified session or developer does not exist.", +# ), +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=409, +# detail="The specified session has already been deleted.", +# ), +# } +# ) @wrap_in_class( ResourceDeletedResponse, one=True, @@ -94,18 +96,20 @@ async def delete_entries_for_session( ] -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( - status_code=404, - detail="The specified entries, session, or developer does not exist.", - ), - asyncpg.UniqueViolationError: lambda exc: HTTPException( - status_code=409, - detail="One or more specified entries have already been deleted.", - ), - } -) +# @rewrap_exceptions( +# { +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified entries, session, or developer does not exist.", +# ), +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=409, +# detail="One or more specified entries have already been deleted.", +# ), +# } +# ) @wrap_in_class( ResourceDeletedResponse, transform=lambda d: { diff --git a/agents-api/agents_api/queries/entries/get_history.py b/agents-api/agents_api/queries/entries/get_history.py index 8f0ddf4a1..2c28b4f21 100644 --- a/agents-api/agents_api/queries/entries/get_history.py +++ b/agents-api/agents_api/queries/entries/get_history.py @@ -6,7 +6,7 @@ from sqlglot import parse_one from ...autogen.openapi_model import History -from ..utils import pg_query, rewrap_exceptions, wrap_in_class +from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass # Define the raw SQL query for getting history with a developer check history_query = parse_one(""" @@ -30,18 +30,20 @@ """).sql(pretty=True) -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( - status_code=404, - detail=str(exc), - ), - asyncpg.UniqueViolationError: lambda exc: HTTPException( - status_code=404, - detail=str(exc), - ), - } -) +# @rewrap_exceptions( +# { +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# } +# ) @wrap_in_class( History, one=True, diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py index 1c398f0ab..657f5563b 100644 --- a/agents-api/agents_api/queries/entries/list_entries.py +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -7,7 +7,7 @@ from ...autogen.openapi_model import Entry from ...metrics.counters import increase_counter -from ..utils import pg_query, rewrap_exceptions, wrap_in_class +from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass # Query for checking if the session exists session_exists_query = """ @@ -48,26 +48,30 @@ """ -@rewrap_exceptions( - { - asyncpg.ForeignKeyViolationError: lambda exc: HTTPException( - status_code=404, - detail=str(exc), - ), - asyncpg.UniqueViolationError: lambda exc: HTTPException( - status_code=409, - detail=str(exc), - ), - asyncpg.NotNullViolationError: lambda exc: HTTPException( - status_code=400, - detail=str(exc), - ), - asyncpg.NoDataFoundError: lambda exc: HTTPException( - status_code=404, - detail="Session not found", - ), - } -) +# @rewrap_exceptions( +# { +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=409, +# detail="Entry already exists", +# ), +# asyncpg.NotNullViolationError: partialclass( +# HTTPException, +# status_code=400, +# detail="Entry is required", +# ), +# asyncpg.NoDataFoundError: partialclass( +# HTTPException, +# status_code=404, +# detail="Session not found", +# ), +# } +# ) @wrap_in_class(Entry) @increase_counter("list_entries") @pg_query @@ -108,5 +112,8 @@ async def list_entries( [session_id, developer_id], "fetchrow", ), - (query, entry_params), + ( + query, + entry_params, + ), ] From 2ba91ad2eeb66ff039d184dd28324e8f99672bc0 Mon Sep 17 00:00:00 2001 From: Vedantsahai18 Date: Wed, 18 Dec 2024 23:19:36 +0000 Subject: [PATCH 24/29] refactor: Lint agents-api (CI) --- agents-api/agents_api/queries/agents/patch_agent.py | 1 - agents-api/agents_api/queries/entries/create_entries.py | 2 +- agents-api/agents_api/queries/entries/delete_entries.py | 2 +- agents-api/agents_api/queries/entries/get_history.py | 2 +- agents-api/agents_api/queries/entries/list_entries.py | 2 +- 5 files changed, 4 insertions(+), 5 deletions(-) diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index 7fb63feda..2325ab33f 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -17,7 +17,6 @@ wrap_in_class, ) - # Define the raw SQL query agent_query = parse_one(""" UPDATE agents diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py index 7f6e2d4d7..72de8db90 100644 --- a/agents-api/agents_api/queries/entries/create_entries.py +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -10,7 +10,7 @@ from ...common.utils.datetime import utcnow from ...common.utils.messages import content_to_json from ...metrics.counters import increase_counter -from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Query for checking if the session exists session_exists_query = """ diff --git a/agents-api/agents_api/queries/entries/delete_entries.py b/agents-api/agents_api/queries/entries/delete_entries.py index ce1590fd4..4539ae4df 100644 --- a/agents-api/agents_api/queries/entries/delete_entries.py +++ b/agents-api/agents_api/queries/entries/delete_entries.py @@ -9,7 +9,7 @@ from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow from ...metrics.counters import increase_counter -from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for deleting entries with a developer check delete_entry_query = parse_one(""" diff --git a/agents-api/agents_api/queries/entries/get_history.py b/agents-api/agents_api/queries/entries/get_history.py index 2c28b4f21..7ad940c0a 100644 --- a/agents-api/agents_api/queries/entries/get_history.py +++ b/agents-api/agents_api/queries/entries/get_history.py @@ -6,7 +6,7 @@ from sqlglot import parse_one from ...autogen.openapi_model import History -from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for getting history with a developer check history_query = parse_one(""" diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py index 657f5563b..4920e39c1 100644 --- a/agents-api/agents_api/queries/entries/list_entries.py +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -7,7 +7,7 @@ from ...autogen.openapi_model import Entry from ...metrics.counters import increase_counter -from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Query for checking if the session exists session_exists_query = """ From 57e453f51260f1458e1b0e2c0c86d8af16f3474a Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Thu, 19 Dec 2024 10:13:50 +0530 Subject: [PATCH 25/29] feat(memory-store,agents-api): Move is_leaf handling to postgres Signed-off-by: Diwank Singh Tomer --- .../agents_api/queries/agents/create_agent.py | 2 -- .../queries/agents/create_or_update_agent.py | 2 -- .../agents_api/queries/agents/delete_agent.py | 2 -- .../agents_api/queries/agents/get_agent.py | 2 -- .../agents_api/queries/agents/list_agents.py | 3 +- .../agents_api/queries/agents/patch_agent.py | 2 -- .../agents_api/queries/agents/update_agent.py | 2 -- .../queries/entries/create_entries.py | 6 +--- .../queries/entries/delete_entries.py | 4 +-- .../agents_api/queries/entries/get_history.py | 4 +-- .../queries/entries/list_entries.py | 3 +- agents-api/tests/test_entry_queries.py | 2 +- agents-api/tests/test_session_queries.py | 1 - .../migrations/000016_entry_relations.up.sql | 34 +++++++++++-------- 14 files changed, 25 insertions(+), 44 deletions(-) diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index 2d8df7978..76c96f46b 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -3,7 +3,6 @@ It includes functions to construct and execute SQL queries for inserting new agent records. """ -from typing import Any, TypeVar from uuid import UUID from beartype import beartype @@ -15,7 +14,6 @@ from ..utils import ( generate_canonical_name, pg_query, - rewrap_exceptions, wrap_in_class, ) diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index e96b30c77..ef3a0abe5 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -3,7 +3,6 @@ It constructs and executes SQL queries to insert a new agent or update an existing agent's details based on agent ID and developer ID. """ -from typing import Any, TypeVar from uuid import UUID from beartype import beartype @@ -14,7 +13,6 @@ from ..utils import ( generate_canonical_name, pg_query, - rewrap_exceptions, wrap_in_class, ) diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index 6738374db..3527f3611 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -3,7 +3,6 @@ It constructs and executes SQL queries to remove agent records and associated data. """ -from typing import Any, TypeVar from uuid import UUID from beartype import beartype @@ -13,7 +12,6 @@ from ...common.utils.datetime import utcnow from ..utils import ( pg_query, - rewrap_exceptions, wrap_in_class, ) diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index 916572db1..a731300fa 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -3,7 +3,6 @@ It constructs and executes SQL queries to fetch agent details based on agent ID and developer ID. """ -from typing import Any, TypeVar from uuid import UUID from beartype import beartype @@ -12,7 +11,6 @@ from ...autogen.openapi_model import Agent from ..utils import ( pg_query, - rewrap_exceptions, wrap_in_class, ) diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index ce12b32b3..87a0c942d 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -3,7 +3,7 @@ It constructs and executes SQL queries to fetch a list of agents based on developer ID with pagination. """ -from typing import Any, Literal, TypeVar +from typing import Any, Literal from uuid import UUID from beartype import beartype @@ -12,7 +12,6 @@ from ...autogen.openapi_model import Agent from ..utils import ( pg_query, - rewrap_exceptions, wrap_in_class, ) diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index 2325ab33f..69a5a6ca5 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -3,7 +3,6 @@ It constructs and executes SQL queries to update specific fields of an agent based on agent ID and developer ID. """ -from typing import Any, TypeVar from uuid import UUID from beartype import beartype @@ -13,7 +12,6 @@ from ...metrics.counters import increase_counter from ..utils import ( pg_query, - rewrap_exceptions, wrap_in_class, ) diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index 79b520cb8..f28e28264 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -3,7 +3,6 @@ It constructs and executes SQL queries to replace an agent's details based on agent ID and developer ID. """ -from typing import Any, TypeVar from uuid import UUID from beartype import beartype @@ -13,7 +12,6 @@ from ...metrics.counters import increase_counter from ..utils import ( pg_query, - rewrap_exceptions, wrap_in_class, ) diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py index 72de8db90..fb61b7c7e 100644 --- a/agents-api/agents_api/queries/entries/create_entries.py +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -1,16 +1,14 @@ from typing import Literal from uuid import UUID -import asyncpg from beartype import beartype -from fastapi import HTTPException from uuid_extensions import uuid7 from ...autogen.openapi_model import CreateEntryRequest, Entry, Relation from ...common.utils.datetime import utcnow from ...common.utils.messages import content_to_json from ...metrics.counters import increase_counter -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class +from ..utils import pg_query, wrap_in_class # Query for checking if the session exists session_exists_query = """ @@ -47,7 +45,6 @@ head, relation, tail, - is_leaf ) VALUES ($1, $2, $3, $4, $5) RETURNING *; """ @@ -169,7 +166,6 @@ async def add_entry_relations( item.get("head"), # $2 item.get("relation"), # $3 item.get("tail"), # $4 - item.get("is_leaf", False), # $5 ] ) diff --git a/agents-api/agents_api/queries/entries/delete_entries.py b/agents-api/agents_api/queries/entries/delete_entries.py index 4539ae4df..628ef9011 100644 --- a/agents-api/agents_api/queries/entries/delete_entries.py +++ b/agents-api/agents_api/queries/entries/delete_entries.py @@ -1,15 +1,13 @@ from typing import Literal from uuid import UUID -import asyncpg from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow from ...metrics.counters import increase_counter -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class +from ..utils import pg_query, wrap_in_class # Define the raw SQL query for deleting entries with a developer check delete_entry_query = parse_one(""" diff --git a/agents-api/agents_api/queries/entries/get_history.py b/agents-api/agents_api/queries/entries/get_history.py index 7ad940c0a..b0b767c08 100644 --- a/agents-api/agents_api/queries/entries/get_history.py +++ b/agents-api/agents_api/queries/entries/get_history.py @@ -1,12 +1,10 @@ from uuid import UUID -import asyncpg from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one from ...autogen.openapi_model import History -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class +from ..utils import pg_query, wrap_in_class # Define the raw SQL query for getting history with a developer check history_query = parse_one(""" diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py index 4920e39c1..a6c355f53 100644 --- a/agents-api/agents_api/queries/entries/list_entries.py +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -1,13 +1,12 @@ from typing import Literal from uuid import UUID -import asyncpg from beartype import beartype from fastapi import HTTPException from ...autogen.openapi_model import Entry from ...metrics.counters import increase_counter -from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class +from ..utils import pg_query, wrap_in_class # Query for checking if the session exists session_exists_query = """ diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 60a387591..f5b9d8d56 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -10,7 +10,7 @@ from agents_api.autogen.openapi_model import CreateEntryRequest from agents_api.clients.pg import create_db_pool from agents_api.queries.entries import create_entries, list_entries -from tests.fixtures import pg_dsn, test_developer, test_session # , test_session +from tests.fixtures import pg_dsn, test_developer # , test_session MODEL = "gpt-4o-mini" diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 8e512379f..4e04468bf 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -29,7 +29,6 @@ from tests.fixtures import ( pg_dsn, test_agent, - test_developer, test_developer_id, test_session, test_user, diff --git a/memory-store/migrations/000016_entry_relations.up.sql b/memory-store/migrations/000016_entry_relations.up.sql index c61c7cd24..bcdb7fb72 100644 --- a/memory-store/migrations/000016_entry_relations.up.sql +++ b/memory-store/migrations/000016_entry_relations.up.sql @@ -31,25 +31,29 @@ CREATE INDEX idx_entry_relations_components ON entry_relations (session_id, head CREATE INDEX idx_entry_relations_leaf ON entry_relations (session_id, relation, is_leaf); -CREATE -OR REPLACE FUNCTION enforce_leaf_nodes () RETURNS TRIGGER AS $$ +CREATE OR REPLACE FUNCTION auto_update_leaf_status() RETURNS TRIGGER AS $$ BEGIN - IF NEW.is_leaf THEN - -- Ensure no other relations point to this leaf node as a head - IF EXISTS ( - SELECT 1 FROM entry_relations - WHERE tail = NEW.head AND session_id = NEW.session_id - ) THEN - RAISE EXCEPTION 'Cannot assign relations to a leaf node.'; - END IF; - END IF; + -- Set is_leaf = false for any existing rows that will now have this new relation as a child + UPDATE entry_relations + SET is_leaf = false + WHERE session_id = NEW.session_id + AND tail = NEW.head; + + -- Set is_leaf for the new row based on whether it has any children + NEW.is_leaf := NOT EXISTS ( + SELECT 1 + FROM entry_relations + WHERE session_id = NEW.session_id + AND head = NEW.tail + ); + RETURN NEW; END; $$ LANGUAGE plpgsql; -CREATE TRIGGER trg_enforce_leaf_nodes BEFORE INSERT -OR -UPDATE ON entry_relations FOR EACH ROW -EXECUTE FUNCTION enforce_leaf_nodes (); +CREATE TRIGGER trg_auto_update_leaf_status +BEFORE INSERT OR UPDATE ON entry_relations +FOR EACH ROW +EXECUTE FUNCTION auto_update_leaf_status(); COMMIT; \ No newline at end of file From bbdbb4b369649073fa2334b05e99d34eb44585f4 Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Thu, 19 Dec 2024 12:03:30 +0300 Subject: [PATCH 26/29] fix(agents-api): fix sessions and agents queries / tests --- .../queries/entries/create_entries.py | 2 +- .../sessions/create_or_update_session.py | 32 +++++------ .../queries/sessions/create_session.py | 23 +++++--- .../queries/sessions/patch_session.py | 51 +---------------- .../queries/sessions/update_session.py | 56 +++---------------- agents-api/agents_api/queries/utils.py | 17 +++--- agents-api/tests/fixtures.py | 10 ++-- agents-api/tests/test_agent_queries.py | 5 +- agents-api/tests/test_session_queries.py | 49 +++++++--------- 9 files changed, 78 insertions(+), 167 deletions(-) diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py index fb61b7c7e..33dcda984 100644 --- a/agents-api/agents_api/queries/entries/create_entries.py +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -173,7 +173,7 @@ async def add_entry_relations( ( session_exists_query, [session_id, developer_id], - "fetch", + "fetchrow", ), ( entry_relation_query, diff --git a/agents-api/agents_api/queries/sessions/create_or_update_session.py b/agents-api/agents_api/queries/sessions/create_or_update_session.py index bc54bf31b..26a353e94 100644 --- a/agents-api/agents_api/queries/sessions/create_or_update_session.py +++ b/agents-api/agents_api/queries/sessions/create_or_update_session.py @@ -61,11 +61,7 @@ participant_type, participant_id ) -SELECT - $1 as developer_id, - $2 as session_id, - unnest($3::participant_type[]) as participant_type, - unnest($4::uuid[]) as participant_id; +VALUES ($1, $2, $3, $4); """).sql(pretty=True) @@ -83,16 +79,23 @@ ), } ) -@wrap_in_class(ResourceUpdatedResponse, one=True) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: { + "id": d["session_id"], + "updated_at": d["updated_at"], + }, +) @increase_counter("create_or_update_session") -@pg_query +@pg_query(return_index=0) @beartype async def create_or_update_session( *, developer_id: UUID, session_id: UUID, data: CreateOrUpdateSessionRequest, -) -> list[tuple[str, list]]: +) -> list[tuple[str, list] | tuple[str, list, str]]: """ Constructs SQL queries to create or update a session and its participant lookups. @@ -139,14 +142,11 @@ async def create_or_update_session( ] # Prepare lookup parameters - lookup_params = [ - developer_id, # $1 - session_id, # $2 - participant_types, # $3 - participant_ids, # $4 - ] + lookup_params = [] + for participant_type, participant_id in zip(participant_types, participant_ids): + lookup_params.append([developer_id, session_id, participant_type, participant_id]) return [ - (session_query, session_params), - (lookup_query, lookup_params), + (session_query, session_params, "fetch"), + (lookup_query, lookup_params, "fetchmany"), ] diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py index baa3f09d1..91badb281 100644 --- a/agents-api/agents_api/queries/sessions/create_session.py +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -1,12 +1,14 @@ from uuid import UUID +from uuid_extensions import uuid7 import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from ...autogen.openapi_model import CreateSessionRequest, Session +from ...autogen.openapi_model import CreateSessionRequest, Session, ResourceCreatedResponse from ...metrics.counters import increase_counter +from ...common.utils.datetime import utcnow from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL queries @@ -63,14 +65,21 @@ ), } ) -@wrap_in_class(Session, transform=lambda d: {**d, "id": d["session_id"]}) +@wrap_in_class( + Session, + one=True, + transform=lambda d: { + **d, + "id": d["session_id"], + }, +) @increase_counter("create_session") -@pg_query +@pg_query(return_index=0) @beartype async def create_session( *, developer_id: UUID, - session_id: UUID, + session_id: UUID | None = None, data: CreateSessionRequest, ) -> list[tuple[str, list] | tuple[str, list, str]]: """ @@ -87,6 +96,7 @@ async def create_session( # Handle participants users = data.users or ([data.user] if data.user else []) agents = data.agents or ([data.agent] if data.agent else []) + session_id = session_id or uuid7() if not agents: raise HTTPException( @@ -123,10 +133,7 @@ async def create_session( for ptype, pid in zip(participant_types, participant_ids): lookup_params.append([developer_id, session_id, ptype, pid]) - print("*" * 100) - print(lookup_params) - print("*" * 100) return [ - (session_query, session_params), + (session_query, session_params, "fetch"), (lookup_query, lookup_params, "fetchmany"), ] diff --git a/agents-api/agents_api/queries/sessions/patch_session.py b/agents-api/agents_api/queries/sessions/patch_session.py index b14b94a8a..60d82468e 100644 --- a/agents-api/agents_api/queries/sessions/patch_session.py +++ b/agents-api/agents_api/queries/sessions/patch_session.py @@ -31,25 +31,6 @@ SELECT * FROM updated_session; """).sql(pretty=True) -lookup_query = parse_one(""" -WITH deleted_lookups AS ( - DELETE FROM session_lookup - WHERE developer_id = $1 AND session_id = $2 -) -INSERT INTO session_lookup ( - developer_id, - session_id, - participant_type, - participant_id -) -SELECT - $1 as developer_id, - $2 as session_id, - unnest($3::participant_type[]) as participant_type, - unnest($4::uuid[]) as participant_id; -""").sql(pretty=True) - - @rewrap_exceptions( { asyncpg.ForeignKeyViolationError: partialclass( @@ -64,7 +45,7 @@ ), } ) -@wrap_in_class(ResourceUpdatedResponse, one=True) +@wrap_in_class(ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["session_id"], "updated_at": d["updated_at"]},) @increase_counter("patch_session") @pg_query @beartype @@ -85,22 +66,6 @@ async def patch_session( Returns: list[tuple[str, list]]: List of SQL queries and their parameters """ - # Handle participants - users = data.users or ([data.user] if data.user else []) - agents = data.agents or ([data.agent] if data.agent else []) - - if data.agent and data.agents: - raise HTTPException( - status_code=400, - detail="Only one of 'agent' or 'agents' should be provided", - ) - - # Prepare participant arrays for lookup query if participants are provided - participant_types = [] - participant_ids = [] - if users or agents: - participant_types = ["user"] * len(users) + ["agent"] * len(agents) - participant_ids = [str(u) for u in users] + [str(a) for a in agents] # Extract fields from data, using None for unset fields session_params = [ @@ -116,16 +81,4 @@ async def patch_session( data.recall_options or {}, # $10 ] - queries = [(session_query, session_params)] - - # Only add lookup query if participants are provided - if participant_types: - lookup_params = [ - developer_id, # $1 - session_id, # $2 - participant_types, # $3 - participant_ids, # $4 - ] - queries.append((lookup_query, lookup_params)) - - return queries + return [(session_query, session_params)] diff --git a/agents-api/agents_api/queries/sessions/update_session.py b/agents-api/agents_api/queries/sessions/update_session.py index 01e21e732..7c58d10e6 100644 --- a/agents-api/agents_api/queries/sessions/update_session.py +++ b/agents-api/agents_api/queries/sessions/update_session.py @@ -27,24 +27,6 @@ RETURNING *; """).sql(pretty=True) -lookup_query = parse_one(""" -WITH deleted_lookups AS ( - DELETE FROM session_lookup - WHERE developer_id = $1 AND session_id = $2 -) -INSERT INTO session_lookup ( - developer_id, - session_id, - participant_type, - participant_id -) -SELECT - $1 as developer_id, - $2 as session_id, - unnest($3::participant_type[]) as participant_type, - unnest($4::uuid[]) as participant_id; -""").sql(pretty=True) - @rewrap_exceptions( { @@ -60,7 +42,14 @@ ), } ) -@wrap_in_class(ResourceUpdatedResponse, one=True) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: { + "id": d["session_id"], + "updated_at": d["updated_at"], + }, +) @increase_counter("update_session") @pg_query @beartype @@ -81,26 +70,6 @@ async def update_session( Returns: list[tuple[str, list]]: List of SQL queries and their parameters """ - # Handle participants - users = data.users or ([data.user] if data.user else []) - agents = data.agents or ([data.agent] if data.agent else []) - - if not agents: - raise HTTPException( - status_code=400, - detail="At least one agent must be provided", - ) - - if data.agent and data.agents: - raise HTTPException( - status_code=400, - detail="Only one of 'agent' or 'agents' should be provided", - ) - - # Prepare participant arrays for lookup query - participant_types = ["user"] * len(users) + ["agent"] * len(agents) - participant_ids = [str(u) for u in users] + [str(a) for a in agents] - # Prepare session parameters session_params = [ developer_id, # $1 @@ -115,15 +84,6 @@ async def update_session( data.recall_options or {}, # $10 ] - # Prepare lookup parameters - lookup_params = [ - developer_id, # $1 - session_id, # $2 - participant_types, # $3 - participant_ids, # $4 - ] - return [ (session_query, session_params), - (lookup_query, lookup_params), ] diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index 73113580d..4126c91dc 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -123,6 +123,7 @@ def pg_query( debug: bool | None = None, only_on_error: bool = False, timeit: bool = False, + return_index: int = -1, ) -> Callable[..., Callable[P, list[Record]]] | Callable[P, list[Record]]: def pg_query_dec( func: Callable[P, PGQueryArgs | list[PGQueryArgs]], @@ -159,6 +160,8 @@ async def wrapper( async with pool.acquire() as conn: async with conn.transaction(): start = timeit and time.perf_counter() + all_results = [] + for method_name, payload in batch: method = getattr(conn, method_name) @@ -169,11 +172,7 @@ async def wrapper( results: list[Record] = await method( query, *args, timeout=timeout ) - - print("%" * 100) - print(results) - print(*args) - print("%" * 100) + all_results.append(results) if method_name == "fetchrow" and ( len(results) == 0 or results.get("bool") is None @@ -204,9 +203,11 @@ async def wrapper( raise - not only_on_error and debug and pprint(results) - - return results + # Return results from specified index + results_to_return = all_results[return_index] if all_results else [] + not only_on_error and debug and pprint(results_to_return) + + return results_to_return # Set the wrapped function as an attribute of the wrapper, # forwards the __wrapped__ attribute if it exists. diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 9153785a4..49c2e7094 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -96,7 +96,7 @@ def patch_embed_acompletion(): yield embed, acompletion -@fixture(scope="global") +@fixture(scope="test") async def test_agent(dsn=pg_dsn, developer=test_developer): pool = await create_db_pool(dsn=dsn) @@ -105,18 +105,16 @@ async def test_agent(dsn=pg_dsn, developer=test_developer): data=CreateAgentRequest( model="gpt-4o-mini", name="test agent", - canonical_name=f"test_agent_{str(int(time.time()))}", about="test agent about", metadata={"test": "test"}, ), connection_pool=pool, ) - yield agent - await pool.close() + return agent -@fixture(scope="global") +@fixture(scope="test") async def test_user(dsn=pg_dsn, developer=test_developer): pool = await create_db_pool(dsn=dsn) @@ -153,7 +151,7 @@ async def test_new_developer(dsn=pg_dsn, email=random_email): return developer -@fixture(scope="global") +@fixture(scope="test") async def test_session( dsn=pg_dsn, developer_id=test_developer_id, diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index b6cb7aedc..594047a82 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -41,7 +41,8 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): ) -@test("query: create agent with instructions sql") + +@test("query: create or update agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that an agent can be successfully created or updated.""" @@ -60,6 +61,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): ) + @test("query: update agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): """Test that an existing agent's information can be successfully updated.""" @@ -81,7 +83,6 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): assert result is not None assert isinstance(result, ResourceUpdatedResponse) - @test("query: get agent not exists sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that retrieving a non-existent agent raises an exception.""" diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 4e04468bf..ec2e511d4 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -7,13 +7,15 @@ from ward import raises, test from agents_api.autogen.openapi_model import ( + Session, CreateOrUpdateSessionRequest, CreateSessionRequest, + UpdateSessionRequest, PatchSessionRequest, ResourceDeletedResponse, ResourceUpdatedResponse, - Session, - UpdateSessionRequest, + ResourceDeletedResponse, + ResourceCreatedResponse, ) from agents_api.clients.pg import create_db_pool from agents_api.queries.sessions import ( @@ -46,7 +48,6 @@ async def _( data = CreateSessionRequest( users=[user.id], agents=[agent.id], - situation="test session", system_template="test system template", ) result = await create_session( @@ -59,10 +60,6 @@ async def _( assert result is not None assert isinstance(result, Session), f"Result is not a Session, {result}" assert result.id == session_id - assert result.developer_id == developer_id - assert result.situation == "test session" - assert set(result.users) == {user.id} - assert set(result.agents) == {agent.id} @test("query: create or update session sql") @@ -76,7 +73,7 @@ async def _( data = CreateOrUpdateSessionRequest( users=[user.id], agents=[agent.id], - situation="test session", + system_template="test system template", ) result = await create_or_update_session( developer_id=developer_id, @@ -86,12 +83,9 @@ async def _( ) assert result is not None - assert isinstance(result, Session) + assert isinstance(result, ResourceUpdatedResponse) assert result.id == session_id - assert result.developer_id == developer_id - assert result.situation == "test session" - assert set(result.users) == {user.id} - assert set(result.agents) == {agent.id} + assert result.updated_at is not None @test("query: get session exists") @@ -108,7 +102,6 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): assert result is not None assert isinstance(result, Session) assert result.id == session.id - assert result.developer_id == developer_id @test("query: get session does not exist") @@ -130,7 +123,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): """Test listing sessions with default pagination.""" pool = await create_db_pool(dsn=dsn) - result, _ = await list_sessions( + result = await list_sessions( developer_id=developer_id, limit=10, offset=0, @@ -147,17 +140,18 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): """Test listing sessions with specific filters.""" pool = await create_db_pool(dsn=dsn) - result, _ = await list_sessions( + result = await list_sessions( developer_id=developer_id, limit=10, offset=0, - filters={"situation": "test session"}, connection_pool=pool, ) assert isinstance(result, list) assert len(result) >= 1 - assert all(s.situation == "test session" for s in result) + assert all( + s.situation == session.situation for s in result + ), f"Result is not a list of sessions, {result}, {session.situation}" @test("query: count sessions") @@ -170,20 +164,21 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): connection_pool=pool, ) - assert isinstance(count, int) - assert count >= 1 + assert isinstance(count, dict) + assert count["count"] >= 1 @test("query: update session sql") async def _( - dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent + dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent, user=test_user ): """Test that an existing session's information can be successfully updated.""" pool = await create_db_pool(dsn=dsn) data = UpdateSessionRequest( - agents=[agent.id], - situation="updated session", + token_budget=1000, + forward_tool_calls=True, + system_template="updated system template", ) result = await update_session( session_id=session.id, @@ -201,8 +196,7 @@ async def _( session_id=session.id, connection_pool=pool, ) - assert updated_session.situation == "updated session" - assert set(updated_session.agents) == {agent.id} + assert updated_session.forward_tool_calls is True @test("query: patch session sql") @@ -213,8 +207,6 @@ async def _( pool = await create_db_pool(dsn=dsn) data = PatchSessionRequest( - agents=[agent.id], - situation="patched session", metadata={"test": "metadata"}, ) result = await patch_session( @@ -233,8 +225,7 @@ async def _( session_id=session.id, connection_pool=pool, ) - assert patched_session.situation == "patched session" - assert set(patched_session.agents) == {agent.id} + assert patched_session.situation == session.situation assert patched_session.metadata == {"test": "metadata"} From 8361e7d33e272d193bcd83f15248741751dfde85 Mon Sep 17 00:00:00 2001 From: Ahmad-mtos Date: Thu, 19 Dec 2024 09:06:04 +0000 Subject: [PATCH 27/29] refactor: Lint agents-api (CI) --- .../queries/sessions/create_or_update_session.py | 4 +++- .../agents_api/queries/sessions/create_session.py | 10 +++++++--- .../agents_api/queries/sessions/patch_session.py | 7 ++++++- agents-api/agents_api/queries/utils.py | 4 ++-- agents-api/tests/test_agent_queries.py | 3 +-- agents-api/tests/test_session_queries.py | 13 ++++++++----- 6 files changed, 27 insertions(+), 14 deletions(-) diff --git a/agents-api/agents_api/queries/sessions/create_or_update_session.py b/agents-api/agents_api/queries/sessions/create_or_update_session.py index 26a353e94..3c4dbf66e 100644 --- a/agents-api/agents_api/queries/sessions/create_or_update_session.py +++ b/agents-api/agents_api/queries/sessions/create_or_update_session.py @@ -144,7 +144,9 @@ async def create_or_update_session( # Prepare lookup parameters lookup_params = [] for participant_type, participant_id in zip(participant_types, participant_ids): - lookup_params.append([developer_id, session_id, participant_type, participant_id]) + lookup_params.append( + [developer_id, session_id, participant_type, participant_id] + ) return [ (session_query, session_params, "fetch"), diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py index 91badb281..63fbdc940 100644 --- a/agents-api/agents_api/queries/sessions/create_session.py +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -1,14 +1,18 @@ from uuid import UUID -from uuid_extensions import uuid7 import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one +from uuid_extensions import uuid7 -from ...autogen.openapi_model import CreateSessionRequest, Session, ResourceCreatedResponse -from ...metrics.counters import increase_counter +from ...autogen.openapi_model import ( + CreateSessionRequest, + ResourceCreatedResponse, + Session, +) from ...common.utils.datetime import utcnow +from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL queries diff --git a/agents-api/agents_api/queries/sessions/patch_session.py b/agents-api/agents_api/queries/sessions/patch_session.py index 60d82468e..7d526ae1a 100644 --- a/agents-api/agents_api/queries/sessions/patch_session.py +++ b/agents-api/agents_api/queries/sessions/patch_session.py @@ -31,6 +31,7 @@ SELECT * FROM updated_session; """).sql(pretty=True) + @rewrap_exceptions( { asyncpg.ForeignKeyViolationError: partialclass( @@ -45,7 +46,11 @@ ), } ) -@wrap_in_class(ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["session_id"], "updated_at": d["updated_at"]},) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: {"id": d["session_id"], "updated_at": d["updated_at"]}, +) @increase_counter("patch_session") @pg_query @beartype diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index 4126c91dc..0c20ca59e 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -161,7 +161,7 @@ async def wrapper( async with conn.transaction(): start = timeit and time.perf_counter() all_results = [] - + for method_name, payload in batch: method = getattr(conn, method_name) @@ -206,7 +206,7 @@ async def wrapper( # Return results from specified index results_to_return = all_results[return_index] if all_results else [] not only_on_error and debug and pprint(results_to_return) - + return results_to_return # Set the wrapped function as an attribute of the wrapper, diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index 594047a82..85d10f6ea 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -41,7 +41,6 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): ) - @test("query: create or update agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that an agent can be successfully created or updated.""" @@ -61,7 +60,6 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): ) - @test("query: update agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): """Test that an existing agent's information can be successfully updated.""" @@ -83,6 +81,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): assert result is not None assert isinstance(result, ResourceUpdatedResponse) + @test("query: get agent not exists sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that retrieving a non-existent agent raises an exception.""" diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index ec2e511d4..5f2190e2b 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -7,15 +7,14 @@ from ward import raises, test from agents_api.autogen.openapi_model import ( - Session, CreateOrUpdateSessionRequest, CreateSessionRequest, - UpdateSessionRequest, PatchSessionRequest, + ResourceCreatedResponse, ResourceDeletedResponse, ResourceUpdatedResponse, - ResourceDeletedResponse, - ResourceCreatedResponse, + Session, + UpdateSessionRequest, ) from agents_api.clients.pg import create_db_pool from agents_api.queries.sessions import ( @@ -170,7 +169,11 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): @test("query: update session sql") async def _( - dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent, user=test_user + dsn=pg_dsn, + developer_id=test_developer_id, + session=test_session, + agent=test_agent, + user=test_user, ): """Test that an existing session's information can be successfully updated.""" From e158f3adbd41aaeb996cd3a62c0401ca1aa21eaa Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Thu, 19 Dec 2024 19:45:43 +0530 Subject: [PATCH 28/29] feat(agents-api): Remove auto_blob_store in favor of interceptor based system Signed-off-by: Diwank Singh Tomer --- .../agents_api/activities/embed_docs.py | 2 - .../activities/excecute_api_call.py | 2 - .../activities/execute_integration.py | 2 - .../agents_api/activities/execute_system.py | 7 +- .../activities/sync_items_remote.py | 12 +- .../activities/task_steps/base_evaluate.py | 2 - .../activities/task_steps/cozo_query_step.py | 2 - .../activities/task_steps/evaluate_step.py | 2 - .../activities/task_steps/for_each_step.py | 2 - .../activities/task_steps/get_value_step.py | 5 +- .../activities/task_steps/if_else_step.py | 2 - .../activities/task_steps/log_step.py | 2 - .../activities/task_steps/map_reduce_step.py | 2 - .../activities/task_steps/prompt_step.py | 2 - .../task_steps/raise_complete_async.py | 2 - .../activities/task_steps/return_step.py | 2 - .../activities/task_steps/set_value_step.py | 5 +- .../activities/task_steps/switch_step.py | 2 - .../activities/task_steps/tool_call_step.py | 2 - .../activities/task_steps/transition_step.py | 6 - .../task_steps/wait_for_input_step.py | 2 - .../activities/task_steps/yield_step.py | 2 - agents-api/agents_api/activities/utils.py | 1 - .../agents_api/autogen/openapi_model.py | 3 +- agents-api/agents_api/clients/async_s3.py | 1 + agents-api/agents_api/clients/temporal.py | 9 +- agents-api/agents_api/common/interceptors.py | 189 +++++++++------ .../agents_api/common/protocol/remote.py | 97 ++------ .../agents_api/common/protocol/sessions.py | 2 +- .../agents_api/common/protocol/tasks.py | 23 +- .../agents_api/common/storage_handler.py | 226 ------------------ agents-api/agents_api/env.py | 4 +- .../routers/healthz/check_health.py | 19 ++ .../workflows/task_execution/__init__.py | 12 +- .../workflows/task_execution/helpers.py | 7 - 35 files changed, 181 insertions(+), 481 deletions(-) delete mode 100644 agents-api/agents_api/common/storage_handler.py create mode 100644 agents-api/agents_api/routers/healthz/check_health.py diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py index c6c7663c3..a9a7cae44 100644 --- a/agents-api/agents_api/activities/embed_docs.py +++ b/agents-api/agents_api/activities/embed_docs.py @@ -7,13 +7,11 @@ from temporalio import activity from ..clients import cozo, litellm -from ..common.storage_handler import auto_blob_store from ..env import testing from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query from .types import EmbedDocsPayload -@auto_blob_store(deep=True) @beartype async def embed_docs( payload: EmbedDocsPayload, cozo_client=None, max_batch_size: int = 100 diff --git a/agents-api/agents_api/activities/excecute_api_call.py b/agents-api/agents_api/activities/excecute_api_call.py index 09a33aaa8..2167aaead 100644 --- a/agents-api/agents_api/activities/excecute_api_call.py +++ b/agents-api/agents_api/activities/excecute_api_call.py @@ -6,7 +6,6 @@ from temporalio import activity from ..autogen.openapi_model import ApiCallDef -from ..common.storage_handler import auto_blob_store from ..env import testing @@ -20,7 +19,6 @@ class RequestArgs(TypedDict): headers: Optional[dict[str, str]] -@auto_blob_store(deep=True) @beartype async def execute_api_call( api_call: ApiCallDef, diff --git a/agents-api/agents_api/activities/execute_integration.py b/agents-api/agents_api/activities/execute_integration.py index 3316ad6f5..d058553c4 100644 --- a/agents-api/agents_api/activities/execute_integration.py +++ b/agents-api/agents_api/activities/execute_integration.py @@ -7,12 +7,10 @@ from ..clients import integrations from ..common.exceptions.tools import IntegrationExecutionException from ..common.protocol.tasks import ExecutionInput, StepContext -from ..common.storage_handler import auto_blob_store from ..env import testing from ..models.tools import get_tool_args_from_metadata -@auto_blob_store(deep=True) @beartype async def execute_integration( context: StepContext, diff --git a/agents-api/agents_api/activities/execute_system.py b/agents-api/agents_api/activities/execute_system.py index 590849080..647327a8a 100644 --- a/agents-api/agents_api/activities/execute_system.py +++ b/agents-api/agents_api/activities/execute_system.py @@ -19,16 +19,14 @@ VectorDocSearchRequest, ) from ..common.protocol.tasks import ExecutionInput, StepContext -from ..common.storage_handler import auto_blob_store, load_from_blob_store_if_remote from ..env import testing -from ..queries.developer import get_developer +from ..queries.developers import get_developer from .utils import get_handler # For running synchronous code in the background process_pool_executor = ProcessPoolExecutor() -@auto_blob_store(deep=True) @beartype async def execute_system( context: StepContext, @@ -37,9 +35,6 @@ async def execute_system( """Execute a system call with the appropriate handler and transformed arguments.""" arguments: dict[str, Any] = system.arguments or {} - if set(arguments.keys()) == {"bucket", "key"}: - arguments = await load_from_blob_store_if_remote(arguments) - if not isinstance(context.execution_input, ExecutionInput): raise TypeError("Expected ExecutionInput type for context.execution_input") diff --git a/agents-api/agents_api/activities/sync_items_remote.py b/agents-api/agents_api/activities/sync_items_remote.py index d71a5c566..14751c2b6 100644 --- a/agents-api/agents_api/activities/sync_items_remote.py +++ b/agents-api/agents_api/activities/sync_items_remote.py @@ -9,20 +9,16 @@ @beartype async def save_inputs_remote_fn(inputs: list[Any]) -> list[Any | RemoteObject]: - from ..common.storage_handler import store_in_blob_store_if_large + from ..common.interceptors import offload_if_large - return await asyncio.gather( - *[store_in_blob_store_if_large(input) for input in inputs] - ) + return await asyncio.gather(*[offload_if_large(input) for input in inputs]) @beartype async def load_inputs_remote_fn(inputs: list[Any | RemoteObject]) -> list[Any]: - from ..common.storage_handler import load_from_blob_store_if_remote + from ..common.interceptors import load_if_remote - return await asyncio.gather( - *[load_from_blob_store_if_remote(input) for input in inputs] - ) + return await asyncio.gather(*[load_if_remote(input) for input in inputs]) save_inputs_remote = activity.defn(name="save_inputs_remote")(save_inputs_remote_fn) diff --git a/agents-api/agents_api/activities/task_steps/base_evaluate.py b/agents-api/agents_api/activities/task_steps/base_evaluate.py index d87b961d3..3bb04e390 100644 --- a/agents-api/agents_api/activities/task_steps/base_evaluate.py +++ b/agents-api/agents_api/activities/task_steps/base_evaluate.py @@ -13,7 +13,6 @@ from temporalio import activity # noqa: E402 from thefuzz import fuzz # noqa: E402 -from ...common.storage_handler import auto_blob_store # noqa: E402 from ...env import testing # noqa: E402 from ..utils import get_evaluator # noqa: E402 @@ -63,7 +62,6 @@ def _recursive_evaluate(expr, evaluator: SimpleEval): raise ValueError(f"Invalid expression: {expr}") -@auto_blob_store(deep=True) @beartype async def base_evaluate( exprs: Any, diff --git a/agents-api/agents_api/activities/task_steps/cozo_query_step.py b/agents-api/agents_api/activities/task_steps/cozo_query_step.py index 16e9a53d8..8d28d83c9 100644 --- a/agents-api/agents_api/activities/task_steps/cozo_query_step.py +++ b/agents-api/agents_api/activities/task_steps/cozo_query_step.py @@ -4,11 +4,9 @@ from temporalio import activity from ... import models -from ...common.storage_handler import auto_blob_store from ...env import testing -@auto_blob_store(deep=True) @beartype async def cozo_query_step( query_name: str, diff --git a/agents-api/agents_api/activities/task_steps/evaluate_step.py b/agents-api/agents_api/activities/task_steps/evaluate_step.py index 904ec3b9d..08fa6cd55 100644 --- a/agents-api/agents_api/activities/task_steps/evaluate_step.py +++ b/agents-api/agents_api/activities/task_steps/evaluate_step.py @@ -5,11 +5,9 @@ from ...activities.utils import simple_eval_dict from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing -@auto_blob_store(deep=True) @beartype async def evaluate_step( context: StepContext, diff --git a/agents-api/agents_api/activities/task_steps/for_each_step.py b/agents-api/agents_api/activities/task_steps/for_each_step.py index f51c1ef76..ca84eb75d 100644 --- a/agents-api/agents_api/activities/task_steps/for_each_step.py +++ b/agents-api/agents_api/activities/task_steps/for_each_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def for_each_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/get_value_step.py b/agents-api/agents_api/activities/task_steps/get_value_step.py index ca38bc4fe..feeb71bbf 100644 --- a/agents-api/agents_api/activities/task_steps/get_value_step.py +++ b/agents-api/agents_api/activities/task_steps/get_value_step.py @@ -2,13 +2,12 @@ from temporalio import activity from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing - # TODO: We should use this step to query the parent workflow and get the value from the workflow context # SCRUM-1 -@auto_blob_store(deep=True) + + @beartype async def get_value_step( context: StepContext, diff --git a/agents-api/agents_api/activities/task_steps/if_else_step.py b/agents-api/agents_api/activities/task_steps/if_else_step.py index cf3764199..ec4368640 100644 --- a/agents-api/agents_api/activities/task_steps/if_else_step.py +++ b/agents-api/agents_api/activities/task_steps/if_else_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def if_else_step(context: StepContext) -> StepOutcome: # NOTE: This activity is only for logging, so we just evaluate the expression diff --git a/agents-api/agents_api/activities/task_steps/log_step.py b/agents-api/agents_api/activities/task_steps/log_step.py index 28fea2dae..f54018683 100644 --- a/agents-api/agents_api/activities/task_steps/log_step.py +++ b/agents-api/agents_api/activities/task_steps/log_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...common.utils.template import render_template from ...env import testing -@auto_blob_store(deep=True) @beartype async def log_step(context: StepContext) -> StepOutcome: # NOTE: This activity is only for logging, so we just evaluate the expression diff --git a/agents-api/agents_api/activities/task_steps/map_reduce_step.py b/agents-api/agents_api/activities/task_steps/map_reduce_step.py index 872988bb4..c39bace20 100644 --- a/agents-api/agents_api/activities/task_steps/map_reduce_step.py +++ b/agents-api/agents_api/activities/task_steps/map_reduce_step.py @@ -8,12 +8,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def map_reduce_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index cf8b169d5..47560cadd 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -8,7 +8,6 @@ litellm, # We dont directly import `acompletion` so we can mock it ) from ...common.protocol.tasks import ExecutionInput, StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...common.utils.template import render_template from ...env import debug from .base_evaluate import base_evaluate @@ -62,7 +61,6 @@ def format_tool(tool: Tool) -> dict: @activity.defn -@auto_blob_store(deep=True) @beartype async def prompt_step(context: StepContext) -> StepOutcome: # Get context data diff --git a/agents-api/agents_api/activities/task_steps/raise_complete_async.py b/agents-api/agents_api/activities/task_steps/raise_complete_async.py index 640d6ae4e..bbf27c500 100644 --- a/agents-api/agents_api/activities/task_steps/raise_complete_async.py +++ b/agents-api/agents_api/activities/task_steps/raise_complete_async.py @@ -6,12 +6,10 @@ from ...autogen.openapi_model import CreateTransitionRequest from ...common.protocol.tasks import StepContext -from ...common.storage_handler import auto_blob_store from .transition_step import original_transition_step @activity.defn -@auto_blob_store(deep=True) @beartype async def raise_complete_async(context: StepContext, output: Any) -> None: activity_info = activity.info() diff --git a/agents-api/agents_api/activities/task_steps/return_step.py b/agents-api/agents_api/activities/task_steps/return_step.py index 08ac20de4..f15354536 100644 --- a/agents-api/agents_api/activities/task_steps/return_step.py +++ b/agents-api/agents_api/activities/task_steps/return_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def return_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/set_value_step.py b/agents-api/agents_api/activities/task_steps/set_value_step.py index 1c97b6551..96db5d0d1 100644 --- a/agents-api/agents_api/activities/task_steps/set_value_step.py +++ b/agents-api/agents_api/activities/task_steps/set_value_step.py @@ -5,13 +5,12 @@ from ...activities.utils import simple_eval_dict from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing - # TODO: We should use this step to signal to the parent workflow and set the value on the workflow context # SCRUM-2 -@auto_blob_store(deep=True) + + @beartype async def set_value_step( context: StepContext, diff --git a/agents-api/agents_api/activities/task_steps/switch_step.py b/agents-api/agents_api/activities/task_steps/switch_step.py index 6a95e98d2..100d8020a 100644 --- a/agents-api/agents_api/activities/task_steps/switch_step.py +++ b/agents-api/agents_api/activities/task_steps/switch_step.py @@ -6,12 +6,10 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store from ...env import testing from ..utils import get_evaluator -@auto_blob_store(deep=True) @beartype async def switch_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/tool_call_step.py b/agents-api/agents_api/activities/task_steps/tool_call_step.py index 5725a75d1..a2d7fd7c2 100644 --- a/agents-api/agents_api/activities/task_steps/tool_call_step.py +++ b/agents-api/agents_api/activities/task_steps/tool_call_step.py @@ -11,7 +11,6 @@ StepContext, StepOutcome, ) -from ...common.storage_handler import auto_blob_store # FIXME: This shouldn't be here. @@ -47,7 +46,6 @@ def construct_tool_call( @activity.defn -@auto_blob_store(deep=True) @beartype async def tool_call_step(context: StepContext) -> StepOutcome: assert isinstance(context.current_step, ToolCallStep) diff --git a/agents-api/agents_api/activities/task_steps/transition_step.py b/agents-api/agents_api/activities/task_steps/transition_step.py index 44046a5e7..11c7befb5 100644 --- a/agents-api/agents_api/activities/task_steps/transition_step.py +++ b/agents-api/agents_api/activities/task_steps/transition_step.py @@ -8,7 +8,6 @@ from ...autogen.openapi_model import CreateTransitionRequest, Transition from ...clients.temporal import get_workflow_handle from ...common.protocol.tasks import ExecutionInput, StepContext -from ...common.storage_handler import load_from_blob_store_if_remote from ...env import ( temporal_activity_after_retry_timeout, testing, @@ -48,11 +47,6 @@ async def transition_step( TaskExecutionWorkflow.set_last_error, LastErrorInput(last_error=None) ) - # Load output from blob store if it is a remote object - transition_info.output = await load_from_blob_store_if_remote( - transition_info.output - ) - if not isinstance(context.execution_input, ExecutionInput): raise TypeError("Expected ExecutionInput type for context.execution_input") diff --git a/agents-api/agents_api/activities/task_steps/wait_for_input_step.py b/agents-api/agents_api/activities/task_steps/wait_for_input_step.py index ad6eeb63e..a3cb00f67 100644 --- a/agents-api/agents_api/activities/task_steps/wait_for_input_step.py +++ b/agents-api/agents_api/activities/task_steps/wait_for_input_step.py @@ -3,12 +3,10 @@ from ...autogen.openapi_model import WaitForInputStep from ...common.protocol.tasks import StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def wait_for_input_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/task_steps/yield_step.py b/agents-api/agents_api/activities/task_steps/yield_step.py index 199008703..18e5383cc 100644 --- a/agents-api/agents_api/activities/task_steps/yield_step.py +++ b/agents-api/agents_api/activities/task_steps/yield_step.py @@ -5,12 +5,10 @@ from ...autogen.openapi_model import TransitionTarget, YieldStep from ...common.protocol.tasks import ExecutionInput, StepContext, StepOutcome -from ...common.storage_handler import auto_blob_store from ...env import testing from .base_evaluate import base_evaluate -@auto_blob_store(deep=True) @beartype async def yield_step(context: StepContext) -> StepOutcome: try: diff --git a/agents-api/agents_api/activities/utils.py b/agents-api/agents_api/activities/utils.py index d9ad1840c..cedc01695 100644 --- a/agents-api/agents_api/activities/utils.py +++ b/agents-api/agents_api/activities/utils.py @@ -304,7 +304,6 @@ def get_handler(system: SystemDef) -> Callable: from ..models.docs.delete_doc import delete_doc as delete_doc_query from ..models.docs.list_docs import list_docs as list_docs_query from ..models.session.create_session import create_session as create_session_query - from ..models.session.delete_session import delete_session as delete_session_query from ..models.session.get_session import get_session as get_session_query from ..models.session.list_sessions import list_sessions as list_sessions_query from ..models.session.update_session import update_session as update_session_query diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index af73e8015..d809e0a35 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -14,7 +14,6 @@ model_validator, ) -from ..common.storage_handler import RemoteObject from ..common.utils.datetime import utcnow from .Agents import * from .Chat import * @@ -358,7 +357,7 @@ def validate_subworkflows(self): class SystemDef(SystemDef): - arguments: dict[str, Any] | None | RemoteObject = None + arguments: dict[str, Any] | None = None class CreateTransitionRequest(Transition): diff --git a/agents-api/agents_api/clients/async_s3.py b/agents-api/agents_api/clients/async_s3.py index 0cd5235ee..b6ba76d8b 100644 --- a/agents-api/agents_api/clients/async_s3.py +++ b/agents-api/agents_api/clients/async_s3.py @@ -16,6 +16,7 @@ ) +@alru_cache(maxsize=1024) async def list_buckets() -> list[str]: session = get_session() diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index da2d7f6fa..cd2178d95 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -1,3 +1,4 @@ +import asyncio from datetime import timedelta from uuid import UUID @@ -12,9 +13,9 @@ from temporalio.runtime import PrometheusConfig, Runtime, TelemetryConfig from ..autogen.openapi_model import TransitionTarget +from ..common.interceptors import offload_if_large from ..common.protocol.tasks import ExecutionInput from ..common.retry_policies import DEFAULT_RETRY_POLICY -from ..common.storage_handler import store_in_blob_store_if_large from ..env import ( temporal_client_cert, temporal_metrics_bind_host, @@ -96,8 +97,10 @@ async def run_task_execution_workflow( client = client or (await get_client()) execution_id = execution_input.execution.id execution_id_key = SearchAttributeKey.for_keyword("CustomStringField") - execution_input.arguments = await store_in_blob_store_if_large( - execution_input.arguments + + old_args = execution_input.arguments + execution_input.arguments = await asyncio.gather( + *[offload_if_large(arg) for arg in old_args] ) return await client.start_workflow( diff --git a/agents-api/agents_api/common/interceptors.py b/agents-api/agents_api/common/interceptors.py index 40600a818..bfd64c374 100644 --- a/agents-api/agents_api/common/interceptors.py +++ b/agents-api/agents_api/common/interceptors.py @@ -4,8 +4,12 @@ certain types of errors that are known to be non-retryable. """ -from typing import Optional, Type +import asyncio +import sys +from functools import wraps +from typing import Any, Awaitable, Callable, Optional, Sequence, Type +from temporalio import workflow from temporalio.activity import _CompleteAsyncError as CompleteAsyncError from temporalio.exceptions import ApplicationError, FailureError, TemporalError from temporalio.service import RPCError @@ -23,7 +27,97 @@ ReadOnlyContextError, ) -from .exceptions.tasks import is_retryable_error +with workflow.unsafe.imports_passed_through(): + from ..env import blob_store_cutoff_kb, use_blob_store_for_temporal + from .exceptions.tasks import is_retryable_error + from .protocol.remote import RemoteObject + +# Common exceptions that should be re-raised without modification +PASSTHROUGH_EXCEPTIONS = ( + ContinueAsNewError, + ReadOnlyContextError, + NondeterminismError, + RPCError, + CompleteAsyncError, + TemporalError, + FailureError, + ApplicationError, +) + + +def is_too_large(result: Any) -> bool: + return sys.getsizeof(result) > blob_store_cutoff_kb * 1024 + + +async def load_if_remote[T](arg: T | RemoteObject[T]) -> T: + if use_blob_store_for_temporal and isinstance(arg, RemoteObject): + return await arg.load() + + return arg + + +async def offload_if_large[T](result: T) -> T: + if use_blob_store_for_temporal and is_too_large(result): + return await RemoteObject.from_value(result) + + return result + + +def offload_to_blob_store[S, T]( + func: Callable[[S, ExecuteActivityInput | ExecuteWorkflowInput], Awaitable[T]], +) -> Callable[ + [S, ExecuteActivityInput | ExecuteWorkflowInput], Awaitable[T | RemoteObject[T]] +]: + @wraps(func) + async def wrapper( + self, + input: ExecuteActivityInput | ExecuteWorkflowInput, + ) -> T | RemoteObject[T]: + # Load all remote arguments from the blob store + args: Sequence[Any] = input.args + + if use_blob_store_for_temporal: + input.args = await asyncio.gather(*[load_if_remote(arg) for arg in args]) + + # Execute the function + result = await func(self, input) + + # Save the result to the blob store if necessary + return await offload_if_large(result) + + return wrapper + + +async def handle_execution_with_errors[I, T]( + execution_fn: Callable[[I], Awaitable[T]], + input: I, +) -> T: + """ + Common error handling logic for both activities and workflows. + + Args: + execution_fn: Async function to execute with error handling + input: Input to the execution function + + Returns: + The result of the execution function + + Raises: + ApplicationError: For non-retryable errors + Any other exception: For retryable errors + """ + try: + return await execution_fn(input) + except PASSTHROUGH_EXCEPTIONS: + raise + except BaseException as e: + if not is_retryable_error(e): + raise ApplicationError( + str(e), + type=type(e).__name__, + non_retryable=True, + ) + raise class CustomActivityInterceptor(ActivityInboundInterceptor): @@ -35,95 +129,45 @@ class CustomActivityInterceptor(ActivityInboundInterceptor): as non-retryable errors. """ - async def execute_activity(self, input: ExecuteActivityInput): + @offload_to_blob_store + async def execute_activity(self, input: ExecuteActivityInput) -> Any: """ - 🎭 The Activity Whisperer: Handles activity execution with style and grace - - This is like a safety net for your activities - catching errors and deciding - their fate with the wisdom of a fortune cookie. + Handles activity execution by intercepting errors and determining their retry behavior. """ - try: - return await super().execute_activity(input) - except ( - ContinueAsNewError, # When you need a fresh start - ReadOnlyContextError, # When someone tries to write in a museum - NondeterminismError, # When chaos theory kicks in - RPCError, # When computers can't talk to each other - CompleteAsyncError, # When async goes wrong - TemporalError, # When time itself rebels - FailureError, # When failure is not an option, but happens anyway - ApplicationError, # When the app says "nope" - ): - raise - except BaseException as e: - if not is_retryable_error(e): - # If it's not retryable, we wrap it in a nice bow (ApplicationError) - # and mark it as non-retryable to prevent further attempts - raise ApplicationError( - str(e), - type=type(e).__name__, - non_retryable=True, - ) - # For retryable errors, we'll let Temporal retry with backoff - # Default retry policy ensures at least 2 retries - raise + return await handle_execution_with_errors( + super().execute_activity, + input, + ) class CustomWorkflowInterceptor(WorkflowInboundInterceptor): """ - 🎪 The Workflow Circus Ringmaster + Custom interceptor for Temporal workflows. - This interceptor is like a circus ringmaster - keeping all the workflow acts - running smoothly and catching any lions (errors) that escape their cages. + Handles workflow execution errors and determines their retry behavior. """ - async def execute_workflow(self, input: ExecuteWorkflowInput): + @offload_to_blob_store + async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: """ - 🎪 The Main Event: Workflow Execution Extravaganza! - - Watch as we gracefully handle errors like a trapeze artist catching their partner! + Executes workflows and handles error cases appropriately. """ - try: - return await super().execute_workflow(input) - except ( - ContinueAsNewError, # The show must go on! - ReadOnlyContextError, # No touching, please! - NondeterminismError, # When butterflies cause hurricanes - RPCError, # Lost in translation - CompleteAsyncError, # Async said "bye" too soon - TemporalError, # Time is relative, errors are absolute - FailureError, # Task failed successfully - ApplicationError, # App.exe has stopped working - ): - raise - except BaseException as e: - if not is_retryable_error(e): - # Pack the error in a nice box with a "do not retry" sticker - raise ApplicationError( - str(e), - type=type(e).__name__, - non_retryable=True, - ) - # Let it retry - everyone deserves a second (or third) chance! - raise + return await handle_execution_with_errors( + super().execute_workflow, + input, + ) class CustomInterceptor(Interceptor): """ - 🎭 The Grand Interceptor: Master of Ceremonies - - This is like the backstage manager of a theater - making sure both the - activity actors and workflow directors have their interceptor costumes on. + Main interceptor class that provides both activity and workflow interceptors. """ def intercept_activity( self, next: ActivityInboundInterceptor ) -> ActivityInboundInterceptor: """ - 🎬 Activity Interceptor Factory: Where the magic begins! - - Creating custom activity interceptors faster than a caffeinated barista - makes espresso shots. + Creates and returns a custom activity interceptor. """ return CustomActivityInterceptor(super().intercept_activity(next)) @@ -131,9 +175,6 @@ def workflow_interceptor_class( self, input: WorkflowInterceptorClassInput ) -> Optional[Type[WorkflowInboundInterceptor]]: """ - 🎪 Workflow Interceptor Class Selector - - Like a matchmaker for workflows and their interceptors - a match made in - exception handling heaven! + Returns the custom workflow interceptor class. """ return CustomWorkflowInterceptor diff --git a/agents-api/agents_api/common/protocol/remote.py b/agents-api/agents_api/common/protocol/remote.py index ce2a2a63a..86add1949 100644 --- a/agents-api/agents_api/common/protocol/remote.py +++ b/agents-api/agents_api/common/protocol/remote.py @@ -1,91 +1,34 @@ from dataclasses import dataclass -from typing import Any +from typing import Generic, Self, Type, TypeVar, cast -from temporalio import activity, workflow +from temporalio import workflow with workflow.unsafe.imports_passed_through(): - from pydantic import BaseModel - + from ...clients import async_s3 from ...env import blob_store_bucket + from ...worker.codec import deserialize, serialize -@dataclass -class RemoteObject: - key: str - bucket: str = blob_store_bucket - - -class BaseRemoteModel(BaseModel): - _remote_cache: dict[str, Any] - - class Config: - arbitrary_types_allowed = True - - def __init__(self, **data: Any): - super().__init__(**data) - self._remote_cache = {} - - async def load_item(self, item: Any | RemoteObject) -> Any: - if not activity.in_activity(): - return item - - from ..storage_handler import load_from_blob_store_if_remote - - return await load_from_blob_store_if_remote(item) +T = TypeVar("T") - async def save_item(self, item: Any) -> Any: - if not activity.in_activity(): - return item - from ..storage_handler import store_in_blob_store_if_large - - return await store_in_blob_store_if_large(item) - - async def get_attribute(self, name: str) -> Any: - if name.startswith("_"): - return super().__getattribute__(name) - - try: - value = super().__getattribute__(name) - except AttributeError: - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'" - ) - - if isinstance(value, RemoteObject): - cache = super().__getattribute__("_remote_cache") - if name in cache: - return cache[name] - - loaded_data = await self.load_item(value) - cache[name] = loaded_data - return loaded_data - - return value - - async def set_attribute(self, name: str, value: Any) -> None: - if name.startswith("_"): - super().__setattr__(name, value) - return +@dataclass +class RemoteObject(Generic[T]): + _type: Type[T] + key: str + bucket: str - stored_value = await self.save_item(value) - super().__setattr__(name, stored_value) + @classmethod + async def from_value(cls, x: T) -> Self: + await async_s3.setup() - if isinstance(stored_value, RemoteObject): - cache = self.__dict__.get("_remote_cache", {}) - cache.pop(name, None) + serialized = serialize(x) - async def load_all(self) -> None: - for name in self.model_fields_set: - await self.get_attribute(name) + key = await async_s3.add_object_with_hash(serialized) + return RemoteObject[T](key=key, bucket=blob_store_bucket, _type=type(x)) - async def unload_attribute(self, name: str) -> None: - if name in self._remote_cache: - data = self._remote_cache.pop(name) - remote_obj = await self.save_item(data) - super().__setattr__(name, remote_obj) + async def load(self) -> T: + await async_s3.setup() - async def unload_all(self) -> "BaseRemoteModel": - for name in list(self._remote_cache.keys()): - await self.unload_attribute(name) - return self + fetched = await async_s3.get_object(self.key) + return cast(self._type, deserialize(fetched)) diff --git a/agents-api/agents_api/common/protocol/sessions.py b/agents-api/agents_api/common/protocol/sessions.py index 121afe702..3b04178e1 100644 --- a/agents-api/agents_api/common/protocol/sessions.py +++ b/agents-api/agents_api/common/protocol/sessions.py @@ -103,7 +103,7 @@ def get_active_tools(self) -> list[Tool]: return active_toolset.tools - def get_chat_environment(self) -> dict[str, dict | list[dict]]: + def get_chat_environment(self) -> dict[str, dict | list[dict] | None]: """ Get the chat environment from the session data. """ diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index 430a62f36..f3bb81d07 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -1,9 +1,8 @@ -import asyncio from typing import Annotated, Any, Literal from uuid import UUID from beartype import beartype -from temporalio import activity, workflow +from temporalio import workflow from temporalio.exceptions import ApplicationError with workflow.unsafe.imports_passed_through(): @@ -33,8 +32,6 @@ Workflow, WorkflowStep, ) - from ...common.storage_handler import load_from_blob_store_if_remote - from .remote import BaseRemoteModel, RemoteObject # TODO: Maybe we should use a library for this @@ -146,16 +143,16 @@ class ExecutionInput(BaseModel): task: TaskSpecDef agent: Agent agent_tools: list[Tool | CreateToolRequest] - arguments: dict[str, Any] | RemoteObject + arguments: dict[str, Any] # Not used at the moment user: User | None = None session: Session | None = None -class StepContext(BaseRemoteModel): - execution_input: ExecutionInput | RemoteObject - inputs: list[Any] | RemoteObject +class StepContext(BaseModel): + execution_input: ExecutionInput + inputs: list[Any] cursor: TransitionTarget @computed_field @@ -242,17 +239,9 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]: return dump | execution_input - async def prepare_for_step( - self, *args, include_remote: bool = True, **kwargs - ) -> dict[str, Any]: + async def prepare_for_step(self, *args, **kwargs) -> dict[str, Any]: current_input = self.current_input inputs = self.inputs - if activity.in_activity() and include_remote: - await self.load_all() - inputs = await asyncio.gather( - *[load_from_blob_store_if_remote(input) for input in inputs] - ) - current_input = await load_from_blob_store_if_remote(current_input) # Merge execution inputs into the dump dict dump = self.model_dump(*args, **kwargs) diff --git a/agents-api/agents_api/common/storage_handler.py b/agents-api/agents_api/common/storage_handler.py deleted file mode 100644 index 42beef270..000000000 --- a/agents-api/agents_api/common/storage_handler.py +++ /dev/null @@ -1,226 +0,0 @@ -import asyncio -import sys -from datetime import timedelta -from functools import wraps -from typing import Any, Callable - -from pydantic import BaseModel -from temporalio import workflow - -from ..activities.sync_items_remote import load_inputs_remote -from ..clients import async_s3 -from ..common.protocol.remote import BaseRemoteModel, RemoteObject -from ..common.retry_policies import DEFAULT_RETRY_POLICY -from ..env import ( - blob_store_cutoff_kb, - debug, - temporal_heartbeat_timeout, - temporal_schedule_to_close_timeout, - testing, - use_blob_store_for_temporal, -) -from ..worker.codec import deserialize, serialize - - -async def store_in_blob_store_if_large(x: Any) -> RemoteObject | Any: - if not use_blob_store_for_temporal: - return x - - await async_s3.setup() - - serialized = serialize(x) - data_size = sys.getsizeof(serialized) - - if data_size > blob_store_cutoff_kb * 1024: - key = await async_s3.add_object_with_hash(serialized) - return RemoteObject(key=key) - - return x - - -async def load_from_blob_store_if_remote(x: Any | RemoteObject) -> Any: - if not use_blob_store_for_temporal: - return x - - await async_s3.setup() - - if isinstance(x, RemoteObject): - fetched = await async_s3.get_object(x.key) - return deserialize(fetched) - - elif isinstance(x, dict) and set(x.keys()) == {"bucket", "key"}: - fetched = await async_s3.get_object(x["key"]) - return deserialize(fetched) - - return x - - -# Decorator that automatically does two things: -# 1. store in blob store if the output of a function is large -# 2. load from blob store if the input is a RemoteObject - - -def auto_blob_store(f: Callable | None = None, *, deep: bool = False) -> Callable: - def auto_blob_store_decorator(f: Callable) -> Callable: - async def load_args( - args: list | tuple, kwargs: dict[str, Any] - ) -> tuple[list | tuple, dict[str, Any]]: - new_args = await asyncio.gather( - *[load_from_blob_store_if_remote(arg) for arg in args] - ) - kwargs_keys, kwargs_values = list(zip(*kwargs.items())) or ([], []) - new_kwargs = await asyncio.gather( - *[load_from_blob_store_if_remote(v) for v in kwargs_values] - ) - new_kwargs = dict(zip(kwargs_keys, new_kwargs)) - - if deep: - args = new_args - kwargs = new_kwargs - - new_args = [] - - for arg in args: - if isinstance(arg, list): - new_args.append( - await asyncio.gather( - *[load_from_blob_store_if_remote(item) for item in arg] - ) - ) - elif isinstance(arg, dict): - keys, values = list(zip(*arg.items())) or ([], []) - values = await asyncio.gather( - *[load_from_blob_store_if_remote(value) for value in values] - ) - new_args.append(dict(zip(keys, values))) - - elif isinstance(arg, BaseRemoteModel): - new_args.append(await arg.unload_all()) - - elif isinstance(arg, BaseModel): - for field in arg.model_fields.keys(): - if isinstance(getattr(arg, field), RemoteObject): - setattr( - arg, - field, - await load_from_blob_store_if_remote( - getattr(arg, field) - ), - ) - elif isinstance(getattr(arg, field), list): - setattr( - arg, - field, - await asyncio.gather( - *[ - load_from_blob_store_if_remote(item) - for item in getattr(arg, field) - ] - ), - ) - elif isinstance(getattr(arg, field), BaseRemoteModel): - setattr( - arg, - field, - await getattr(arg, field).unload_all(), - ) - - new_args.append(arg) - - else: - new_args.append(arg) - - new_kwargs = {} - - for k, v in kwargs.items(): - if isinstance(v, list): - new_kwargs[k] = await asyncio.gather( - *[load_from_blob_store_if_remote(item) for item in v] - ) - - elif isinstance(v, dict): - keys, values = list(zip(*v.items())) or ([], []) - values = await asyncio.gather( - *[load_from_blob_store_if_remote(value) for value in values] - ) - new_kwargs[k] = dict(zip(keys, values)) - - elif isinstance(v, BaseRemoteModel): - new_kwargs[k] = await v.unload_all() - - elif isinstance(v, BaseModel): - for field in v.model_fields.keys(): - if isinstance(getattr(v, field), RemoteObject): - setattr( - v, - field, - await load_from_blob_store_if_remote( - getattr(v, field) - ), - ) - elif isinstance(getattr(v, field), list): - setattr( - v, - field, - await asyncio.gather( - *[ - load_from_blob_store_if_remote(item) - for item in getattr(v, field) - ] - ), - ) - elif isinstance(getattr(v, field), BaseRemoteModel): - setattr( - v, - field, - await getattr(v, field).unload_all(), - ) - new_kwargs[k] = v - - else: - new_kwargs[k] = v - - return new_args, new_kwargs - - async def unload_return_value(x: Any | BaseRemoteModel) -> Any: - if isinstance(x, BaseRemoteModel): - await x.unload_all() - - return await store_in_blob_store_if_large(x) - - @wraps(f) - async def async_wrapper(*args, **kwargs) -> Any: - new_args, new_kwargs = await load_args(args, kwargs) - output = await f(*new_args, **new_kwargs) - - return await unload_return_value(output) - - return async_wrapper if use_blob_store_for_temporal else f - - return auto_blob_store_decorator(f) if f else auto_blob_store_decorator - - -def auto_blob_store_workflow(f: Callable) -> Callable: - @wraps(f) - async def wrapper(*args, **kwargs) -> Any: - keys = kwargs.keys() - values = [kwargs[k] for k in keys] - - loaded = await workflow.execute_activity( - load_inputs_remote, - args=[[*args, *values]], - schedule_to_close_timeout=timedelta( - seconds=60 if debug or testing else temporal_schedule_to_close_timeout - ), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ) - - loaded_args = loaded[: len(args)] - loaded_kwargs = dict(zip(keys, loaded[len(args) :])) - - result = await f(*loaded_args, **loaded_kwargs) - - return result - - return wrapper if use_blob_store_for_temporal else f diff --git a/agents-api/agents_api/env.py b/agents-api/agents_api/env.py index 8b9fd4dae..7baa24653 100644 --- a/agents-api/agents_api/env.py +++ b/agents-api/agents_api/env.py @@ -36,8 +36,8 @@ # Blob Store # ---------- -use_blob_store_for_temporal: bool = ( - env.bool("USE_BLOB_STORE_FOR_TEMPORAL", default=False) if not testing else False +use_blob_store_for_temporal: bool = testing or env.bool( + "USE_BLOB_STORE_FOR_TEMPORAL", default=False ) blob_store_bucket: str = env.str("BLOB_STORE_BUCKET", default="agents-api") diff --git a/agents-api/agents_api/routers/healthz/check_health.py b/agents-api/agents_api/routers/healthz/check_health.py new file mode 100644 index 000000000..5a466ba39 --- /dev/null +++ b/agents-api/agents_api/routers/healthz/check_health.py @@ -0,0 +1,19 @@ +import logging +from uuid import UUID + +from ...models.agent.list_agents import list_agents as list_agents_query +from .router import router + + +@router.get("/healthz", tags=["healthz"]) +async def check_health() -> dict: + try: + # Check if the database is reachable + list_agents_query( + developer_id=UUID("00000000-0000-0000-0000-000000000000"), + ) + except Exception as e: + logging.error("An error occurred while checking health: %s", str(e)) + return {"status": "error", "message": "An internal error has occurred."} + + return {"status": "ok"} diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 6ea9239df..a76c13975 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -15,7 +15,7 @@ from ...activities.excecute_api_call import execute_api_call from ...activities.execute_integration import execute_integration from ...activities.execute_system import execute_system - from ...activities.sync_items_remote import load_inputs_remote, save_inputs_remote + from ...activities.sync_items_remote import save_inputs_remote from ...autogen.openapi_model import ( ApiCallDef, BaseIntegrationDef, @@ -214,16 +214,6 @@ async def run( # 3. Then, based on the outcome and step type, decide what to do next workflow.logger.info(f"Processing outcome for step {context.cursor.step}") - [outcome] = await workflow.execute_activity( - load_inputs_remote, - args=[[outcome]], - schedule_to_close_timeout=timedelta( - seconds=60 if debug or testing else temporal_schedule_to_close_timeout - ), - retry_policy=DEFAULT_RETRY_POLICY, - heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout), - ) - # Init state state = None diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index 1d68322f5..b2df640a7 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -19,11 +19,9 @@ ExecutionInput, StepContext, ) - from ...common.storage_handler import auto_blob_store_workflow from ...env import task_max_parallelism, temporal_heartbeat_timeout -@auto_blob_store_workflow async def continue_as_child( execution_input: ExecutionInput, start: TransitionTarget, @@ -50,7 +48,6 @@ async def continue_as_child( ) -@auto_blob_store_workflow async def execute_switch_branch( *, context: StepContext, @@ -84,7 +81,6 @@ async def execute_switch_branch( ) -@auto_blob_store_workflow async def execute_if_else_branch( *, context: StepContext, @@ -123,7 +119,6 @@ async def execute_if_else_branch( ) -@auto_blob_store_workflow async def execute_foreach_step( *, context: StepContext, @@ -161,7 +156,6 @@ async def execute_foreach_step( return results -@auto_blob_store_workflow async def execute_map_reduce_step( *, context: StepContext, @@ -209,7 +203,6 @@ async def execute_map_reduce_step( return result -@auto_blob_store_workflow async def execute_map_reduce_step_parallel( *, context: StepContext, From ca5f4e24a2cedcab3d3bad10b70996b3edd54a27 Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Thu, 19 Dec 2024 19:50:21 +0530 Subject: [PATCH 29/29] fix(agents-api): Minor fixes Signed-off-by: Diwank Singh Tomer --- agents-api/agents_api/activities/utils.py | 1 + agents-api/agents_api/queries/sessions/create_session.py | 2 -- agents-api/tests/fixtures.py | 1 - agents-api/tests/test_session_queries.py | 1 - 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/agents-api/agents_api/activities/utils.py b/agents-api/agents_api/activities/utils.py index cedc01695..d9ad1840c 100644 --- a/agents-api/agents_api/activities/utils.py +++ b/agents-api/agents_api/activities/utils.py @@ -304,6 +304,7 @@ def get_handler(system: SystemDef) -> Callable: from ..models.docs.delete_doc import delete_doc as delete_doc_query from ..models.docs.list_docs import list_docs as list_docs_query from ..models.session.create_session import create_session as create_session_query + from ..models.session.delete_session import delete_session as delete_session_query from ..models.session.get_session import get_session as get_session_query from ..models.session.list_sessions import list_sessions as list_sessions_query from ..models.session.update_session import update_session as update_session_query diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py index 63fbdc940..058462cf8 100644 --- a/agents-api/agents_api/queries/sessions/create_session.py +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -8,10 +8,8 @@ from ...autogen.openapi_model import ( CreateSessionRequest, - ResourceCreatedResponse, Session, ) -from ...common.utils.datetime import utcnow from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 49c2e7094..e1d286c9c 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -1,6 +1,5 @@ import random import string -import time from uuid import UUID from fastapi.testclient import TestClient diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 5f2190e2b..7926a391f 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -10,7 +10,6 @@ CreateOrUpdateSessionRequest, CreateSessionRequest, PatchSessionRequest, - ResourceCreatedResponse, ResourceDeletedResponse, ResourceUpdatedResponse, Session,