diff --git a/changes/93.fix b/changes/93.fix new file mode 100644 index 00000000..a27c7127 --- /dev/null +++ b/changes/93.fix @@ -0,0 +1 @@ +Apply queueing of individual event types to reduce contention of overlapping event handlers in each manager process diff --git a/src/ai/backend/common/events.py b/src/ai/backend/common/events.py index 5e95534e..0b2b2bce 100644 --- a/src/ai/backend/common/events.py +++ b/src/ai/backend/common/events.py @@ -41,6 +41,7 @@ KernelId, SessionId, LogSeverity, + Sentinel, ) __all__ = ( @@ -621,6 +622,10 @@ class EventDispatcher(aobject): subscriber_loop_task: asyncio.Task consumer_taskset: weakref.WeakSet[asyncio.Task] subscriber_taskset: weakref.WeakSet[asyncio.Task] + consumer_queues: defaultdict[str, asyncio.Queue[ + tuple[str, AgentId, tuple] | Sentinel, + ]] + consumer_handlers: dict[str, asyncio.Task] _log_events: bool _consumer_name: str @@ -641,10 +646,18 @@ def __init__( self._consumer_name = secrets.token_urlsafe(16) async def __ainit__(self) -> None: - self.consumer_loop_task = asyncio.create_task(self._consume_loop()) - self.subscriber_loop_task = asyncio.create_task(self._subscribe_loop()) + self.consumer_loop_task = asyncio.create_task( + self._consume_loop(), + name="evdispatcher.consumer_loop", + ) + self.subscriber_loop_task = asyncio.create_task( + self._subscribe_loop(), + name="evdispatcher.subscriber_loop", + ) self.consumer_taskset = weakref.WeakSet() self.subscriber_taskset = weakref.WeakSet() + self.consumer_queues = defaultdict(asyncio.Queue) + self.consumer_handlers = {} async def close(self) -> None: cancelled_tasks = [] @@ -661,6 +674,11 @@ async def close(self) -> None: cancelled_tasks.append(self.consumer_loop_task) cancelled_tasks.append(self.subscriber_loop_task) await asyncio.gather(*cancelled_tasks, return_exceptions=True) + join_tasks = [] + for q in self.consumer_queues.values(): + q.put_nowait(Sentinel.TOKEN) + join_tasks.append(q.join()) + await asyncio.gather(*join_tasks) await self.redis_client.close() def consume( @@ -731,11 +749,15 @@ async def dispatch_consumers( ) -> None: if self._log_events: log.debug('DISPATCH_CONSUMERS(ev:{}, ag:{})', event_name, source) + consumer_tasks = [] for consumer in self.consumers[event_name].copy(): - self.consumer_taskset.add(asyncio.create_task( + consumer_tasks.append(asyncio.create_task( self.handle("CONSUMER", consumer, source, args), )) - await asyncio.sleep(0) + results = await asyncio.gather(*consumer_tasks, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + log.error("unexpected error while processing ev:{}", event_name, exc_info=result) async def dispatch_subscribers( self, @@ -751,6 +773,16 @@ async def dispatch_subscribers( )) await asyncio.sleep(0) + async def _consume_handle(self, event_name) -> None: + while True: + params = await self.consumer_queues[event_name].get() + try: + if params is Sentinel.TOKEN: + break + await self.dispatch_consumers(*params) + finally: + self.consumer_queues[event_name].task_done() + async def _consume_loop(self) -> None: async with aclosing(redis.read_stream_by_group( self.redis_client, @@ -760,11 +792,19 @@ async def _consume_loop(self) -> None: )) as agen: async for msg_id, msg_data in agen: try: - await self.dispatch_consumers( - msg_data[b'name'].decode(), - msg_data[b'source'].decode(), - msgpack.unpackb(msg_data[b'args']), + event_name = msg_data[b'name'].decode() + self.consumer_queues[event_name].put_nowait( + ( + event_name, + msg_data[b'source'].decode(), + msgpack.unpackb(msg_data[b'args']), + ), ) + if event_name not in self.consumer_handlers: + self.consumer_handlers[event_name] = asyncio.create_task( + self._consume_handle(event_name), + name=f"evdispatcher.consume_handler.{event_name}" + ) except asyncio.CancelledError: raise except Exception: