Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Graceful Shutdown in Deriver #86

Merged
merged 1 commit into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/deriver/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@

if __name__ == "__main__":
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
asyncio.run(main())
try:
asyncio.run(main())
except KeyboardInterrupt:
print("Shutdown initiated via KeyboardInterrupt")
335 changes: 197 additions & 138 deletions src/deriver/queue.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import asyncio
import os
from collections.abc import Sequence
import signal
from datetime import datetime, timedelta
from typing import Any, Optional

import sentry_sdk
from dotenv import load_dotenv
from rich import print as rprint
from sentry_sdk.integrations.asyncio import AsyncioIntegration
from sqlalchemy import delete, insert, select, update
from sqlalchemy.exc import IntegrityError
Expand All @@ -20,153 +18,214 @@
load_dotenv()


async def get_next_message_for_session(
db: AsyncSession, session_id: int
) -> Optional[models.QueueItem]:
result = await db.execute(
select(models.QueueItem)
.where(models.QueueItem.session_id == session_id)
.where(models.QueueItem.processed == False)
.order_by(models.QueueItem.id)
.with_for_update(skip_locked=True)
.limit(1)
)
return result.scalar_one_or_none()
class QueueManager:
def __init__(self):
self.shutdown_event = asyncio.Event()
self.active_tasks: set[asyncio.Task] = set()
self.owned_sessions: set[int] = set()
self.queue_empty_flag = asyncio.Event()

# Initialize from environment
self.workers = int(os.getenv("DERIVER_WORKERS", 1))
self.semaphore = asyncio.Semaphore(self.workers)

# Initialize Sentry if enabled
if os.getenv("SENTRY_ENABLED", "False").lower() == "true":
sentry_sdk.init(
dsn=os.getenv("SENTRY_DSN"),
enable_tracing=True,
traces_sample_rate=0.1,
profiles_sample_rate=0.1,
integrations=[AsyncioIntegration()],
)

def add_task(self, task: asyncio.Task):
"""Track a new task"""
self.active_tasks.add(task)
task.add_done_callback(self.active_tasks.discard)

def track_session(self, session_id: int):
"""Track a new session owned by this process"""
self.owned_sessions.add(session_id)

def untrack_session(self, session_id: int):
"""Remove a session from tracking"""
self.owned_sessions.discard(session_id)

async def initialize(self):
"""Setup signal handlers and start the main polling loop"""
loop = asyncio.get_running_loop()
signals = (signal.SIGTERM, signal.SIGINT)
for sig in signals:
loop.add_signal_handler(
sig, lambda s=sig: asyncio.create_task(self.shutdown(s))
)

@sentry_sdk.trace
async def process_session_messages(session_id: int):
async with SessionLocal() as db:
try:
while True:
message = await get_next_message_for_session(db, session_id)
if not message:
break
try:
await process_item(db, payload=message.payload)
except Exception as e:
print(e)
sentry_sdk.capture_exception(e)
finally:
# Prevent malformed messages from stalling a queue indefinitely
message.processed = True
await db.commit()

# Update last_updated to show this session is still being processed
await self.polling_loop()
finally:
await self.cleanup()

async def shutdown(self, sig: signal.Signals):
"""Handle graceful shutdown"""
print(f"Received exit signal {sig.name}...")
self.shutdown_event.set()

if self.active_tasks:
print(f"Waiting for {len(self.active_tasks)} active tasks to complete...")
await asyncio.gather(*self.active_tasks, return_exceptions=True)

async def cleanup(self):
"""Clean up owned sessions"""
if self.owned_sessions:
print(f"Cleaning up {len(self.owned_sessions)} owned sessions...")
async with SessionLocal() as db:
await db.execute(
update(models.ActiveQueueSession)
.where(models.ActiveQueueSession.session_id == session_id)
.values(last_updated=func.now())
delete(models.ActiveQueueSession).where(
models.ActiveQueueSession.session_id.in_(self.owned_sessions)
)
)
await db.commit()
finally:
# Remove session from active_sessions when done
await db.execute(
delete(models.ActiveQueueSession).where(
models.ActiveQueueSession.session_id == session_id
)
)
await db.commit()


