diff --git a/src/deriver/__main__.py b/src/deriver/__main__.py index cc3e7c6..165d01c 100644 --- a/src/deriver/__main__.py +++ b/src/deriver/__main__.py @@ -6,4 +6,7 @@ if __name__ == "__main__": asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - asyncio.run(main()) + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("Shutdown initiated via KeyboardInterrupt") diff --git a/src/deriver/queue.py b/src/deriver/queue.py index 3bf585f..dbef817 100644 --- a/src/deriver/queue.py +++ b/src/deriver/queue.py @@ -1,12 +1,10 @@ import asyncio import os -from collections.abc import Sequence +import signal from datetime import datetime, timedelta -from typing import Any, Optional import sentry_sdk from dotenv import load_dotenv -from rich import print as rprint from sentry_sdk.integrations.asyncio import AsyncioIntegration from sqlalchemy import delete, insert, select, update from sqlalchemy.exc import IntegrityError @@ -20,153 +18,214 @@ load_dotenv() -async def get_next_message_for_session( - db: AsyncSession, session_id: int -) -> Optional[models.QueueItem]: - result = await db.execute( - select(models.QueueItem) - .where(models.QueueItem.session_id == session_id) - .where(models.QueueItem.processed == False) - .order_by(models.QueueItem.id) - .with_for_update(skip_locked=True) - .limit(1) - ) - return result.scalar_one_or_none() +class QueueManager: + def __init__(self): + self.shutdown_event = asyncio.Event() + self.active_tasks: set[asyncio.Task] = set() + self.owned_sessions: set[int] = set() + self.queue_empty_flag = asyncio.Event() + + # Initialize from environment + self.workers = int(os.getenv("DERIVER_WORKERS", 1)) + self.semaphore = asyncio.Semaphore(self.workers) + + # Initialize Sentry if enabled + if os.getenv("SENTRY_ENABLED", "False").lower() == "true": + sentry_sdk.init( + dsn=os.getenv("SENTRY_DSN"), + enable_tracing=True, + traces_sample_rate=0.1, + profiles_sample_rate=0.1, + integrations=[AsyncioIntegration()], + ) + def add_task(self, task: asyncio.Task): + """Track a new task""" + self.active_tasks.add(task) + task.add_done_callback(self.active_tasks.discard) + + def track_session(self, session_id: int): + """Track a new session owned by this process""" + self.owned_sessions.add(session_id) + + def untrack_session(self, session_id: int): + """Remove a session from tracking""" + self.owned_sessions.discard(session_id) + + async def initialize(self): + """Setup signal handlers and start the main polling loop""" + loop = asyncio.get_running_loop() + signals = (signal.SIGTERM, signal.SIGINT) + for sig in signals: + loop.add_signal_handler( + sig, lambda s=sig: asyncio.create_task(self.shutdown(s)) + ) -@sentry_sdk.trace -async def process_session_messages(session_id: int): - async with SessionLocal() as db: try: - while True: - message = await get_next_message_for_session(db, session_id) - if not message: - break - try: - await process_item(db, payload=message.payload) - except Exception as e: - print(e) - sentry_sdk.capture_exception(e) - finally: - # Prevent malformed messages from stalling a queue indefinitely - message.processed = True - await db.commit() - - # Update last_updated to show this session is still being processed + await self.polling_loop() + finally: + await self.cleanup() + + async def shutdown(self, sig: signal.Signals): + """Handle graceful shutdown""" + print(f"Received exit signal {sig.name}...") + self.shutdown_event.set() + + if self.active_tasks: + print(f"Waiting for {len(self.active_tasks)} active tasks to complete...") + await asyncio.gather(*self.active_tasks, return_exceptions=True) + + async def cleanup(self): + """Clean up owned sessions""" + if self.owned_sessions: + print(f"Cleaning up {len(self.owned_sessions)} owned sessions...") + async with SessionLocal() as db: await db.execute( - update(models.ActiveQueueSession) - .where(models.ActiveQueueSession.session_id == session_id) - .values(last_updated=func.now()) + delete(models.ActiveQueueSession).where( + models.ActiveQueueSession.session_id.in_(self.owned_sessions) + ) ) await db.commit() - finally: - # Remove session from active_sessions when done - await db.execute( - delete(models.ActiveQueueSession).where( - models.ActiveQueueSession.session_id == session_id - ) - ) - await db.commit() - -@sentry_sdk.trace -async def get_available_sessions(db: AsyncSession, limit: int) -> Sequence[Any]: - # First, clean up stale sessions (e.g., older than 5 minutes) - five_minutes_ago = datetime.utcnow() - timedelta(minutes=5) - await db.execute( - delete(models.ActiveQueueSession).where( - models.ActiveQueueSession.last_updated < five_minutes_ago + ########################## + # Polling and Scheduling # + ########################## + + async def get_available_sessions(self, db: AsyncSession): + """Get available sessions that aren't being processed""" + # Clean up stale sessions + five_minutes_ago = datetime.utcnow() - timedelta(minutes=5) + await db.execute( + delete(models.ActiveQueueSession).where( + models.ActiveQueueSession.last_updated < five_minutes_ago + ) ) - ) - - # Then get available sessions - result = await db.execute( - select(models.QueueItem.session_id) - .outerjoin( - models.ActiveQueueSession, - models.QueueItem.session_id == models.ActiveQueueSession.session_id, + + # Get available sessions + result = await db.execute( + select(models.QueueItem.session_id) + .outerjoin( + models.ActiveQueueSession, + models.QueueItem.session_id == models.ActiveQueueSession.session_id, + ) + .where(models.QueueItem.processed == False) + .where( + models.ActiveQueueSession.session_id == None + ) # Only sessions not in active_sessions + .group_by(models.QueueItem.session_id) + .limit(1) ) - .where(models.QueueItem.processed == False) - .where( - models.ActiveQueueSession.session_id == None - ) # Only sessions not in active_sessions - .group_by(models.QueueItem.session_id) - .limit(limit) - ) - return result.scalars().all() - - -@sentry_sdk.trace -async def schedule_session( - semaphore: asyncio.Semaphore, queue_empty_flag: asyncio.Event -): - async with ( - semaphore, - SessionLocal() as db, - ): - with sentry_sdk.start_transaction( - op="deriver_schedule_session", name="Schedule Deriver Session" - ): - try: - # available_slots = semaphore._value - # print(available_slots) - new_sessions = await get_available_sessions(db, 1) - - if new_sessions: - for session_id in new_sessions: + return result.scalars().all() + + async def polling_loop(self): + """Main polling loop to find and process new sessions""" + try: + while not self.shutdown_event.is_set(): + if self.queue_empty_flag.is_set(): + await asyncio.sleep(1) + self.queue_empty_flag.clear() + continue + + # Chec if we have capacity before querying + if self.semaphore.locked(): + await asyncio.sleep(1) # Wait before trying again + continue + + async with SessionLocal() as db: + try: + new_sessions = await self.get_available_sessions(db) + + if new_sessions and not self.shutdown_event.is_set(): + for session_id in new_sessions: + try: + # Try to claim the session + await db.execute( + insert(models.ActiveQueueSession).values( + session_id=session_id + ) + ) + await db.commit() + + # Track this session + self.track_session(session_id) + + # Create a new task for processing this session + if not self.shutdown_event.is_set(): + task = asyncio.create_task( + self.process_session(session_id) + ) + self.add_task(task) + except IntegrityError: + await db.rollback() + else: + self.queue_empty_flag.set() + await asyncio.sleep(1) + except Exception as e: + print(f"Error in polling loop: {e}") + sentry_sdk.capture_exception(e) + await db.rollback() + await asyncio.sleep(1) + finally: + print("Polling loop stopped") + + ###################### + # Queue Worker Logic # + ###################### + + @sentry_sdk.trace + async def process_session(self, session_id: int): + """Process all messages for a session""" + async with self.semaphore: # Hold the semaphore for the entire session duration + async with SessionLocal() as db: + try: + while not self.shutdown_event.is_set(): + message = await self.get_next_message(db, session_id) + if not message: + break try: - # Try to insert the session into active_sessions - await db.execute( - insert(models.ActiveQueueSession).values( - session_id=session_id - ) - ) + await process_item(db, payload=message.payload) + except Exception as e: + print(e) + sentry_sdk.capture_exception(e) + finally: + # Prevent malformed messages from stalling queue indefinitely + message.processed = True await db.commit() - # If successful, create a task for this session - await process_session_messages(session_id) - except IntegrityError: - # If the session is already in active_sessions, skip it - await db.rollback() - - else: - # No items to process, set the queue_empty_flag - queue_empty_flag.set() - except Exception as e: - rprint("==========") - rprint("Exception") - rprint(e) - rprint("==========") - await db.rollback() - - -async def polling_loop(semaphore: asyncio.Semaphore, queue_empty_flag: asyncio.Event): - while True: - if queue_empty_flag.is_set(): - await asyncio.sleep(1) # Sleep briefly if the queue is empty - queue_empty_flag.clear() # Reset the flag - continue - if semaphore.locked(): - await asyncio.sleep(1) # Sleep briefly if the semaphore is fully locked - continue - # Create a task instead of awaiting - asyncio.create_task(schedule_session(semaphore, queue_empty_flag)) - await asyncio.sleep(0) # Give other tasks a chance to run + if self.shutdown_event.is_set(): + break + + # Update last_updated timestamp to showthis session is still being processed + await db.execute( + update(models.ActiveQueueSession) + .where(models.ActiveQueueSession.session_id == session_id) + .values(last_updated=func.now()) + ) + await db.commit() + finally: + # Remove session from active_sessions when done + await db.execute( + delete(models.ActiveQueueSession).where( + models.ActiveQueueSession.session_id == session_id + ) + ) + await db.commit() + self.untrack_session(session_id) + + @sentry_sdk.trace + async def get_next_message(self, db: AsyncSession, session_id: int): + """Get the next unprocessed message for a session""" + result = await db.execute( + select(models.QueueItem) + .where(models.QueueItem.session_id == session_id) + .where(models.QueueItem.processed == False) + .order_by(models.QueueItem.id) + .with_for_update(skip_locked=True) + .limit(1) + ) + return result.scalar_one_or_none() async def main(): - SENTRY_ENABLED = os.getenv("SENTRY_ENABLED", "False").lower() == "true" - if SENTRY_ENABLED: - sentry_sdk.init( - dsn=os.getenv("SENTRY_DSN"), - enable_tracing=True, - traces_sample_rate=0.4, - profiles_sample_rate=0.4, - integrations=[ - AsyncioIntegration(), - ], - ) - workers = int(os.getenv("DERIVER_WORKERS", 1)) + 1 - semaphore = asyncio.Semaphore(workers) # Limit to 5 concurrent dequeuing operations - queue_empty_flag = asyncio.Event() # Event to signal when the queue is empty - await polling_loop(semaphore, queue_empty_flag) + manager = QueueManager() + await manager.initialize()