Skip to content

Commit

Permalink
feat(deriver): Deriver now supports SIGINT and SIGTERM signals with s…
Browse files Browse the repository at this point in the history
…hutdown logic
  • Loading branch information
VVoruganti committed Dec 29, 2024
1 parent 079084e commit 766f8f1
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 139 deletions.
5 changes: 4 additions & 1 deletion src/deriver/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@

if __name__ == "__main__":
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
asyncio.run(main())
try:
asyncio.run(main())
except KeyboardInterrupt:
print("Shutdown initiated via KeyboardInterrupt")
335 changes: 197 additions & 138 deletions src/deriver/queue.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import asyncio
import os
from collections.abc import Sequence
import signal
from datetime import datetime, timedelta
from typing import Any, Optional

import sentry_sdk
from dotenv import load_dotenv
from rich import print as rprint
from sentry_sdk.integrations.asyncio import AsyncioIntegration
from sqlalchemy import delete, insert, select, update
from sqlalchemy.exc import IntegrityError
Expand All @@ -20,153 +18,214 @@
load_dotenv()


async def get_next_message_for_session(
db: AsyncSession, session_id: int
) -> Optional[models.QueueItem]:
result = await db.execute(
select(models.QueueItem)
.where(models.QueueItem.session_id == session_id)
.where(models.QueueItem.processed == False)
.order_by(models.QueueItem.id)
.with_for_update(skip_locked=True)
.limit(1)
)
return result.scalar_one_or_none()
class QueueManager:
def __init__(self):
self.shutdown_event = asyncio.Event()
self.active_tasks: set[asyncio.Task] = set()
self.owned_sessions: set[int] = set()
self.queue_empty_flag = asyncio.Event()

# Initialize from environment
self.workers = int(os.getenv("DERIVER_WORKERS", 1))
self.semaphore = asyncio.Semaphore(self.workers)

# Initialize Sentry if enabled
if os.getenv("SENTRY_ENABLED", "False").lower() == "true":
sentry_sdk.init(
dsn=os.getenv("SENTRY_DSN"),
enable_tracing=True,
traces_sample_rate=0.1,
profiles_sample_rate=0.1,
integrations=[AsyncioIntegration()],
)

def add_task(self, task: asyncio.Task):
"""Track a new task"""
self.active_tasks.add(task)
task.add_done_callback(self.active_tasks.discard)

def track_session(self, session_id: int):
"""Track a new session owned by this process"""
self.owned_sessions.add(session_id)

def untrack_session(self, session_id: int):
"""Remove a session from tracking"""
self.owned_sessions.discard(session_id)

async def initialize(self):
"""Setup signal handlers and start the main polling loop"""
loop = asyncio.get_running_loop()
signals = (signal.SIGTERM, signal.SIGINT)
for sig in signals:
loop.add_signal_handler(
sig, lambda s=sig: asyncio.create_task(self.shutdown(s))
)

@sentry_sdk.trace
async def process_session_messages(session_id: int):
async with SessionLocal() as db:
try:
while True:
message = await get_next_message_for_session(db, session_id)
if not message:
break
try:
await process_item(db, payload=message.payload)
except Exception as e:
print(e)
sentry_sdk.capture_exception(e)
finally:
# Prevent malformed messages from stalling a queue indefinitely
message.processed = True
await db.commit()

# Update last_updated to show this session is still being processed
await self.polling_loop()
finally:
await self.cleanup()

async def shutdown(self, sig: signal.Signals):
"""Handle graceful shutdown"""
print(f"Received exit signal {sig.name}...")
self.shutdown_event.set()

if self.active_tasks:
print(f"Waiting for {len(self.active_tasks)} active tasks to complete...")
await asyncio.gather(*self.active_tasks, return_exceptions=True)

async def cleanup(self):
"""Clean up owned sessions"""
if self.owned_sessions:
print(f"Cleaning up {len(self.owned_sessions)} owned sessions...")
async with SessionLocal() as db:
await db.execute(
update(models.ActiveQueueSession)
.where(models.ActiveQueueSession.session_id == session_id)
.values(last_updated=func.now())
delete(models.ActiveQueueSession).where(
models.ActiveQueueSession.session_id.in_(self.owned_sessions)
)
)
await db.commit()
finally:
# Remove session from active_sessions when done
await db.execute(
delete(models.ActiveQueueSession).where(
models.ActiveQueueSession.session_id == session_id
)
)
await db.commit()


