From e8169bf19c74e5dcaa184123764db403ad131a36 Mon Sep 17 00:00:00 2001 From: Ron Frederick Date: Thu, 5 Dec 2024 07:33:13 -0800 Subject: [PATCH] Adjust remote_only argument to only apply to copying plain files This commit adjusts the behavior of the "remote_only" argument in SFTPClient copy() and mcopy() to only raise an error when attempting to copy a plain file on a system which doesn't support remote copy. Other operations during the copy such as creating directories and symlinks and setting attributes will always be allowed. When an error handler is set, it will be called once for each plain file copy which is attempted when remote_only is set and the server doesn't support remote copy. The exceptions will include the source and destination pathname, to allow other forms of copy to be performed. --- asyncssh/sftp.py | 25 +++++++++++++------------ tests/test_sftp.py | 11 ++++++++--- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/asyncssh/sftp.py b/asyncssh/sftp.py index 4b8f0ca..209bee3 100644 --- a/asyncssh/sftp.py +++ b/asyncssh/sftp.py @@ -3759,7 +3759,8 @@ async def _copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, preserve: bool, recurse: bool, follow_symlinks: bool, block_size: int, max_requests: int, progress_handler: SFTPProgressHandler, - error_handler: SFTPErrorHandler) -> None: + error_handler: SFTPErrorHandler, + remote_only: bool) -> None: """Copy a file, directory, or symbolic link""" try: @@ -3795,7 +3796,8 @@ async def _copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, await self._copy(srcfs, dstfs, srcfile, dstfile, srcname.attrs, preserve, recurse, follow_symlinks, block_size, max_requests, - progress_handler, error_handler) + progress_handler, error_handler, + remote_only) self.logger.info(' Finished copy of directory %s to %s', srcpath, dstpath) @@ -3810,6 +3812,9 @@ async def _copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, else: self.logger.info(' Copying file %s to %s', srcpath, dstpath) + if remote_only and not self.supports_remote_copy: + raise SFTPOpUnsupported('Remote copy not supported') + await _SFTPFileCopier(block_size, max_requests, 0, srcattrs.size or 0, srcfs, dstfs, srcpath, dstpath, progress_handler).run() @@ -3846,7 +3851,8 @@ async def _begin_copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, recurse: bool, follow_symlinks: bool, block_size: int, max_requests: int, progress_handler: SFTPProgressHandler, - error_handler: SFTPErrorHandler) -> None: + error_handler: SFTPErrorHandler, + remote_only: bool = False) -> None: """Begin a new file upload, download, or copy""" if block_size <= 0: @@ -3903,7 +3909,8 @@ async def _begin_copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, await self._copy(srcfs, dstfs, srcfile, dstfile, srcname.attrs, preserve, recurse, follow_symlinks, block_size, - max_requests, progress_handler, error_handler) + max_requests, progress_handler, error_handler, + remote_only) async def get(self, remotepaths: _SFTPPaths, localpath: Optional[_SFTPPath] = None, *, @@ -4222,13 +4229,10 @@ async def copy(self, srcpaths: _SFTPPaths, """ - if remote_only and not self.supports_remote_copy: - raise SFTPOpUnsupported('Remote copy not supported') - await self._begin_copy(self, self, srcpaths, dstpath, 'remote copy', False, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, - error_handler) + error_handler, remote_only) async def mget(self, remotepaths: _SFTPPaths, localpath: Optional[_SFTPPath] = None, *, @@ -4295,13 +4299,10 @@ async def mcopy(self, srcpaths: _SFTPPaths, """ - if remote_only and not self.supports_remote_copy: - raise SFTPOpUnsupported('Remote copy not supported') - await self._begin_copy(self, self, srcpaths, dstpath, 'remote mcopy', True, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, - error_handler) + error_handler, remote_only) async def remote_copy(self, src: _SFTPClientFileOrPath, dst: _SFTPClientFileOrPath, src_offset: int = 0, diff --git a/tests/test_sftp.py b/tests/test_sftp.py index 59b377d..a853508 100644 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -777,9 +777,14 @@ async def _test_copy_remote_only(self, sftp): for method in ('copy', 'mcopy'): with self.subTest(method=method): - with self.assertRaises(SFTPOpUnsupported): - await getattr(sftp, method)('src', 'dst', - remote_only=True) + try: + self._create_file('src') + + with self.assertRaises(SFTPOpUnsupported): + await getattr(sftp, method)('src', 'dst', + remote_only=True) + finally: + remove('src') with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): # pylint: disable=no-value-for-parameter