Skip to content

Commit

Permalink
Add support for str, bytes, or PurePath pathnames in remote_copy
Browse files Browse the repository at this point in the history
This commit adds support for passing in pathnames of type str, bytes, or
PurePath to the new remote_copy() function, in addition to passing in
already-open SFTPClientFile objects.
  • Loading branch information
ronf committed Dec 3, 2024
1 parent 6ecb91e commit 4cf2a39
Showing 1 changed file with 33 additions and 23 deletions.
56 changes: 33 additions & 23 deletions asyncssh/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@
_SFTPPatList = List[Union[bytes, List[bytes]]]
_SFTPStatFunc = Callable[[_SFTPPath], Awaitable['SFTPAttrs']]

_SFTPClientFileOrPath = Union['SFTPClientFile', _SFTPPath]

_SFTPNames = Tuple[Sequence['SFTPName'], bool]
_SFTPOSAttrs = Union[os.stat_result, 'SFTPAttrs']
_SFTPOSVFSAttrs = Union[os.statvfs_result, 'SFTPVFSAttrs']
Expand Down Expand Up @@ -799,6 +801,22 @@ async def run_task(self, offset: int, size: int) -> Tuple[int, int]:
async def run(self) -> None:
"""Perform parallel file copy"""

if self._srcfs == self._dstfs and \
isinstance(self._srcfs, SFTPClient):
try:
await self._srcfs.remote_copy(self._srcpath, self._dstpath)
except SFTPOpUnsupported:
pass
else:
self._bytes_copied = self._total_bytes

if self._progress_handler:
self._progress_handler(self._srcpath, self._dstpath,
self._bytes_copied,
self._total_bytes)

return

try:
self._src = await self._srcfs.open(self._srcpath, 'rb',
block_size=0)
Expand All @@ -808,24 +826,6 @@ async def run(self) -> None:
if self._progress_handler and self._total_bytes == 0:
self._progress_handler(self._srcpath, self._dstpath, 0, 0)

if self._srcfs == self._dstfs and \
isinstance(self._srcfs, SFTPClient):
try:
await self._srcfs.remote_copy(
cast(SFTPClientFile, self._src),
cast(SFTPClientFile, self._dst))
except SFTPOpUnsupported:
pass
else:
self._bytes_copied = self._total_bytes

if self._progress_handler:
self._progress_handler(self._srcpath, self._dstpath,
self._bytes_copied,
self._total_bytes)

return

async for _, datalen in self.iter():
if datalen:
self._bytes_copied += datalen
Expand Down Expand Up @@ -4283,9 +4283,9 @@ async def mcopy(self, srcpaths: _SFTPPaths,
block_size, max_requests, progress_handler,
error_handler)

async def remote_copy(self, src: SFTPClientFile, dst: SFTPClientFile,
src_offset: int = 0, src_length: int = 0,
dst_offset: int = 0) -> None:
async def remote_copy(self, src: _SFTPClientFileOrPath,
dst: _SFTPClientFileOrPath, src_offset: int = 0,
src_length: int = 0, dst_offset: int = 0) -> None:
"""Copy data between remote files
:param src:
Expand All @@ -4298,8 +4298,12 @@ async def remote_copy(self, src: SFTPClientFile, dst: SFTPClientFile,
The number of bytes to attempt to copy
:param dst_offset: (optional)
The offset to begin writing data to
:type src: :class:`SSHClientFile`
:type dst: :class:`SSHClientFile`
:type src:
:class:`SSHClientFile`, :class:`PurePath <pathlib.PurePath>`,
`str`, or `bytes`
:type dst:
:class:`SSHClientFile`, :class:`PurePath <pathlib.PurePath>`,
`str`, or `bytes`
:type src_offset: `int`
:type src_length: `int`
:type dst_offset: `int`
Expand All @@ -4309,6 +4313,12 @@ async def remote_copy(self, src: SFTPClientFile, dst: SFTPClientFile,
"""

if isinstance(src, (bytes, str, PurePath)):
src = await self.open(src, 'rb', block_size=0)

if isinstance(dst, (bytes, str, PurePath)):
dst = await self.open(dst, 'wb', block_size=0)

await self._handler.copy_data(src.handle, src_offset, src_length,
dst.handle, dst_offset)

Expand Down

0 comments on commit 4cf2a39

Please sign in to comment.