From 967c9972d7b26e42d850eace4b7ee8d5f1f64fff Mon Sep 17 00:00:00 2001 From: Vineeth Voruganti <13438633+VVoruganti@users.noreply.github.com> Date: Thu, 14 Nov 2024 21:45:08 -0500 Subject: [PATCH] fix(deriver) Add more detailed profiling with sentry --- src/deriver/consumer.py | 3 ++ src/deriver/queue.py | 85 ++++++++++++++++++---------------- src/deriver/voe.py | 3 ++ tests/routes/test_documents.py | 8 ++-- uv.lock | 2 +- 5 files changed, 56 insertions(+), 45 deletions(-) diff --git a/src/deriver/consumer.py b/src/deriver/consumer.py index 5d76222..128f143 100644 --- a/src/deriver/consumer.py +++ b/src/deriver/consumer.py @@ -1,6 +1,7 @@ import logging import re +import sentry_sdk from rich.console import Console from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -47,6 +48,7 @@ async def process_item(db: AsyncSession, payload: dict): return +@sentry_sdk.trace async def process_ai_message( content: str, app_id: str, @@ -102,6 +104,7 @@ async def process_ai_message( console.print(content_lines, style="blue") +@sentry_sdk.trace async def process_user_message( content: str, app_id: str, diff --git a/src/deriver/queue.py b/src/deriver/queue.py index 0f840a4..3bf585f 100644 --- a/src/deriver/queue.py +++ b/src/deriver/queue.py @@ -34,6 +34,7 @@ async def get_next_message_for_session( return result.scalar_one_or_none() +@sentry_sdk.trace async def process_session_messages(session_id: int): async with SessionLocal() as db: try: @@ -41,7 +42,6 @@ async def process_session_messages(session_id: int): message = await get_next_message_for_session(db, session_id) if not message: break - try: await process_item(db, payload=message.payload) except Exception as e: @@ -69,6 +69,7 @@ async def process_session_messages(session_id: int): await db.commit() +@sentry_sdk.trace async def get_available_sessions(db: AsyncSession, limit: int) -> Sequence[Any]: # First, clean up stale sessions (e.g., older than 5 minutes) five_minutes_ago = datetime.utcnow() - timedelta(minutes=5) @@ -95,45 +96,48 @@ async def get_available_sessions(db: AsyncSession, limit: int) -> Sequence[Any]: return result.scalars().all() +@sentry_sdk.trace async def schedule_session( semaphore: asyncio.Semaphore, queue_empty_flag: asyncio.Event ): - async with semaphore, SessionLocal() as db: - try: - available_slots = semaphore._value - # print(available_slots) - new_sessions = await get_available_sessions(db, available_slots) - - if new_sessions: - tasks = [] - for session_id in new_sessions: - try: - # Try to insert the session into active_sessions - await db.execute( - insert(models.ActiveQueueSession).values( - session_id=session_id + async with ( + semaphore, + SessionLocal() as db, + ): + with sentry_sdk.start_transaction( + op="deriver_schedule_session", name="Schedule Deriver Session" + ): + try: + # available_slots = semaphore._value + # print(available_slots) + new_sessions = await get_available_sessions(db, 1) + + if new_sessions: + for session_id in new_sessions: + try: + # Try to insert the session into active_sessions + await db.execute( + insert(models.ActiveQueueSession).values( + session_id=session_id + ) ) - ) - await db.commit() - - # If successful, create a task for this session - # Pass enable_timing to process_session_messages - asyncio.create_task(process_session_messages(session_id)) - except IntegrityError: - # If the session is already in active_sessions, skip it - await db.rollback() - - if tasks: - await asyncio.gather(*tasks) - else: - # No items to process, set the queue_empty_flag - queue_empty_flag.set() - except Exception as e: - rprint("==========") - rprint("Exception") - rprint(e) - rprint("==========") - await db.rollback() + await db.commit() + + # If successful, create a task for this session + await process_session_messages(session_id) + except IntegrityError: + # If the session is already in active_sessions, skip it + await db.rollback() + + else: + # No items to process, set the queue_empty_flag + queue_empty_flag.set() + except Exception as e: + rprint("==========") + rprint("Exception") + rprint(e) + rprint("==========") + await db.rollback() async def polling_loop(semaphore: asyncio.Semaphore, queue_empty_flag: asyncio.Event): @@ -145,8 +149,9 @@ async def polling_loop(semaphore: asyncio.Semaphore, queue_empty_flag: asyncio.E if semaphore.locked(): await asyncio.sleep(1) # Sleep briefly if the semaphore is fully locked continue - await schedule_session(semaphore, queue_empty_flag) - # await asyncio.sleep(0) # Yield control to allow tasks to run + # Create a task instead of awaiting + asyncio.create_task(schedule_session(semaphore, queue_empty_flag)) + await asyncio.sleep(0) # Give other tasks a chance to run async def main(): @@ -155,8 +160,8 @@ async def main(): sentry_sdk.init( dsn=os.getenv("SENTRY_DSN"), enable_tracing=True, - traces_sample_rate=1.0, - profiles_sample_rate=1.0, + traces_sample_rate=0.4, + profiles_sample_rate=0.4, integrations=[ AsyncioIntegration(), ], diff --git a/src/deriver/voe.py b/src/deriver/voe.py index d2763e8..20814f6 100644 --- a/src/deriver/voe.py +++ b/src/deriver/voe.py @@ -1,11 +1,13 @@ import os from anthropic import Anthropic +import sentry_sdk # Initialize the Anthropic client anthropic = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) +@sentry_sdk.trace async def tom_inference( chat_history: str, session_id: str, user_representation: str = "None" ) -> str: @@ -65,6 +67,7 @@ async def tom_inference( return message.content[0].text +@sentry_sdk.trace async def user_representation( chat_history: str, session_id: str, diff --git a/tests/routes/test_documents.py b/tests/routes/test_documents.py index 01f2432..8819d0d 100644 --- a/tests/routes/test_documents.py +++ b/tests/routes/test_documents.py @@ -78,7 +78,7 @@ def test_get_documents(client, sample_data): ) assert response.status_code == 200 data = response.json() - assert len(data) == 3 + assert len(data["items"]) == 3 response = client.post( f"/v1/apps/{test_app.public_id}/users/{test_user.public_id}/collections/{collection['id']}/documents/list", @@ -86,9 +86,9 @@ def test_get_documents(client, sample_data): ) assert response.status_code == 200 data = response.json() - assert len(data) == 2 - assert data[0]["metadata"]["test"] == "key" - assert data[1]["metadata"]["test"] == "key" + assert len(data["items"]) == 2 + assert data["items"][0]["metadata"]["test"] == "key" + assert data["items"][1]["metadata"]["test"] == "key" def test_query_documents(client, sample_data): diff --git a/uv.lock b/uv.lock index 40a79ad..17a498d 100644 --- a/uv.lock +++ b/uv.lock @@ -425,7 +425,7 @@ wheels = [ [[package]] name = "honcho" -version = "0.0.13" +version = "0.0.14" source = { virtual = "." } dependencies = [ { name = "anthropic" },