diff --git a/asyncssh/sftp.py b/asyncssh/sftp.py index 3809d3e..c40519a 100644 --- a/asyncssh/sftp.py +++ b/asyncssh/sftp.py @@ -137,6 +137,8 @@ _SFTPPatList = List[Union[bytes, List[bytes]]] _SFTPStatFunc = Callable[[_SFTPPath], Awaitable['SFTPAttrs']] +_SFTPClientFileOrPath = Union['SFTPClientFile', _SFTPPath] + _SFTPNames = Tuple[Sequence['SFTPName'], bool] _SFTPOSAttrs = Union[os.stat_result, 'SFTPAttrs'] _SFTPOSVFSAttrs = Union[os.statvfs_result, 'SFTPVFSAttrs'] @@ -799,6 +801,22 @@ async def run_task(self, offset: int, size: int) -> Tuple[int, int]: async def run(self) -> None: """Perform parallel file copy""" + if self._srcfs == self._dstfs and \ + isinstance(self._srcfs, SFTPClient): + try: + await self._srcfs.remote_copy(self._srcpath, self._dstpath) + except SFTPOpUnsupported: + pass + else: + self._bytes_copied = self._total_bytes + + if self._progress_handler: + self._progress_handler(self._srcpath, self._dstpath, + self._bytes_copied, + self._total_bytes) + + return + try: self._src = await self._srcfs.open(self._srcpath, 'rb', block_size=0) @@ -808,24 +826,6 @@ async def run(self) -> None: if self._progress_handler and self._total_bytes == 0: self._progress_handler(self._srcpath, self._dstpath, 0, 0) - if self._srcfs == self._dstfs and \ - isinstance(self._srcfs, SFTPClient): - try: - await self._srcfs.remote_copy( - cast(SFTPClientFile, self._src), - cast(SFTPClientFile, self._dst)) - except SFTPOpUnsupported: - pass - else: - self._bytes_copied = self._total_bytes - - if self._progress_handler: - self._progress_handler(self._srcpath, self._dstpath, - self._bytes_copied, - self._total_bytes) - - return - async for _, datalen in self.iter(): if datalen: self._bytes_copied += datalen @@ -4283,9 +4283,9 @@ async def mcopy(self, srcpaths: _SFTPPaths, block_size, max_requests, progress_handler, error_handler) - async def remote_copy(self, src: SFTPClientFile, dst: SFTPClientFile, - src_offset: int = 0, src_length: int = 0, - dst_offset: int = 0) -> None: + async def remote_copy(self, src: _SFTPClientFileOrPath, + dst: _SFTPClientFileOrPath, src_offset: int = 0, + src_length: int = 0, dst_offset: int = 0) -> None: """Copy data between remote files :param src: @@ -4298,8 +4298,12 @@ async def remote_copy(self, src: SFTPClientFile, dst: SFTPClientFile, The number of bytes to attempt to copy :param dst_offset: (optional) The offset to begin writing data to - :type src: :class:`SSHClientFile` - :type dst: :class:`SSHClientFile` + :type src: + :class:`SSHClientFile`, :class:`PurePath `, + `str`, or `bytes` + :type dst: + :class:`SSHClientFile`, :class:`PurePath `, + `str`, or `bytes` :type src_offset: `int` :type src_length: `int` :type dst_offset: `int` @@ -4309,6 +4313,12 @@ async def remote_copy(self, src: SFTPClientFile, dst: SFTPClientFile, """ + if isinstance(src, (bytes, str, PurePath)): + src = await self.open(src, 'rb', block_size=0) + + if isinstance(dst, (bytes, str, PurePath)): + dst = await self.open(dst, 'wb', block_size=0) + await self._handler.copy_data(src.handle, src_offset, src_length, dst.handle, dst_offset)