Skip to content

Commit

Permalink
Merge pull request #91 from danielskowronski/fix-conn-leak
Browse files Browse the repository at this point in the history
Fix connection leak and add Coroutine-Safety
  • Loading branch information
kennedyshead authored Dec 13, 2024
2 parents 27477d8 + 932e2e2 commit 33e6ce6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 24 deletions.
61 changes: 38 additions & 23 deletions aioasuswrt/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,33 +124,41 @@ def __init__(
super().__init__(host, port or 22, username, password)
self._ssh_key = ssh_key
self._client = None
self._lock = asyncio.Lock()

async def _async_call_command(self, command: str) -> List[str]:
"""Run commands through an SSH connection.
Connect to the SSH server if not currently connected, otherwise
use the existing connection.
"""
try:
if not self.is_connected:
await self._async_connect()
if not self._client:
raise _CommandException

result = await asyncio.wait_for(
self._client.run(f"{_PATH_EXPORT_COMMAND} && {command}"),
9,
)
except asyncssh.misc.ChannelOpenError as ex:
self._disconnect()
_LOGGER.warning("Not connected to host")
raise _CommandException from ex
except TimeoutError as ex:
self._disconnect()
_LOGGER.error("Host timeout.")
raise _CommandException from ex

return result.stdout.split("\n")

if not self.is_connected:
if not retry:
await self.async_connect()
return await self.async_run_command(command, retry=True)
else:
_LOGGER.error("Cant connect to host, giving up!")
return []
else:
try:
async with self._lock:
result = await asyncio.wait_for(
self._client.run("%s && %s" % (_PATH_EXPORT_COMMAND, command)),
9,
)
except asyncssh.misc.ChannelOpenError:
if not retry:
await self.async_connect()
return await self.async_run_command(command, retry=True)
else:
_LOGGER.error("Cant connect to host, giving up!")
return []
except TimeoutError:
self._client = None
_LOGGER.error("Host timeout.")
return []
else:
return result.stdout.split("\n")

@property
def is_connected(self) -> bool:
"""Do we have a connection."""
Expand All @@ -166,8 +174,15 @@ async def _async_connect(self):
"known_hosts": None,
'server_host_key_algs': ['ssh-rsa'],
}

self._client = await asyncssh.connect(self._host, **kwargs)
async with self._lock:
if self.is_connected:
_LOGGER.debug("reconnecting; old connection had local port %d", self._client._local_port)
self._client.close()
self._client = None
else:
_LOGGER.debug("reconnecting; no old connection existed")
self._client = await asyncssh.connect(self._host, **kwargs)
_LOGGER.debug("reconnected; new connection has local port %d", self._client._local_port)

def _disconnect(self):
self._client = None
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from setuptools import find_packages, setup

__author__ = "Magnus Knutas"
VERSION = "1.4.0"
VERSION = "1.4.1"

with open("README.md", "r") as fh:
long_description = fh.read()
Expand Down

0 comments on commit 33e6ce6

Please sign in to comment.