@sentry_sdk.trace
async def get_available_sessions(db: AsyncSession, limit: int) -> Sequence[Any]:
# First, clean up stale sessions (e.g., older than 5 minutes)
five_minutes_ago = datetime.utcnow() - timedelta(minutes=5)
await db.execute(
delete(models.ActiveQueueSession).where(
models.ActiveQueueSession.last_updated < five_minutes_ago
##########################
# Polling and Scheduling #
##########################

async def get_available_sessions(self, db: AsyncSession):
"""Get available sessions that aren't being processed"""
# Clean up stale sessions
five_minutes_ago = datetime.utcnow() - timedelta(minutes=5)
await db.execute(
delete(models.ActiveQueueSession).where(
models.ActiveQueueSession.last_updated < five_minutes_ago
)
)
)

# Then get available sessions
result = await db.execute(
select(models.QueueItem.session_id)
.outerjoin(
models.ActiveQueueSession,
models.QueueItem.session_id == models.ActiveQueueSession.session_id,

# Get available sessions
result = await db.execute(
select(models.QueueItem.session_id)
.outerjoin(
models.ActiveQueueSession,
models.QueueItem.session_id == models.ActiveQueueSession.session_id,
)
.where(models.QueueItem.processed == False)
.where(
models.ActiveQueueSession.session_id == None
) # Only sessions not in active_sessions
.group_by(models.QueueItem.session_id)
.limit(1)
)
.where(models.QueueItem.processed == False)
.where(
models.ActiveQueueSession.session_id == None
) # Only sessions not in active_sessions
.group_by(models.QueueItem.session_id)
.limit(limit)
)
return result.scalars().all()


@sentry_sdk.trace
async def schedule_session(
semaphore: asyncio.Semaphore, queue_empty_flag: asyncio.Event
):
async with (
semaphore,
SessionLocal() as db,
):
with sentry_sdk.start_transaction(
op="deriver_schedule_session", name="Schedule Deriver Session"
):
try:
# available_slots = semaphore._value
# print(available_slots)
new_sessions = await get_available_sessions(db, 1)

if new_sessions:
for session_id in new_sessions:
return result.scalars().all()

async def polling_loop(self):
"""Main polling loop to find and process new sessions"""
try:
while not self.shutdown_event.is_set():
if self.queue_empty_flag.is_set():
await asyncio.sleep(1)
self.queue_empty_flag.clear()
continue

# Chec if we have capacity before querying
if self.semaphore.locked():
await asyncio.sleep(1) # Wait before trying again
continue

async with SessionLocal() as db:
try:
new_sessions = await self.get_available_sessions(db)

if new_sessions and not self.shutdown_event.is_set():
for session_id in new_sessions:
try:
# Try to claim the session
await db.execute(
insert(models.ActiveQueueSession).values(
session_id=session_id
)
)
await db.commit()

# Track this session
self.track_session(session_id)

# Create a new task for processing this session
if not self.shutdown_event.is_set():
task = asyncio.create_task(
self.process_session(session_id)
)
self.add_task(task)
except IntegrityError:
await db.rollback()
else:
self.queue_empty_flag.set()
await asyncio.sleep(1)
except Exception as e:
print(f"Error in polling loop: {e}")
sentry_sdk.capture_exception(e)
await db.rollback()
await asyncio.sleep(1)
finally:
print("Polling loop stopped")

######################
# Queue Worker Logic #
######################

@sentry_sdk.trace
async def process_session(self, session_id: int):
"""Process all messages for a session"""
async with self.semaphore: # Hold the semaphore for the entire session duration
async with SessionLocal() as db:
try:
while not self.shutdown_event.is_set():
message = await self.get_next_message(db, session_id)
if not message:
break
try:
# Try to insert the session into active_sessions
await db.execute(
insert(models.ActiveQueueSession).values(
session_id=session_id
)
)
await process_item(db, payload=message.payload)
except Exception as e:
print(e)
sentry_sdk.capture_exception(e)
finally:
# Prevent malformed messages from stalling queue indefinitely
message.processed = True
await db.commit()

# If successful, create a task for this session
await process_session_messages(session_id)
except IntegrityError:
# If the session is already in active_sessions, skip it
await db.rollback()

else:
# No items to process, set the queue_empty_flag
queue_empty_flag.set()
except Exception as e:
rprint("==========")
rprint("Exception")
rprint(e)
rprint("==========")
await db.rollback()


async def polling_loop(semaphore: asyncio.Semaphore, queue_empty_flag: asyncio.Event):
while True:
if queue_empty_flag.is_set():
await asyncio.sleep(1) # Sleep briefly if the queue is empty
queue_empty_flag.clear() # Reset the flag
continue
if semaphore.locked():
await asyncio.sleep(1) # Sleep briefly if the semaphore is fully locked
continue
# Create a task instead of awaiting
asyncio.create_task(schedule_session(semaphore, queue_empty_flag))
await asyncio.sleep(0) # Give other tasks a chance to run
if self.shutdown_event.is_set():
break

# Update last_updated timestamp to showthis session is still being processed
await db.execute(
update(models.ActiveQueueSession)
.where(models.ActiveQueueSession.session_id == session_id)
.values(last_updated=func.now())
)
await db.commit()
finally:
# Remove session from active_sessions when done
await db.execute(
delete(models.ActiveQueueSession).where(
models.ActiveQueueSession.session_id == session_id
)
)
await db.commit()
self.untrack_session(session_id)

@sentry_sdk.trace
async def get_next_message(self, db: AsyncSession, session_id: int):
"""Get the next unprocessed message for a session"""
result = await db.execute(
select(models.QueueItem)
.where(models.QueueItem.session_id == session_id)
.where(models.QueueItem.processed == False)
.order_by(models.QueueItem.id)
.with_for_update(skip_locked=True)
.limit(1)
)
return result.scalar_one_or_none()


async def main():
SENTRY_ENABLED = os.getenv("SENTRY_ENABLED", "False").lower() == "true"
if SENTRY_ENABLED:
sentry_sdk.init(
dsn=os.getenv("SENTRY_DSN"),
enable_tracing=True,
traces_sample_rate=0.4,
profiles_sample_rate=0.4,
integrations=[
AsyncioIntegration(),
],
)
workers = int(os.getenv("DERIVER_WORKERS", 1)) + 1
semaphore = asyncio.Semaphore(workers) # Limit to 5 concurrent dequeuing operations
queue_empty_flag = asyncio.Event() # Event to signal when the queue is empty
await polling_loop(semaphore, queue_empty_flag)
manager = QueueManager()
await manager.initialize()
Loading