diff --git a/nextcore/common/dispatcher.py b/nextcore/common/dispatcher.py index 2efc96171..0d6660c75 100644 --- a/nextcore/common/dispatcher.py +++ b/nextcore/common/dispatcher.py @@ -21,11 +21,13 @@ from __future__ import annotations -from asyncio import CancelledError, Future, create_task +from asyncio import CancelledError, Future from collections import defaultdict from logging import getLogger from typing import TYPE_CHECKING, Generic, Hashable, TypeVar, cast, overload +from anyio import create_task_group + from .maybe_coro import maybe_coro # Types @@ -441,22 +443,23 @@ async def dispatch(self, event_name: EventNameT, *args: Any) -> None: """ logger.debug("Dispatching event %s", event_name) - # Event handlers - # Tasks are used here as some event handler/check might take a long time. - for handler in self._global_event_handlers: - logger.debug("Dispatching to a global handler") - create_task(self._run_global_event_handler(handler, event_name, *args)) - for handler in self._event_handlers.get(event_name, []): - logger.debug("Dispatching to a local handler") - create_task(self._run_event_handler(handler, event_name, *args)) - - # Wait for handlers - for check, future in self._wait_for_handlers.get(event_name, []): - logger.debug("Dispatching to a wait_for handler") - create_task(self._run_wait_for_handler(check, future, event_name, *args)) - for check, future in self._global_wait_for_handlers: - logger.debug("Dispatching to a global wait_for handler") - create_task(self._run_global_wait_for_handler(check, future, event_name, *args)) + async with create_task_group() as tg: + # Event handlers + # Tasks are used here as some event handler/check might take a long time. + for handler in self._global_event_handlers: + logger.debug("Dispatching to a global handler") + tg.start_soon(self._run_global_event_handler, handler, event_name, *args) + for handler in self._event_handlers.get(event_name, []): + logger.debug("Dispatching to a local handler") + tg.start_soon(self._run_event_handler, handler, event_name, *args) + + # Wait for handlers + for check, future in self._global_wait_for_handlers: + logger.debug("Dispatching to a global wait_for handler") + tg.start_soon(self._run_global_wait_for_handler, check, future, event_name, *args) + for check, future in self._wait_for_handlers.get(event_name, []): + logger.debug("Dispatching to a wait_for handler") + tg.start_soon(self._run_wait_for_handler, check, future, event_name, *args) async def _run_event_handler(self, callback: EventCallback, event_name: EventNameT, *args: Any) -> None: """Run event with exception handlers""" diff --git a/nextcore/gateway/shard.py b/nextcore/gateway/shard.py index d77699b31..6ecc7ae1f 100644 --- a/nextcore/gateway/shard.py +++ b/nextcore/gateway/shard.py @@ -177,6 +177,8 @@ class Shard: "_logger", "_received_heartbeat_ack", "_http_client", + "_receive_task", + "_heartbeat_task", "_heartbeat_sent_at", "_latency", ) @@ -228,6 +230,8 @@ def __init__( self._logger: Logger = getLogger(f"{__name__}.{self.shard_id}") self._received_heartbeat_ack: bool = True self._http_client: HTTPClient = http_client # TODO: Should this be private? + self._receive_task: asyncio.Task[None] | None = None + self._heartbeat_task: asyncio.Task[None] | None = None # Latency self._heartbeat_sent_at: float | None = None @@ -282,7 +286,7 @@ async def connect(self) -> None: self._received_heartbeat_ack = True self._ws = ws # Use the new connection - create_task(self._receive_loop()) + self._receive_task = create_task(self._receive_loop()) # Connection logic is continued in _handle_hello to account for that rate limits are defined there. @@ -334,6 +338,17 @@ async def close(self, *, cleanup: bool = True) -> None: await self._ws.close(code=999) self._ws = None # Clear it to save some ram self._send_rate_limit = None # No longer applies + + # safely stop running tasks + + if self._receive_task is not None: + self._receive_task.cancel() + self._receive_task = None + + if self._heartbeat_task is not None: + self._heartbeat_task.cancel() + self._heartbeat_task = None + self.connected.clear() @property @@ -538,7 +553,7 @@ async def _handle_hello(self, data: HelloEvent) -> None: heartbeat_interval = data["d"]["heartbeat_interval"] / 1000 # Convert from ms to seconds loop = get_running_loop() - loop.create_task(self._heartbeat_loop(heartbeat_interval)) + self._heartbeat_task = loop.create_task(self._heartbeat_loop(heartbeat_interval)) # Create a rate limiter times, per = self._GATEWAY_SEND_RATE_LIMITS diff --git a/nextcore/gateway/shard_manager.py b/nextcore/gateway/shard_manager.py index e8adec905..98e2904f0 100644 --- a/nextcore/gateway/shard_manager.py +++ b/nextcore/gateway/shard_manager.py @@ -21,12 +21,13 @@ from __future__ import annotations -from asyncio import CancelledError, gather, get_running_loop +from asyncio import CancelledError, Task, gather, get_running_loop from collections import defaultdict from logging import getLogger from typing import TYPE_CHECKING from aiohttp import ClientConnectionError +from anyio import create_task_group from ..common import Dispatcher, TimesPer from ..http import Route @@ -188,17 +189,20 @@ async def connect(self) -> None: else: shard_ids = self.shard_ids - for shard_id in shard_ids: - shard = self._spawn_shard(shard_id, self._active_shard_count) + async with create_task_group() as tg: + for shard_id in shard_ids: + shard = self._spawn_shard(shard_id, self._active_shard_count) + # Here we lazy connect the shard. This gives us a bit more speed when connecting large sets of shards. + await tg.spawn(shard.connect) - # Register event listeners - shard.raw_dispatcher.add_listener(self._on_raw_shard_receive) - shard.event_dispatcher.add_listener(self._on_shard_dispatch) - shard.dispatcher.add_listener(self._on_shard_critical, "critical") + # Register event listeners + shard.raw_dispatcher.add_listener(self._on_raw_shard_receive) + shard.event_dispatcher.add_listener(self._on_shard_dispatch) + shard.dispatcher.add_listener(self._on_shard_critical, "critical") - logger.info("Added shard event listeners") + logger.info("Added shard event listeners") - self.active_shards.append(shard) + self.active_shards.append(shard) def _spawn_shard(self, shard_id: int, shard_count: int) -> Shard: assert self.max_concurrency is not None, "max_concurrency is not set. This is set in connect" @@ -214,10 +218,6 @@ def _spawn_shard(self, shard_id: int, shard_count: int) -> Shard: presence=self.presence, ) - # Here we lazy connect the shard. This gives us a bit more speed when connecting large sets of shards. - loop = get_running_loop() - loop.create_task(shard.connect()) - return shard async def rescale_shards(self, shard_count: int, shard_ids: list[int] | None = None) -> None: diff --git a/pyproject.toml b/pyproject.toml index b30d5e3ea..bfe58dc14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.8" aiohttp = ">=3.6.0,<4.0.0" +anyio = "^3.7.0" frozendict = "^2.3.0" types-frozendict = "^2.0.6" # Could we extend the version requirement typing-extensions = "^4.1.1" # Same as above