From fadfb2f66298c35d6274d3ca61fd9c2f8c7c2ef4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Skowro=C5=84ski?= Date: Thu, 2 Nov 2023 17:51:02 +0100 Subject: [PATCH 1/3] solve connection leaks and add Coroutine-Safety #90 --- aioasuswrt/connection.py | 13 ++++++++++--- setup.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/aioasuswrt/connection.py b/aioasuswrt/connection.py index 280af59..673bc06 100644 --- a/aioasuswrt/connection.py +++ b/aioasuswrt/connection.py @@ -24,6 +24,7 @@ def __init__(self, host, port, username, password, ssh_key): self._password = password self._ssh_key = ssh_key self._client = None + self._lock = asyncio.Lock() async def async_run_command(self, command, retry=False): """Run commands through an SSH connection. @@ -65,7 +66,6 @@ def is_connected(self): async def async_connect(self): """Fetches the client or creates a new one.""" - kwargs = { "username": self._username if self._username else None, "client_keys": [self._ssh_key] if self._ssh_key else None, @@ -74,8 +74,15 @@ async def async_connect(self): "known_hosts": None, 'server_host_key_algs': ['ssh-rsa'], } - - self._client = await asyncssh.connect(self._host, **kwargs) + async with self._lock: + if self.is_connected: + _LOGGER.debug("reconnecting; old connection had local port %d", self._client._local_port) + self._client.close() + self._client = None + else: + _LOGGER.debug("reconnecting; no old connection existed") + self._client = await asyncssh.connect(self._host, **kwargs) + _LOGGER.debug("reconnected; new connection has local port %d", self._client._local_port) class TelnetConnection: diff --git a/setup.py b/setup.py index 5b23047..11c1a8c 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ from setuptools import find_packages, setup __author__ = "Magnus Knutas" -VERSION = "1.4.0" +VERSION = "1.4.1" with open("README.md", "r") as fh: long_description = fh.read() From 8ea026eb456d2e82b5b78b8af947a6668c64bef5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Skowro=C5=84ski?= Date: Thu, 2 Nov 2023 18:12:24 +0100 Subject: [PATCH 2/3] reflow async_run_command and make it Coroutine-Safe - make logic more readable - safeguards race-condition between connect and run_command --- aioasuswrt/connection.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/aioasuswrt/connection.py b/aioasuswrt/connection.py index 673bc06..acad288 100644 --- a/aioasuswrt/connection.py +++ b/aioasuswrt/connection.py @@ -31,11 +31,15 @@ async def async_run_command(self, command, retry=False): Connect to the SSH server if not currently connected, otherwise use the existing connection. """ - if self._client is None and not retry: - await self.async_connect() - return await self.async_run_command(command, retry=True) - else: - if self._client is not None: + async with self._lock: + if self._client is None: + if not retry: + await self.async_connect() + return await self.async_run_command(command, retry=True) + else: + _LOGGER.error("Cant connect to host, giving up!") + return [] + else: try: result = await asyncio.wait_for( self._client.run("%s && %s" % (_PATH_EXPORT_COMMAND, command)), @@ -52,12 +56,9 @@ async def async_run_command(self, command, retry=False): self._client = None _LOGGER.error("Host timeout.") return [] - - return result.stdout.split("\n") - - else: - _LOGGER.error("Cant connect to host, giving up!") - return [] + else: + return result.stdout.split("\n") + @property def is_connected(self): From f10aaeddf50baae05e03c49da8546bbcf9b6ce60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Skowro=C5=84ski?= Date: Thu, 2 Nov 2023 18:40:24 +0100 Subject: [PATCH 3/3] revert brave locking of recursive corourine and guard only run --- aioasuswrt/connection.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/aioasuswrt/connection.py b/aioasuswrt/connection.py index acad288..0b4d33a 100644 --- a/aioasuswrt/connection.py +++ b/aioasuswrt/connection.py @@ -31,33 +31,33 @@ async def async_run_command(self, command, retry=False): Connect to the SSH server if not currently connected, otherwise use the existing connection. """ - async with self._lock: - if self._client is None: + if not self.is_connected: + if not retry: + await self.async_connect() + return await self.async_run_command(command, retry=True) + else: + _LOGGER.error("Cant connect to host, giving up!") + return [] + else: + try: + async with self._lock: + result = await asyncio.wait_for( + self._client.run("%s && %s" % (_PATH_EXPORT_COMMAND, command)), + 9, + ) + except asyncssh.misc.ChannelOpenError: if not retry: await self.async_connect() return await self.async_run_command(command, retry=True) else: _LOGGER.error("Cant connect to host, giving up!") return [] + except TimeoutError: + self._client = None + _LOGGER.error("Host timeout.") + return [] else: - try: - result = await asyncio.wait_for( - self._client.run("%s && %s" % (_PATH_EXPORT_COMMAND, command)), - 9, - ) - except asyncssh.misc.ChannelOpenError: - if not retry: - await self.async_connect() - return await self.async_run_command(command, retry=True) - else: - _LOGGER.error("Cant connect to host, giving up!") - return [] - except TimeoutError: - self._client = None - _LOGGER.error("Host timeout.") - return [] - else: - return result.stdout.split("\n") + return result.stdout.split("\n") @property