Skip to content

Commit

Permalink
Improve support for Windows pathnames when using ProxyCommand
Browse files Browse the repository at this point in the history
This commit changes the parsing of a proxy_command (or ProxyCommand in
SSH config files) to better cope with backslashes that might appear in
pathnames on Windows. Thanks go to GitHub user chipolux for reporting
the issue and investigating the existing OpenSSH parsing behavior.
  • Loading branch information
ronf committed Aug 10, 2024
1 parent a50f9b3 commit 4f3de9e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 9 deletions.
4 changes: 2 additions & 2 deletions asyncssh/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ class SSHClientConfig(SSHConfig):
"""Settings from an OpenSSH client config file"""

_conditionals = {'host', 'match'}
_no_split = {'remotecommand'}
_no_split = {'proxycommand', 'remotecommand'}
_percent_expand = {'CertificateFile', 'IdentityAgent',
'IdentityFile', 'ProxyCommand', 'RemoteCommand'}

Expand Down Expand Up @@ -559,7 +559,7 @@ def _set_tokens(self) -> None:
('PKCS11Provider', SSHConfig._set_string),
('PreferredAuthentications', SSHConfig._set_string),
('Port', SSHConfig._set_int),
('ProxyCommand', SSHConfig._set_string_list),
('ProxyCommand', SSHConfig._set_string),
('ProxyJump', SSHConfig._set_string),
('PubkeyAuthentication', SSHConfig._set_bool),
('RekeyLimit', SSHConfig._set_rekey_limits),
Expand Down
16 changes: 9 additions & 7 deletions asyncssh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
from .misc import TermModesArg, TermSizeArg
from .misc import async_context_manager, construct_disc_error
from .misc import get_symbol_names, ip_address, map_handler_name
from .misc import parse_byte_count, parse_time_interval
from .misc import parse_byte_count, parse_time_interval, split_args

from .packet import Boolean, Byte, NameList, String, UInt32, PacketDecodeError
from .packet import SSHPacket, SSHPacketHandler, SSHPacketLogger
Expand Down Expand Up @@ -231,7 +231,7 @@ async def create_server(self, session_factory: TCPListenerFactory,
_GlobalRequestResult = Tuple[int, SSHPacket]
_KeyOrCertOptions = Mapping[str, object]
_ListenerArg = Union[bool, SSHListener]
_ProxyCommand = Optional[Sequence[str]]
_ProxyCommand = Optional[Union[str, Sequence[str]]]
_RequestPTY = Union[bool, str]

_TCPServerHandlerFactory = Callable[[str, int], SSHSocketSessionFactory]
Expand Down Expand Up @@ -7144,11 +7144,13 @@ def prepare(self, config: SSHConfig, # type: ignore
self.tunnel = tunnel if tunnel != () else config.get('ProxyJump')
self.passphrase = passphrase

if proxy_command == ():
proxy_command = cast(Optional[str], config.get('ProxyCommand'))

if isinstance(proxy_command, str):
proxy_command = shlex.split(proxy_command)
proxy_command = split_args(proxy_command)

self.proxy_command = proxy_command if proxy_command != () else \
cast(Sequence[str], config.get('ProxyCommand'))
self.proxy_command = proxy_command

self.family = cast(int, family if family != () else
config.get('AddressFamily', socket.AF_UNSPEC))
Expand Down Expand Up @@ -9224,7 +9226,7 @@ async def create_server(server_factory: _ServerFactory,
async def get_server_host_key(
host = '', port: DefTuple[int] = (), *,
tunnel: DefTuple[_TunnelConnector] = (),
proxy_command: DefTuple[str] = (), family: DefTuple[int] = (),
proxy_command: DefTuple[_ProxyCommand] = (), family: DefTuple[int] = (),
flags: int = 0, local_addr: DefTuple[HostPort] = (),
sock: Optional[socket.socket] = None,
client_version: DefTuple[BytesOrStr] = (),
Expand Down Expand Up @@ -9368,7 +9370,7 @@ def conn_factory() -> SSHClientConnection:
async def get_server_auth_methods(
host = '', port: DefTuple[int] = (), username: DefTuple[str] = (), *,
tunnel: DefTuple[_TunnelConnector] = (),
proxy_command: DefTuple[str] = (), family: DefTuple[int] = (),
proxy_command: DefTuple[_ProxyCommand] = (), family: DefTuple[int] = (),
flags: int = 0, local_addr: DefTuple[HostPort] = (),
sock: Optional[socket.socket] = None,
client_version: DefTuple[BytesOrStr] = (),
Expand Down
13 changes: 13 additions & 0 deletions asyncssh/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import functools
import ipaddress
import re
import shlex
import socket
import sys

Expand Down Expand Up @@ -269,6 +270,18 @@ def parse_time_interval(value: str) -> float:
return _parse_units(value, _time_units, 'time interval')


def split_args(command: str) -> Sequence[str]:
"""Split a command string into a list of arguments"""

lex = shlex.shlex(command, posix=True)
lex.whitespace_split = True

if sys.platform == 'win32': # pragma: no cover
lex.escape = []

return list(lex)


_ACM = TypeVar('_ACM', bound=AsyncContextManager, covariant=True)

class _ACMWrapper(Generic[_ACM]):
Expand Down

0 comments on commit 4f3de9e

Please sign in to comment.