From bbdbb4b369649073fa2334b05e99d34eb44585f4 Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Thu, 19 Dec 2024 12:03:30 +0300 Subject: [PATCH] fix(agents-api): fix sessions and agents queries / tests --- .../queries/entries/create_entries.py | 2 +- .../sessions/create_or_update_session.py | 32 +++++------ .../queries/sessions/create_session.py | 23 +++++--- .../queries/sessions/patch_session.py | 51 +---------------- .../queries/sessions/update_session.py | 56 +++---------------- agents-api/agents_api/queries/utils.py | 17 +++--- agents-api/tests/fixtures.py | 10 ++-- agents-api/tests/test_agent_queries.py | 5 +- agents-api/tests/test_session_queries.py | 49 +++++++--------- 9 files changed, 78 insertions(+), 167 deletions(-) diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py index fb61b7c7e..33dcda984 100644 --- a/agents-api/agents_api/queries/entries/create_entries.py +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -173,7 +173,7 @@ async def add_entry_relations( ( session_exists_query, [session_id, developer_id], - "fetch", + "fetchrow", ), ( entry_relation_query, diff --git a/agents-api/agents_api/queries/sessions/create_or_update_session.py b/agents-api/agents_api/queries/sessions/create_or_update_session.py index bc54bf31b..26a353e94 100644 --- a/agents-api/agents_api/queries/sessions/create_or_update_session.py +++ b/agents-api/agents_api/queries/sessions/create_or_update_session.py @@ -61,11 +61,7 @@ participant_type, participant_id ) -SELECT - $1 as developer_id, - $2 as session_id, - unnest($3::participant_type[]) as participant_type, - unnest($4::uuid[]) as participant_id; +VALUES ($1, $2, $3, $4); """).sql(pretty=True) @@ -83,16 +79,23 @@ ), } ) -@wrap_in_class(ResourceUpdatedResponse, one=True) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: { + "id": d["session_id"], + "updated_at": d["updated_at"], + }, +) @increase_counter("create_or_update_session") -@pg_query +@pg_query(return_index=0) @beartype async def create_or_update_session( *, developer_id: UUID, session_id: UUID, data: CreateOrUpdateSessionRequest, -) -> list[tuple[str, list]]: +) -> list[tuple[str, list] | tuple[str, list, str]]: """ Constructs SQL queries to create or update a session and its participant lookups. @@ -139,14 +142,11 @@ async def create_or_update_session( ] # Prepare lookup parameters - lookup_params = [ - developer_id, # $1 - session_id, # $2 - participant_types, # $3 - participant_ids, # $4 - ] + lookup_params = [] + for participant_type, participant_id in zip(participant_types, participant_ids): + lookup_params.append([developer_id, session_id, participant_type, participant_id]) return [ - (session_query, session_params), - (lookup_query, lookup_params), + (session_query, session_params, "fetch"), + (lookup_query, lookup_params, "fetchmany"), ] diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py index baa3f09d1..91badb281 100644 --- a/agents-api/agents_api/queries/sessions/create_session.py +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -1,12 +1,14 @@ from uuid import UUID +from uuid_extensions import uuid7 import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one -from ...autogen.openapi_model import CreateSessionRequest, Session +from ...autogen.openapi_model import CreateSessionRequest, Session, ResourceCreatedResponse from ...metrics.counters import increase_counter +from ...common.utils.datetime import utcnow from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL queries @@ -63,14 +65,21 @@ ), } ) -@wrap_in_class(Session, transform=lambda d: {**d, "id": d["session_id"]}) +@wrap_in_class( + Session, + one=True, + transform=lambda d: { + **d, + "id": d["session_id"], + }, +) @increase_counter("create_session") -@pg_query +@pg_query(return_index=0) @beartype async def create_session( *, developer_id: UUID, - session_id: UUID, + session_id: UUID | None = None, data: CreateSessionRequest, ) -> list[tuple[str, list] | tuple[str, list, str]]: """ @@ -87,6 +96,7 @@ async def create_session( # Handle participants users = data.users or ([data.user] if data.user else []) agents = data.agents or ([data.agent] if data.agent else []) + session_id = session_id or uuid7() if not agents: raise HTTPException( @@ -123,10 +133,7 @@ async def create_session( for ptype, pid in zip(participant_types, participant_ids): lookup_params.append([developer_id, session_id, ptype, pid]) - print("*" * 100) - print(lookup_params) - print("*" * 100) return [ - (session_query, session_params), + (session_query, session_params, "fetch"), (lookup_query, lookup_params, "fetchmany"), ] diff --git a/agents-api/agents_api/queries/sessions/patch_session.py b/agents-api/agents_api/queries/sessions/patch_session.py index b14b94a8a..60d82468e 100644 --- a/agents-api/agents_api/queries/sessions/patch_session.py +++ b/agents-api/agents_api/queries/sessions/patch_session.py @@ -31,25 +31,6 @@ SELECT * FROM updated_session; """).sql(pretty=True) -lookup_query = parse_one(""" -WITH deleted_lookups AS ( - DELETE FROM session_lookup - WHERE developer_id = $1 AND session_id = $2 -) -INSERT INTO session_lookup ( - developer_id, - session_id, - participant_type, - participant_id -) -SELECT - $1 as developer_id, - $2 as session_id, - unnest($3::participant_type[]) as participant_type, - unnest($4::uuid[]) as participant_id; -""").sql(pretty=True) - - @rewrap_exceptions( { asyncpg.ForeignKeyViolationError: partialclass( @@ -64,7 +45,7 @@ ), } ) -@wrap_in_class(ResourceUpdatedResponse, one=True) +@wrap_in_class(ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["session_id"], "updated_at": d["updated_at"]},) @increase_counter("patch_session") @pg_query @beartype @@ -85,22 +66,6 @@ async def patch_session( Returns: list[tuple[str, list]]: List of SQL queries and their parameters """ - # Handle participants - users = data.users or ([data.user] if data.user else []) - agents = data.agents or ([data.agent] if data.agent else []) - - if data.agent and data.agents: - raise HTTPException( - status_code=400, - detail="Only one of 'agent' or 'agents' should be provided", - ) - - # Prepare participant arrays for lookup query if participants are provided - participant_types = [] - participant_ids = [] - if users or agents: - participant_types = ["user"] * len(users) + ["agent"] * len(agents) - participant_ids = [str(u) for u in users] + [str(a) for a in agents] # Extract fields from data, using None for unset fields session_params = [ @@ -116,16 +81,4 @@ async def patch_session( data.recall_options or {}, # $10 ] - queries = [(session_query, session_params)] - - # Only add lookup query if participants are provided - if participant_types: - lookup_params = [ - developer_id, # $1 - session_id, # $2 - participant_types, # $3 - participant_ids, # $4 - ] - queries.append((lookup_query, lookup_params)) - - return queries + return [(session_query, session_params)] diff --git a/agents-api/agents_api/queries/sessions/update_session.py b/agents-api/agents_api/queries/sessions/update_session.py index 01e21e732..7c58d10e6 100644 --- a/agents-api/agents_api/queries/sessions/update_session.py +++ b/agents-api/agents_api/queries/sessions/update_session.py @@ -27,24 +27,6 @@ RETURNING *; """).sql(pretty=True) -lookup_query = parse_one(""" -WITH deleted_lookups AS ( - DELETE FROM session_lookup - WHERE developer_id = $1 AND session_id = $2 -) -INSERT INTO session_lookup ( - developer_id, - session_id, - participant_type, - participant_id -) -SELECT - $1 as developer_id, - $2 as session_id, - unnest($3::participant_type[]) as participant_type, - unnest($4::uuid[]) as participant_id; -""").sql(pretty=True) - @rewrap_exceptions( { @@ -60,7 +42,14 @@ ), } ) -@wrap_in_class(ResourceUpdatedResponse, one=True) +@wrap_in_class( + ResourceUpdatedResponse, + one=True, + transform=lambda d: { + "id": d["session_id"], + "updated_at": d["updated_at"], + }, +) @increase_counter("update_session") @pg_query @beartype @@ -81,26 +70,6 @@ async def update_session( Returns: list[tuple[str, list]]: List of SQL queries and their parameters """ - # Handle participants - users = data.users or ([data.user] if data.user else []) - agents = data.agents or ([data.agent] if data.agent else []) - - if not agents: - raise HTTPException( - status_code=400, - detail="At least one agent must be provided", - ) - - if data.agent and data.agents: - raise HTTPException( - status_code=400, - detail="Only one of 'agent' or 'agents' should be provided", - ) - - # Prepare participant arrays for lookup query - participant_types = ["user"] * len(users) + ["agent"] * len(agents) - participant_ids = [str(u) for u in users] + [str(a) for a in agents] - # Prepare session parameters session_params = [ developer_id, # $1 @@ -115,15 +84,6 @@ async def update_session( data.recall_options or {}, # $10 ] - # Prepare lookup parameters - lookup_params = [ - developer_id, # $1 - session_id, # $2 - participant_types, # $3 - participant_ids, # $4 - ] - return [ (session_query, session_params), - (lookup_query, lookup_params), ] diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index 73113580d..4126c91dc 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -123,6 +123,7 @@ def pg_query( debug: bool | None = None, only_on_error: bool = False, timeit: bool = False, + return_index: int = -1, ) -> Callable[..., Callable[P, list[Record]]] | Callable[P, list[Record]]: def pg_query_dec( func: Callable[P, PGQueryArgs | list[PGQueryArgs]], @@ -159,6 +160,8 @@ async def wrapper( async with pool.acquire() as conn: async with conn.transaction(): start = timeit and time.perf_counter() + all_results = [] + for method_name, payload in batch: method = getattr(conn, method_name) @@ -169,11 +172,7 @@ async def wrapper( results: list[Record] = await method( query, *args, timeout=timeout ) - - print("%" * 100) - print(results) - print(*args) - print("%" * 100) + all_results.append(results) if method_name == "fetchrow" and ( len(results) == 0 or results.get("bool") is None @@ -204,9 +203,11 @@ async def wrapper( raise - not only_on_error and debug and pprint(results) - - return results + # Return results from specified index + results_to_return = all_results[return_index] if all_results else [] + not only_on_error and debug and pprint(results_to_return) + + return results_to_return # Set the wrapped function as an attribute of the wrapper, # forwards the __wrapped__ attribute if it exists. diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 9153785a4..49c2e7094 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -96,7 +96,7 @@ def patch_embed_acompletion(): yield embed, acompletion -@fixture(scope="global") +@fixture(scope="test") async def test_agent(dsn=pg_dsn, developer=test_developer): pool = await create_db_pool(dsn=dsn) @@ -105,18 +105,16 @@ async def test_agent(dsn=pg_dsn, developer=test_developer): data=CreateAgentRequest( model="gpt-4o-mini", name="test agent", - canonical_name=f"test_agent_{str(int(time.time()))}", about="test agent about", metadata={"test": "test"}, ), connection_pool=pool, ) - yield agent - await pool.close() + return agent -@fixture(scope="global") +@fixture(scope="test") async def test_user(dsn=pg_dsn, developer=test_developer): pool = await create_db_pool(dsn=dsn) @@ -153,7 +151,7 @@ async def test_new_developer(dsn=pg_dsn, email=random_email): return developer -@fixture(scope="global") +@fixture(scope="test") async def test_session( dsn=pg_dsn, developer_id=test_developer_id, diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index b6cb7aedc..594047a82 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -41,7 +41,8 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): ) -@test("query: create agent with instructions sql") + +@test("query: create or update agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that an agent can be successfully created or updated.""" @@ -60,6 +61,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): ) + @test("query: update agent sql") async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): """Test that an existing agent's information can be successfully updated.""" @@ -81,7 +83,6 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): assert result is not None assert isinstance(result, ResourceUpdatedResponse) - @test("query: get agent not exists sql") async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that retrieving a non-existent agent raises an exception.""" diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 4e04468bf..ec2e511d4 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -7,13 +7,15 @@ from ward import raises, test from agents_api.autogen.openapi_model import ( + Session, CreateOrUpdateSessionRequest, CreateSessionRequest, + UpdateSessionRequest, PatchSessionRequest, ResourceDeletedResponse, ResourceUpdatedResponse, - Session, - UpdateSessionRequest, + ResourceDeletedResponse, + ResourceCreatedResponse, ) from agents_api.clients.pg import create_db_pool from agents_api.queries.sessions import ( @@ -46,7 +48,6 @@ async def _( data = CreateSessionRequest( users=[user.id], agents=[agent.id], - situation="test session", system_template="test system template", ) result = await create_session( @@ -59,10 +60,6 @@ async def _( assert result is not None assert isinstance(result, Session), f"Result is not a Session, {result}" assert result.id == session_id - assert result.developer_id == developer_id - assert result.situation == "test session" - assert set(result.users) == {user.id} - assert set(result.agents) == {agent.id} @test("query: create or update session sql") @@ -76,7 +73,7 @@ async def _( data = CreateOrUpdateSessionRequest( users=[user.id], agents=[agent.id], - situation="test session", + system_template="test system template", ) result = await create_or_update_session( developer_id=developer_id, @@ -86,12 +83,9 @@ async def _( ) assert result is not None - assert isinstance(result, Session) + assert isinstance(result, ResourceUpdatedResponse) assert result.id == session_id - assert result.developer_id == developer_id - assert result.situation == "test session" - assert set(result.users) == {user.id} - assert set(result.agents) == {agent.id} + assert result.updated_at is not None @test("query: get session exists") @@ -108,7 +102,6 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): assert result is not None assert isinstance(result, Session) assert result.id == session.id - assert result.developer_id == developer_id @test("query: get session does not exist") @@ -130,7 +123,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): """Test listing sessions with default pagination.""" pool = await create_db_pool(dsn=dsn) - result, _ = await list_sessions( + result = await list_sessions( developer_id=developer_id, limit=10, offset=0, @@ -147,17 +140,18 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): """Test listing sessions with specific filters.""" pool = await create_db_pool(dsn=dsn) - result, _ = await list_sessions( + result = await list_sessions( developer_id=developer_id, limit=10, offset=0, - filters={"situation": "test session"}, connection_pool=pool, ) assert isinstance(result, list) assert len(result) >= 1 - assert all(s.situation == "test session" for s in result) + assert all( + s.situation == session.situation for s in result + ), f"Result is not a list of sessions, {result}, {session.situation}" @test("query: count sessions") @@ -170,20 +164,21 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): connection_pool=pool, ) - assert isinstance(count, int) - assert count >= 1 + assert isinstance(count, dict) + assert count["count"] >= 1 @test("query: update session sql") async def _( - dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent + dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent, user=test_user ): """Test that an existing session's information can be successfully updated.""" pool = await create_db_pool(dsn=dsn) data = UpdateSessionRequest( - agents=[agent.id], - situation="updated session", + token_budget=1000, + forward_tool_calls=True, + system_template="updated system template", ) result = await update_session( session_id=session.id, @@ -201,8 +196,7 @@ async def _( session_id=session.id, connection_pool=pool, ) - assert updated_session.situation == "updated session" - assert set(updated_session.agents) == {agent.id} + assert updated_session.forward_tool_calls is True @test("query: patch session sql") @@ -213,8 +207,6 @@ async def _( pool = await create_db_pool(dsn=dsn) data = PatchSessionRequest( - agents=[agent.id], - situation="patched session", metadata={"test": "metadata"}, ) result = await patch_session( @@ -233,8 +225,7 @@ async def _( session_id=session.id, connection_pool=pool, ) - assert patched_session.situation == "patched session" - assert set(patched_session.agents) == {agent.id} + assert patched_session.situation == session.situation assert patched_session.metadata == {"test": "metadata"}