diff --git a/aioasuswrt/connection.py b/aioasuswrt/connection.py index 1fac42d..7c087ca 100644 --- a/aioasuswrt/connection.py +++ b/aioasuswrt/connection.py @@ -124,33 +124,41 @@ def __init__( super().__init__(host, port or 22, username, password) self._ssh_key = ssh_key self._client = None + self._lock = asyncio.Lock() async def _async_call_command(self, command: str) -> List[str]: """Run commands through an SSH connection. Connect to the SSH server if not currently connected, otherwise use the existing connection. """ - try: - if not self.is_connected: - await self._async_connect() - if not self._client: - raise _CommandException - - result = await asyncio.wait_for( - self._client.run(f"{_PATH_EXPORT_COMMAND} && {command}"), - 9, - ) - except asyncssh.misc.ChannelOpenError as ex: - self._disconnect() - _LOGGER.warning("Not connected to host") - raise _CommandException from ex - except TimeoutError as ex: - self._disconnect() - _LOGGER.error("Host timeout.") - raise _CommandException from ex - - return result.stdout.split("\n") - + if not self.is_connected: + if not retry: + await self.async_connect() + return await self.async_run_command(command, retry=True) + else: + _LOGGER.error("Cant connect to host, giving up!") + return [] + else: + try: + async with self._lock: + result = await asyncio.wait_for( + self._client.run("%s && %s" % (_PATH_EXPORT_COMMAND, command)), + 9, + ) + except asyncssh.misc.ChannelOpenError: + if not retry: + await self.async_connect() + return await self.async_run_command(command, retry=True) + else: + _LOGGER.error("Cant connect to host, giving up!") + return [] + except TimeoutError: + self._client = None + _LOGGER.error("Host timeout.") + return [] + else: + return result.stdout.split("\n") + @property def is_connected(self) -> bool: """Do we have a connection.""" @@ -166,8 +174,15 @@ async def _async_connect(self): "known_hosts": None, 'server_host_key_algs': ['ssh-rsa'], } - - self._client = await asyncssh.connect(self._host, **kwargs) + async with self._lock: + if self.is_connected: + _LOGGER.debug("reconnecting; old connection had local port %d", self._client._local_port) + self._client.close() + self._client = None + else: + _LOGGER.debug("reconnecting; no old connection existed") + self._client = await asyncssh.connect(self._host, **kwargs) + _LOGGER.debug("reconnected; new connection has local port %d", self._client._local_port) def _disconnect(self): self._client = None diff --git a/setup.py b/setup.py index 5b23047..11c1a8c 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ from setuptools import find_packages, setup __author__ = "Magnus Knutas" -VERSION = "1.4.0" +VERSION = "1.4.1" with open("README.md", "r") as fh: long_description = fh.read()