@sentry_sdk.trace
async def get_available_sessions(db: AsyncSession, limit: int) -> Sequence[Any]:
# First, clean up stale sessions (e.g., older than 5 minutes)
five_minutes_ago = datetime.utcnow() - timedelta(minutes=5)
await db.execute(
delete(models.ActiveQueueSession).where(
models.ActiveQueueSession.last_updated < five_minutes_ago
##########################
# Polling and Scheduling #
##########################

async def get_available_sessions(self, db: AsyncSession):
"""Get available sessions that aren't being processed"""
# Clean up stale sessions
five_minutes_ago = datetime.utcnow() - timedelta(minutes=5)
await db.execute(
delete(models.ActiveQueueSession).where(
models.ActiveQueueSession.last_updated < five_minutes_ago
)
)
)

# Then get available sessions
result = await db.execute(
select(models.QueueItem.session_id)
.outerjoin(
models.ActiveQueueSession,
models.QueueItem.session_id == models.ActiveQueueSession.session_id,

# Get available sessions
result = await db.execute(
select(models.QueueItem.session_id)
.outerjoin(
models.ActiveQueueSession,
models.QueueItem.session_id == models.ActiveQueueSession.session_id,
)
.where(models.QueueItem.processed == False)
.where(
models.ActiveQueueSession.session_id == None
) # Only sessions not in active_sessions
.group_by(models.QueueItem.session_id)
.limit(1)
)
.where(models.QueueItem.processed == False)
.where(
models.ActiveQueueSession.session_id == None
) # Only sessions not in active_sessions
.group_by(models.QueueItem.session_id)
.limit(limit)
)
return result.scalars().all()


@sentry_sdk.trace
async def schedule_session(
semaphore: asyncio.Semaphore, queue_empty_flag: asyncio.Event
):
async with (
semaphore,
SessionLocal() as db,
):
with sentry_sdk.start_transaction(
op="deriver_schedule_session", name="Schedule Deriver Session"
):
try:
# available_slots = semaphore._value
# print(available_slots)
new_sessions = await get_available_sessions(db, 1)

if new_sessions:
for session_id in new_sessions:
return result.scalars().all()

async def polling_loop(self):
"""Main polling loop to find and process new sessions"""
try:
while not self.shutdown_event.is_set():
if self.queue_empty_flag.is_set():
await asyncio.sleep(1)
self.queue_empty_flag.clear()
continue

# Chec if we have capacity before querying
if self.semaphore.locked():
await asyncio.sleep(1) # Wait before trying again
continue

async with SessionLocal() as db:
try:
new_sessions = await self.get_available_sessions(db)

if new_sessions and not self.shutdown_event.is_set():
for session_id in new_sessions:
try:
# Try to claim the session
await db.execute(
insert(models.ActiveQueueSession).values(
session_id=session_id
)
)
await db.commit()

# Track this session
self.track_session(session_id)

# Create a new task for processing this session
if not self.shutdown_event.is_set():
task = asyncio.create_task(
self.process_session(session_id)
)
self.add_task(task)
except IntegrityError:
await db.rollback()
else:
self.queue_empty_flag.set()
await asyncio.sleep(1)
except Exception as e:
print(f"Error in polling loop: {e}")
sentry_sdk.capture_exception(e)
await db.rollback()
await asyncio.sleep(1)
finally:
print("Polling loop stopped")

######################
# Queue Worker Logic #
######################

@sentry_sdk.trace
async def process_session(self, session_id: int):
"""Process all messages for a session"""
async with self.semaphore: # Hold the semaphore for the entire session duration
async with SessionLocal() as db:
try:
while not self.shutdown_event.is_set():
message = await self.get_next_message(db, session_id)
if not message:
break
try:
# Try to insert the session into active_sessions
await db.execute(
insert(models.ActiveQueueSession).values(
session_id=session_id
)
)
await process_item(db, payload=message.payload)
except Exception as e:
print(e)
sentry_sdk.capture_exception(e)
finally:
# Prevent malformed messages from stalling queue indefinitely
message.processed = True
await db.commit()

# If successful, create a task for this session
await process_session_messages(session_id)
except IntegrityError:
# If the session is already in active_sessions, skip it
await db.rollback()

else:
# No items to process, set the queue_empty_flag
queue_empty_flag.set()
except Exception as e:
rprint("==========")
rprint("Exception")
rprint(e)
rprint("==========")
await db.rollback()


async def polling_loop(semaphore: asyncio.Semaphore, queue_empty_flag: asyncio.Event):
while True:
if queue_empty_flag.is_set():
await asyncio.sleep(1) # Sleep briefly if the queue is empty
queue_empty_flag.clear() # Reset the flag
continue
if semaphore.locked():
await asyncio.sleep(1) # Sleep briefly if the semaphore is fully locked
continue
# Create a task instead of awaiting
asyncio.create_task(schedule_session(semaphore, queue_empty_flag))
await asyncio.sleep(0) # Give other tasks a chance to run
if self.shutdown_event.is_set():
break

# Update last_updated timestamp to showthis session is still being processed
await db.execute(
update(models.ActiveQueueSession)
.where(models.ActiveQueueSession.session_id == session_id)
.values(last_updated=func.now())
)
await db.commit()
finally:
# Remove session from active_sessions when done
await db.execute(
delete(models.ActiveQueueSession).where(
models.ActiveQueueSession.session_id == session_id
)
)
await db.commit()
self.untrack_session(session_id)

@sentry_sdk.trace
async def get_next_message(self, db: AsyncSession, session_id: int):
"""Get the next unprocessed message for a session"""
result = await db.execute(
select(models.QueueItem)
.where(models.QueueItem.session_id == session_id)
.where(models.QueueItem.processed == False)
.order_by(models.QueueItem.id)
.with_for_update(skip_locked=True)
.limit(1)
)
return result.scalar_one_or_none()


async def main():
SENTRY_ENABLED = os.getenv("SENTRY_ENABLED", "False").lower() == "true"
if SENTRY_ENABLED:
sentry_sdk.init(
dsn=os.getenv("SENTRY_DSN"),
enable_tracing=True,
traces_sample_rate=0.4,
profiles_sample_rate=0.4,
integrations=[
AsyncioIntegration(),
],
)
workers = int(os.getenv("DERIVER_WORKERS", 1)) + 1
semaphore = asyncio.Semaphore(workers) # Limit to 5 concurrent dequeuing operations
queue_empty_flag = asyncio.Event() # Event to signal when the queue is empty
await polling_loop(semaphore, queue_empty_flag)
manager = QueueManager()
await manager.initialize()

0 comments on commit 766f8f1

Please sign in to comment.