Skip to content

Commit

Permalink
Improve handling of tast cancellation in SCP
Browse files Browse the repository at this point in the history
This commit improves on the handling of CancelledError in SCP,
avoiding an issue where AsyncSSH could get stuck waiting for the
channel to close in some cases when a KeyboardInterrupt or timeout
occurred. Thanks go to Max Orlov for reporting the problem and
providing code to reproduce it.
  • Loading branch information
ronf committed Feb 9, 2024
1 parent 3807b64 commit 676534b
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions asyncssh/scp.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,12 +388,14 @@ def handle_error(self, exc: Exception) -> None:
elif self._error_handler:
self._error_handler(exc)

async def close(self) -> None:
async def close(self, cancelled: bool = False) -> None:
"""Close an SCP session"""

self.logger.info('Stopping remote SCP')

if self._server:
if cancelled:
self._writer.channel.abort()
elif self._server:
cast('SSHServerChannel', self._writer.channel).exit(0)
else:
self._writer.close()
Expand Down Expand Up @@ -535,6 +537,8 @@ async def _send_files(self, srcpath: bytes, dstpath: bytes,
async def run(self, srcpath: _SCPPath) -> None:
"""Start SCP transfer"""

cancelled = False

try:
if isinstance(srcpath, PurePath):
srcpath = str(srcpath)
Expand All @@ -550,10 +554,12 @@ async def run(self, srcpath: _SCPPath) -> None:
for name in await SFTPGlob(self._fs).match(srcpath):
await self._send_files(cast(bytes, name.filename),
b'', name.attrs)
except asyncio.CancelledError:
cancelled = True
except (OSError, SFTPError) as exc:
self.handle_error(exc)
finally:
await self.close()
await self.close(cancelled)


class _SCPSink(_SCPHandler):
Expand Down Expand Up @@ -699,6 +705,8 @@ async def _recv_files(self, srcpath: bytes, dstpath: bytes) -> None:
async def run(self, dstpath: _SCPPath) -> None:
"""Start SCP file receive"""

cancelled = False

try:
if isinstance(dstpath, PurePath):
dstpath = str(dstpath)
Expand All @@ -711,10 +719,12 @@ async def run(self, dstpath: _SCPPath) -> None:
dstpath))
else:
await self._recv_files(b'', dstpath)
except asyncio.CancelledError as exc:
cancelled = True
except (OSError, SFTPError, ValueError) as exc:
self.handle_error(exc)
finally:
await self.close()
await self.close(cancelled)


class _SCPCopier:
Expand Down Expand Up @@ -870,13 +880,17 @@ async def _copy_files(self) -> None:
async def run(self) -> None:
"""Start SCP remote-to-remote transfer"""

cancelled = False

try:
await self._copy_files()
except asyncio.CancelledError:
cancelled = True
except (OSError, SFTPError) as exc:
self._handle_error(exc)
finally:
await self._source.close()
await self._sink.close()
await self._source.close(cancelled)
await self._sink.close(cancelled)


async def scp(srcpaths: Union[_SCPConnPath, Sequence[_SCPConnPath]],
Expand Down

0 comments on commit 676534b

Please sign in to comment.