Skip to content

Commit

Permalink
Fix logging and typing issues in SFTP high-level copy functions
Browse files Browse the repository at this point in the history
This commit fixes logging and typing issues with SFTP get, put, copy,
mget, mput, and mcopy functions. AsyncSSH should now properly handle
sequences which mix bytes, str, and PurePath entries and also fixes type
annotations for these functions to indicate that they accept either a
single path or a list of paths. Thanks go to GitHub user eyalgolan1337
for reporting these issues!
  • Loading branch information
ronf committed Jul 15, 2024
1 parent 0eac029 commit b21e758
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 24 deletions.
34 changes: 22 additions & 12 deletions asyncssh/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,39 @@ def get_child(self, child: str = '', context: str = '') -> 'SSHLogger':
def log(self, level: int, msg: object, *args, **kwargs) -> None:
"""Log a message to the underlying logger"""

def _text(arg: _LogArg) -> str:
def _item_text(item: _LogArg) -> str:
"""Convert a list item to text"""

if isinstance(item, bytes):
result = item.decode('utf-8', errors='replace')

if not result.isprintable():
result = repr(result)[1:-1]
elif not isinstance(item, str):
result = str(item)
else:
result = item

return result

def _text(arg: _LogArg) -> _LogArg:
"""Convert a log argument to text"""

result: _LogArg

if isinstance(arg, list):
if arg and isinstance(arg[0], bytes):
result = b','.join(arg).decode('utf-8', errors='replace')
else:
result = ','.join(arg)
result = ','.join(_item_text(item) for item in arg)
elif isinstance(arg, tuple):
host, port = arg

if host:
result = '%s, port %d' % (host, port) if port else host
else:
result = 'port %d' % port if port else 'dynamic port'
elif isinstance(arg, bytes):
result = _item_text(arg)
else:
result = cast(str, arg)

if isinstance(result, bytes):
result = result.decode('ascii', errors='backslashreplace')

if not result.isprintable():
result = repr(result)[1:-1]
result = arg

return result

Expand Down
22 changes: 10 additions & 12 deletions asyncssh/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3679,24 +3679,22 @@ async def _copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol,
raise

async def _begin_copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol,
srcpaths: Sequence[_SFTPPath],
dstpath: Optional[_SFTPPath],
srcpaths: _SFTPPaths, dstpath: Optional[_SFTPPath],
copy_type: str, expand_glob: bool, preserve: bool,
recurse: bool, follow_symlinks: bool,
block_size: int, max_requests: int,
progress_handler: SFTPProgressHandler,
error_handler: SFTPErrorHandler) -> None:
"""Begin a new file upload, download, or copy"""

if isinstance(srcpaths, tuple):
if isinstance(srcpaths, (bytes, str, PurePath)):
srcpaths = [srcpaths]
elif not isinstance(srcpaths, list):
srcpaths = list(srcpaths)

self.logger.info('Starting SFTP %s of %s to %s',
copy_type, srcpaths, dstpath)

if isinstance(srcpaths, (bytes, str, PurePath)):
srcpaths = [srcpaths]

srcnames: List[SFTPName] = []

if expand_glob:
Expand Down Expand Up @@ -3741,7 +3739,7 @@ async def _begin_copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol,
preserve, recurse, follow_symlinks, block_size,
max_requests, progress_handler, error_handler)

async def get(self, remotepaths: Sequence[_SFTPPath],
async def get(self, remotepaths: _SFTPPaths,
localpath: Optional[_SFTPPath] = None, *,
preserve: bool = False, recurse: bool = False,
follow_symlinks: bool = False,
Expand Down Expand Up @@ -3846,7 +3844,7 @@ async def get(self, remotepaths: Sequence[_SFTPPath],
block_size, max_requests, progress_handler,
error_handler)

async def put(self, localpaths: Sequence[_SFTPPath],
async def put(self, localpaths: _SFTPPaths,
remotepath: Optional[_SFTPPath] = None, *,
preserve: bool = False, recurse: bool = False,
follow_symlinks: bool = False,
Expand Down Expand Up @@ -3951,7 +3949,7 @@ async def put(self, localpaths: Sequence[_SFTPPath],
block_size, max_requests, progress_handler,
error_handler)

async def copy(self, srcpaths: Sequence[_SFTPPath],
async def copy(self, srcpaths: _SFTPPaths,
dstpath: Optional[_SFTPPath] = None, *,
preserve: bool = False, recurse: bool = False,
follow_symlinks: bool = False,
Expand Down Expand Up @@ -4056,7 +4054,7 @@ async def copy(self, srcpaths: Sequence[_SFTPPath],
block_size, max_requests, progress_handler,
error_handler)

async def mget(self, remotepaths: Sequence[_SFTPPath],
async def mget(self, remotepaths: _SFTPPaths,
localpath: Optional[_SFTPPath] = None, *,
preserve: bool = False, recurse: bool = False,
follow_symlinks: bool = False,
Expand All @@ -4080,7 +4078,7 @@ async def mget(self, remotepaths: Sequence[_SFTPPath],
block_size, max_requests, progress_handler,
error_handler)

async def mput(self, localpaths: Sequence[_SFTPPath],
async def mput(self, localpaths: _SFTPPaths,
remotepath: Optional[_SFTPPath] = None, *,
preserve: bool = False, recurse: bool = False,
follow_symlinks: bool = False,
Expand All @@ -4104,7 +4102,7 @@ async def mput(self, localpaths: Sequence[_SFTPPath],
block_size, max_requests, progress_handler,
error_handler)

async def mcopy(self, srcpaths: Sequence[_SFTPPath],
async def mcopy(self, srcpaths: _SFTPPaths,
dstpath: Optional[_SFTPPath] = None, *,
preserve: bool = False, recurse: bool = False,
follow_symlinks: bool = False,
Expand Down

0 comments on commit b21e758

Please sign in to comment.