From d7fd6ac9e7e58176b356a971c26dc9e51e358aef Mon Sep 17 00:00:00 2001 From: Emre Terzioglu <50607143+EmreTech@users.noreply.github.com> Date: Sun, 9 Jul 2023 22:49:58 -0700 Subject: [PATCH 1/5] refactor: safely store and close tasks --- nextcore/common/dispatcher.py | 14 +++++++++----- nextcore/gateway/shard.py | 21 +++++++++++++++++++-- nextcore/gateway/shard_manager.py | 14 +++++++++----- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/nextcore/common/dispatcher.py b/nextcore/common/dispatcher.py index 2efc96171..d6a333461 100644 --- a/nextcore/common/dispatcher.py +++ b/nextcore/common/dispatcher.py @@ -21,7 +21,7 @@ from __future__ import annotations -from asyncio import CancelledError, Future, create_task +from asyncio import CancelledError, Future, Task, create_task, gather from collections import defaultdict from logging import getLogger from typing import TYPE_CHECKING, Generic, Hashable, TypeVar, cast, overload @@ -441,22 +441,26 @@ async def dispatch(self, event_name: EventNameT, *args: Any) -> None: """ logger.debug("Dispatching event %s", event_name) + pending_tasks: list[Task[None]] = [] + # Event handlers # Tasks are used here as some event handler/check might take a long time. for handler in self._global_event_handlers: logger.debug("Dispatching to a global handler") - create_task(self._run_global_event_handler(handler, event_name, *args)) + pending_tasks.append(create_task(self._run_global_event_handler(handler, event_name, *args))) for handler in self._event_handlers.get(event_name, []): logger.debug("Dispatching to a local handler") - create_task(self._run_event_handler(handler, event_name, *args)) + pending_tasks.append(create_task(self._run_event_handler(handler, event_name, *args))) # Wait for handlers for check, future in self._wait_for_handlers.get(event_name, []): logger.debug("Dispatching to a wait_for handler") - create_task(self._run_wait_for_handler(check, future, event_name, *args)) + pending_tasks.append(create_task(self._run_wait_for_handler(check, future, event_name, *args))) for check, future in self._global_wait_for_handlers: logger.debug("Dispatching to a global wait_for handler") - create_task(self._run_global_wait_for_handler(check, future, event_name, *args)) + pending_tasks.append(create_task(self._run_global_wait_for_handler(check, future, event_name, *args))) + + await gather(*pending_tasks) async def _run_event_handler(self, callback: EventCallback, event_name: EventNameT, *args: Any) -> None: """Run event with exception handlers""" diff --git a/nextcore/gateway/shard.py b/nextcore/gateway/shard.py index d77699b31..c2ff5d64c 100644 --- a/nextcore/gateway/shard.py +++ b/nextcore/gateway/shard.py @@ -177,6 +177,8 @@ class Shard: "_logger", "_received_heartbeat_ack", "_http_client", + "_receive_task", + "_heartbeat_task", "_heartbeat_sent_at", "_latency", ) @@ -228,6 +230,8 @@ def __init__( self._logger: Logger = getLogger(f"{__name__}.{self.shard_id}") self._received_heartbeat_ack: bool = True self._http_client: HTTPClient = http_client # TODO: Should this be private? + self._receive_task: asyncio.Task[None] | None = None + self._heartbeat_task: asyncio.Task[None] | None = None # Latency self._heartbeat_sent_at: float | None = None @@ -282,7 +286,7 @@ async def connect(self) -> None: self._received_heartbeat_ack = True self._ws = ws # Use the new connection - create_task(self._receive_loop()) + self._receive_task = create_task(self._receive_loop()) # Connection logic is continued in _handle_hello to account for that rate limits are defined there. @@ -334,6 +338,19 @@ async def close(self, *, cleanup: bool = True) -> None: await self._ws.close(code=999) self._ws = None # Clear it to save some ram self._send_rate_limit = None # No longer applies + + # safely stop running tasks + + if self._receive_task is not None: + self._receive_task.cancel() + await self._receive_task + self._receive_task = None + + if self._heartbeat_task is not None: + self._heartbeat_task.cancel() + await self._heartbeat_task + self._heartbeat_task = None + self.connected.clear() @property @@ -538,7 +555,7 @@ async def _handle_hello(self, data: HelloEvent) -> None: heartbeat_interval = data["d"]["heartbeat_interval"] / 1000 # Convert from ms to seconds loop = get_running_loop() - loop.create_task(self._heartbeat_loop(heartbeat_interval)) + self._heartbeat_task = loop.create_task(self._heartbeat_loop(heartbeat_interval)) # Create a rate limiter times, per = self._GATEWAY_SEND_RATE_LIMITS diff --git a/nextcore/gateway/shard_manager.py b/nextcore/gateway/shard_manager.py index e8adec905..800b98d72 100644 --- a/nextcore/gateway/shard_manager.py +++ b/nextcore/gateway/shard_manager.py @@ -21,7 +21,7 @@ from __future__ import annotations -from asyncio import CancelledError, gather, get_running_loop +from asyncio import CancelledError, Task, gather, get_running_loop from collections import defaultdict from logging import getLogger from typing import TYPE_CHECKING @@ -188,8 +188,10 @@ async def connect(self) -> None: else: shard_ids = self.shard_ids + connect_tasks: list[Task[None]] = [] for shard_id in shard_ids: - shard = self._spawn_shard(shard_id, self._active_shard_count) + shard, connect_task = self._spawn_shard(shard_id, self._active_shard_count) + connect_tasks.append(connect_task) # Register event listeners shard.raw_dispatcher.add_listener(self._on_raw_shard_receive) @@ -200,7 +202,9 @@ async def connect(self) -> None: self.active_shards.append(shard) - def _spawn_shard(self, shard_id: int, shard_count: int) -> Shard: + await gather(*connect_tasks) + + def _spawn_shard(self, shard_id: int, shard_count: int) -> tuple[Shard, Task[None]]: assert self.max_concurrency is not None, "max_concurrency is not set. This is set in connect" rate_limiter = self._identify_rate_limits[shard_id % self.max_concurrency] @@ -216,9 +220,9 @@ def _spawn_shard(self, shard_id: int, shard_count: int) -> Shard: # Here we lazy connect the shard. This gives us a bit more speed when connecting large sets of shards. loop = get_running_loop() - loop.create_task(shard.connect()) + task = loop.create_task(shard.connect()) - return shard + return (shard, task) async def rescale_shards(self, shard_count: int, shard_ids: list[int] | None = None) -> None: """Change the shard count without restarting From f8784f98cd1c74f86064124684b1eae937e30b68 Mon Sep 17 00:00:00 2001 From: Emre Terzioglu <50607143+EmreTech@users.noreply.github.com> Date: Wed, 12 Jul 2023 00:32:40 -0700 Subject: [PATCH 2/5] refactor: use task groups --- nextcore/common/dispatcher.py | 41 +++++++++++++++---------------- nextcore/gateway/shard_manager.py | 32 +++++++++++------------- pyproject.toml | 1 + 3 files changed, 35 insertions(+), 39 deletions(-) diff --git a/nextcore/common/dispatcher.py b/nextcore/common/dispatcher.py index d6a333461..bd7ea775d 100644 --- a/nextcore/common/dispatcher.py +++ b/nextcore/common/dispatcher.py @@ -21,11 +21,13 @@ from __future__ import annotations -from asyncio import CancelledError, Future, Task, create_task, gather +from asyncio import CancelledError, Future from collections import defaultdict from logging import getLogger from typing import TYPE_CHECKING, Generic, Hashable, TypeVar, cast, overload +from anyio import create_task_group + from .maybe_coro import maybe_coro # Types @@ -441,26 +443,23 @@ async def dispatch(self, event_name: EventNameT, *args: Any) -> None: """ logger.debug("Dispatching event %s", event_name) - pending_tasks: list[Task[None]] = [] - - # Event handlers - # Tasks are used here as some event handler/check might take a long time. - for handler in self._global_event_handlers: - logger.debug("Dispatching to a global handler") - pending_tasks.append(create_task(self._run_global_event_handler(handler, event_name, *args))) - for handler in self._event_handlers.get(event_name, []): - logger.debug("Dispatching to a local handler") - pending_tasks.append(create_task(self._run_event_handler(handler, event_name, *args))) - - # Wait for handlers - for check, future in self._wait_for_handlers.get(event_name, []): - logger.debug("Dispatching to a wait_for handler") - pending_tasks.append(create_task(self._run_wait_for_handler(check, future, event_name, *args))) - for check, future in self._global_wait_for_handlers: - logger.debug("Dispatching to a global wait_for handler") - pending_tasks.append(create_task(self._run_global_wait_for_handler(check, future, event_name, *args))) - - await gather(*pending_tasks) + async with create_task_group() as tg: + # Event handlers + # Tasks are used here as some event handler/check might take a long time. + for handler in self._global_event_handlers: + logger.debug("Dispatching to a global handler") + await tg.spawn(self._run_global_event_handler, handler, event_name, *args) + for handler in self._event_handlers.get(event_name, []): + logger.debug("Dispatching to a local handler") + await tg.spawn(self._run_event_handler, handler, event_name, *args) + + # Wait for handlers + for check, future in self._global_wait_for_handlers: + logger.debug("Dispatching to a global wait_for handler") + await tg.spawn(self._run_global_wait_for_handler, check, future, event_name, *args) + for check, future in self._wait_for_handlers.get(event_name, []): + logger.debug("Dispatching to a wait_for handler") + await tg.spawn(self._run_wait_for_handler, check, future, event_name, *args) async def _run_event_handler(self, callback: EventCallback, event_name: EventNameT, *args: Any) -> None: """Run event with exception handlers""" diff --git a/nextcore/gateway/shard_manager.py b/nextcore/gateway/shard_manager.py index 800b98d72..98e2904f0 100644 --- a/nextcore/gateway/shard_manager.py +++ b/nextcore/gateway/shard_manager.py @@ -27,6 +27,7 @@ from typing import TYPE_CHECKING from aiohttp import ClientConnectionError +from anyio import create_task_group from ..common import Dispatcher, TimesPer from ..http import Route @@ -188,23 +189,22 @@ async def connect(self) -> None: else: shard_ids = self.shard_ids - connect_tasks: list[Task[None]] = [] - for shard_id in shard_ids: - shard, connect_task = self._spawn_shard(shard_id, self._active_shard_count) - connect_tasks.append(connect_task) - - # Register event listeners - shard.raw_dispatcher.add_listener(self._on_raw_shard_receive) - shard.event_dispatcher.add_listener(self._on_shard_dispatch) - shard.dispatcher.add_listener(self._on_shard_critical, "critical") + async with create_task_group() as tg: + for shard_id in shard_ids: + shard = self._spawn_shard(shard_id, self._active_shard_count) + # Here we lazy connect the shard. This gives us a bit more speed when connecting large sets of shards. + await tg.spawn(shard.connect) - logger.info("Added shard event listeners") + # Register event listeners + shard.raw_dispatcher.add_listener(self._on_raw_shard_receive) + shard.event_dispatcher.add_listener(self._on_shard_dispatch) + shard.dispatcher.add_listener(self._on_shard_critical, "critical") - self.active_shards.append(shard) + logger.info("Added shard event listeners") - await gather(*connect_tasks) + self.active_shards.append(shard) - def _spawn_shard(self, shard_id: int, shard_count: int) -> tuple[Shard, Task[None]]: + def _spawn_shard(self, shard_id: int, shard_count: int) -> Shard: assert self.max_concurrency is not None, "max_concurrency is not set. This is set in connect" rate_limiter = self._identify_rate_limits[shard_id % self.max_concurrency] @@ -218,11 +218,7 @@ def _spawn_shard(self, shard_id: int, shard_count: int) -> tuple[Shard, Task[Non presence=self.presence, ) - # Here we lazy connect the shard. This gives us a bit more speed when connecting large sets of shards. - loop = get_running_loop() - task = loop.create_task(shard.connect()) - - return (shard, task) + return shard async def rescale_shards(self, shard_count: int, shard_ids: list[int] | None = None) -> None: """Change the shard count without restarting diff --git a/pyproject.toml b/pyproject.toml index b30d5e3ea..8f1b95674 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.8" aiohttp = ">=3.6.0,<4.0.0" +anyio = ">=3.7.0,<4.0.0" frozendict = "^2.3.0" types-frozendict = "^2.0.6" # Could we extend the version requirement typing-extensions = "^4.1.1" # Same as above From 373ef50fedc0be7c87eba821854c22c275815374 Mon Sep 17 00:00:00 2001 From: Emre Terzioglu <50607143+EmreTech@users.noreply.github.com> Date: Wed, 12 Jul 2023 00:33:20 -0700 Subject: [PATCH 3/5] fix(Shard): tasks are not awaited anymore --- nextcore/gateway/shard.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nextcore/gateway/shard.py b/nextcore/gateway/shard.py index c2ff5d64c..6ecc7ae1f 100644 --- a/nextcore/gateway/shard.py +++ b/nextcore/gateway/shard.py @@ -343,12 +343,10 @@ async def close(self, *, cleanup: bool = True) -> None: if self._receive_task is not None: self._receive_task.cancel() - await self._receive_task self._receive_task = None if self._heartbeat_task is not None: self._heartbeat_task.cancel() - await self._heartbeat_task self._heartbeat_task = None self.connected.clear() From 293f0c4edb56daca346d61f8cc39d2c30ca8a1fb Mon Sep 17 00:00:00 2001 From: Emre Terzioglu <50607143+EmreTech@users.noreply.github.com> Date: Wed, 12 Jul 2023 01:08:48 -0700 Subject: [PATCH 4/5] fix: TaskGroup.spawn is deprecated --- nextcore/common/dispatcher.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nextcore/common/dispatcher.py b/nextcore/common/dispatcher.py index bd7ea775d..0d6660c75 100644 --- a/nextcore/common/dispatcher.py +++ b/nextcore/common/dispatcher.py @@ -448,18 +448,18 @@ async def dispatch(self, event_name: EventNameT, *args: Any) -> None: # Tasks are used here as some event handler/check might take a long time. for handler in self._global_event_handlers: logger.debug("Dispatching to a global handler") - await tg.spawn(self._run_global_event_handler, handler, event_name, *args) + tg.start_soon(self._run_global_event_handler, handler, event_name, *args) for handler in self._event_handlers.get(event_name, []): logger.debug("Dispatching to a local handler") - await tg.spawn(self._run_event_handler, handler, event_name, *args) + tg.start_soon(self._run_event_handler, handler, event_name, *args) # Wait for handlers for check, future in self._global_wait_for_handlers: logger.debug("Dispatching to a global wait_for handler") - await tg.spawn(self._run_global_wait_for_handler, check, future, event_name, *args) + tg.start_soon(self._run_global_wait_for_handler, check, future, event_name, *args) for check, future in self._wait_for_handlers.get(event_name, []): logger.debug("Dispatching to a wait_for handler") - await tg.spawn(self._run_wait_for_handler, check, future, event_name, *args) + tg.start_soon(self._run_wait_for_handler, check, future, event_name, *args) async def _run_event_handler(self, callback: EventCallback, event_name: EventNameT, *args: Any) -> None: """Run event with exception handlers""" From 36de37beb8487dbe7db965f9f23c6bd479b94852 Mon Sep 17 00:00:00 2001 From: Emre Terzioglu <50607143+EmreTech@users.noreply.github.com> Date: Wed, 12 Jul 2023 09:28:36 -0700 Subject: [PATCH 5/5] chore: use ^ syntax --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8f1b95674..bfe58dc14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.8" aiohttp = ">=3.6.0,<4.0.0" -anyio = ">=3.7.0,<4.0.0" +anyio = "^3.7.0" frozendict = "^2.3.0" types-frozendict = "^2.0.6" # Could we extend the version requirement typing-extensions = "^4.1.1" # Same as above