Skip to content

Commit

Permalink
Add support for getting/setting environment variables using byte strings
Browse files Browse the repository at this point in the history
This commit allows applications to pass in byte strings in addition to
Unicode strings when setting environment variables. As before, Unicode
strings will be UTF-8 encoded, but byte strings can support passing
binary data, or text encoded with other encodings.

On the server side, a new get_environment_bytes() method can be used to
get the environment variables as a dictionary with byte srtings as the
keys and values, instead of the Unicode strings returned by
get_environment(), allowing access to the raw binary data. The existing
get_environment() remains available, but it will only provide access
to the environment variables which have valid UTF-8 keys and values.

It is up to the applicaiton code to know what encoding (if any) should
be used when getting or setting environment variables.
  • Loading branch information
ronf committed Sep 15, 2024
1 parent f9138f8 commit d6a65a1
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 78 deletions.
84 changes: 57 additions & 27 deletions asyncssh/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@

from .logging import SSHLogger

from .misc import ChannelOpenError, MaybeAwait, ProtocolError
from .misc import ChannelOpenError, EnvIter, MaybeAwait, ProtocolError
from .misc import TermModes, TermSize, TermSizeArg
from .misc import get_symbol_names, map_handler_name
from .misc import decode_env, encode_env, get_symbol_names, map_handler_name

from .packet import Boolean, Byte, String, UInt32, SSHPacket, SSHPacketHandler

Expand Down Expand Up @@ -115,7 +115,9 @@ def __init__(self, conn: 'SSHConnection',
self._send_high_water: int
self._send_low_water: int

self._env: Dict[str, str] = {}
self._env: Dict[bytes, bytes] = {}
self._str_env: Optional[Dict[str, str]] = None

self._command: Optional[str] = None
self._subsystem: Optional[str] = None

Expand Down Expand Up @@ -1030,12 +1032,41 @@ def get_environment(self) -> Mapping[str, str]:
"""Return the environment for this session
This method returns the environment set by the client when
the session was opened. On the server, calls to this method
should only be made after :meth:`session_started
<SSHServerSession.session_started>` has been called on the
:class:`SSHServerSession`. When using the stream-based API,
calls to this can be made at any time after the handler
function has started up.
the session was opened. Keys and values are of type `str`
and this object only provides access to keys and values sent
as valid UTF-8 strings. Use :meth:`get_environment_bytes`
if you need to access environment variables with keys or
values containing binary data or non-UTF-8 encodings.
On the server, calls to this method should only be made after
:meth:`session_started <SSHServerSession.session_started>` has
been called on the :class:`SSHServerSession`. When using the
stream-based API, calls to this can be made at any time after
the handler function has started up.
:returns: A dictionary containing the environment variables
set by the client
"""

if self._str_env is None:
self._str_env = dict(decode_env(self._env))

return MappingProxyType(self._str_env)

def get_environment_bytes(self) -> Mapping[bytes, bytes]:
"""Return the environment for this session
This method returns the environment set by the client when
the session was opened. Keys and values are of type `bytes`
and can include arbitrary binary data, with the exception
of NUL (\0) bytes.
On the server, calls to this method should only be made after
:meth:`session_started <SSHServerSession.session_started>` has
been called on the :class:`SSHServerSession`. When using the
stream-based API, calls to this can be made at any time after
the handler function has started up.
:returns: A dictionary containing the environment variables
set by the client
Expand Down Expand Up @@ -1097,7 +1128,7 @@ def __init__(self, conn: 'SSHClientConnection',

async def create(self, session_factory: SSHClientSessionFactory[AnyStr],
command: Optional[str], subsystem: Optional[str],
env: Dict[str, str], request_pty: bool,
env: Dict[bytes, bytes], request_pty: bool,
term_type: Optional[str], term_size: TermSizeArg,
term_modes: TermModes, x11_forwarding: Union[bool, str],
x11_display: Optional[str], x11_auth_path: Optional[str],
Expand All @@ -1119,10 +1150,16 @@ async def create(self, session_factory: SSHClientSessionFactory[AnyStr],
self._command = command
self._subsystem = subsystem

for name, env_value in env.items():
self.logger.debug1(' Env: %s=%s', name, env_value)
self._send_request(b'env', String(str(name)),
String(str(env_value)))
for key, value in env.items():
self.logger.debug1(' Env: %s=%s', key, value)

if not isinstance(key, (bytes, str)):
key = str(key)

if not isinstance(value, (bytes, str)):
value = str(value)

self._send_request(b'env', String(key), String(value))

if request_pty:
self.logger.debug1(' Terminal type: %s', term_type or 'None')
Expand Down Expand Up @@ -1460,8 +1497,8 @@ def __init__(self, conn: 'SSHServerConnection',

super().__init__(conn, loop, encoding, errors, window, max_pktsize)

self._env = cast(Dict[str, str],
conn.get_key_option('environment', {}))
env_option = cast(EnvIter, conn.get_key_option('environment', {}))
self._env = dict(encode_env(env_option))

self._allow_pty = allow_pty
self._line_editor = line_editor
Expand Down Expand Up @@ -1616,19 +1653,12 @@ async def _finish_agent_req_request(self) -> None:
def _process_env_request(self, packet: SSHPacket) -> bool:
"""Process a request to set an environment variable"""

name_bytes = packet.get_string()
value_bytes = packet.get_string()
key = packet.get_string()
value = packet.get_string()
packet.check_end()

try:
name = name_bytes.decode('utf-8')
value = value_bytes.decode('utf-8')
except UnicodeDecodeError:
self.logger.debug1('Invalid environment data')
return False

self.logger.debug1(' Env: %s=%s', name, value)
self._env[name] = value
self.logger.debug1(' Env: %s=%s', key, value)
self._env[key] = value
return True

def _start_session(self, command: Optional[str] = None,
Expand Down
50 changes: 19 additions & 31 deletions asyncssh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,16 @@

from .mac import get_mac_algs, get_default_mac_algs

from .misc import BytesOrStr, BytesOrStrDict, DefTuple, FilePath, HostPort
from .misc import IPNetwork, MaybeAwait, OptExcInfo, Options, SockAddr
from .misc import BytesOrStr, BytesOrStrDict, DefTuple, Env, EnvList, FilePath
from .misc import HostPort, IPNetwork, MaybeAwait, OptExcInfo, Options, SockAddr
from .misc import ChannelListenError, ChannelOpenError, CompressionError
from .misc import DisconnectError, ConnectionLost, HostKeyNotVerifiable
from .misc import KeyExchangeFailed, IllegalUserName, MACError
from .misc import PasswordChangeRequired, PermissionDenied, ProtocolError
from .misc import ProtocolNotSupported, ServiceNotAvailable
from .misc import TermModesArg, TermSizeArg
from .misc import async_context_manager, construct_disc_error
from .misc import get_symbol_names, ip_address, map_handler_name
from .misc import async_context_manager, construct_disc_error, encode_env
from .misc import get_symbol_names, ip_address, lookup_env, map_handler_name
from .misc import parse_byte_count, parse_time_interval, split_args

from .packet import Boolean, Byte, NameList, String, UInt32, PacketDecodeError
Expand Down Expand Up @@ -224,9 +224,6 @@ async def create_server(self, session_factory: TCPListenerFactory,
_ClientHostKey = Union[SSHKeyPair, SSHKeySignKeyPair]
_ClientKeysArg = Union[KeyListArg, KeyPairListArg]

_Env = Optional[Union[Mapping[str, str], Sequence[str]]]
_SendEnv = Optional[Sequence[str]]

_GlobalRequest = Tuple[Optional[_PacketHandler], SSHPacket, bool]
_GlobalRequestResult = Tuple[int, SSHPacket]
_KeyOrCertOptions = Mapping[str, object]
Expand Down Expand Up @@ -4073,8 +4070,8 @@ def detach_x11_listener(self, chan: SSHChannel[AnyStr]) -> None:
async def create_session(self, session_factory: SSHClientSessionFactory,
command: DefTuple[Optional[str]] = (), *,
subsystem: DefTuple[Optional[str]]= (),
env: DefTuple[_Env] = (),
send_env: DefTuple[_SendEnv] = (),
env: DefTuple[Env] = (),
send_env: DefTuple[Optional[EnvList]] = (),
request_pty: DefTuple[Union[bool, str]] = (),
term_type: DefTuple[Optional[str]] = (),
term_size: DefTuple[TermSizeArg] = (),
Expand Down Expand Up @@ -4179,8 +4176,8 @@ async def create_session(self, session_factory: SSHClientSessionFactory,
:type session_factory: `callable`
:type command: `str`
:type subsystem: `str`
:type env: `dict` with `str` keys and values
:type send_env: `list` of `str`
:type env: `dict` with `bytes` or `str` keys and values
:type send_env: `list` of `bytes` or `str`
:type request_pty: `bool`, `'force'`, or `'auto'`
:type term_type: `str`
:type term_size: `tuple` of 2 or 4 `int` values
Expand Down Expand Up @@ -4248,22 +4245,13 @@ async def create_session(self, session_factory: SSHClientSessionFactory,
if max_pktsize == ():
max_pktsize = self._options.max_pktsize

new_env: Dict[str, str] = {}
new_env: Dict[bytes, bytes] = {}

if send_env:
for key in send_env:
pattern = WildcardPattern(key)
new_env.update((key, value) for key, value in os.environ.items()
if pattern.matches(key))
new_env.update(lookup_env(send_env))

if env:
try:
if isinstance(env, list):
new_env.update((item.split('=', 1) for item in env))
else:
new_env.update(cast(Mapping[str, str], env))
except ValueError:
raise ValueError('Invalid environment value') from None
new_env.update(encode_env(env))

if request_pty == 'force':
request_pty = True
Expand Down Expand Up @@ -5601,8 +5589,8 @@ def session_factory() -> SSHTunTapSession:
return cast(SSHForwarder, peer)

@async_context_manager
async def start_sftp_client(self, env: DefTuple[_Env] = (),
send_env: DefTuple[_SendEnv] = (),
async def start_sftp_client(self, env: DefTuple[Env] = (),
send_env: DefTuple[Optional[EnvList]] = (),
path_encoding: Optional[str] = 'utf-8',
path_errors = 'strict',
sftp_version = MIN_SFTP_VERSION) -> SFTPClient:
Expand Down Expand Up @@ -7781,8 +7769,8 @@ class SSHClientConnectionOptions(SSHConnectionOptions):
pkcs11_pin: Optional[str]
command: Optional[str]
subsystem: Optional[str]
env: _Env
send_env: _SendEnv
env: Env
send_env: Optional[EnvList]
request_pty: _RequestPTY
term_type: Optional[str]
term_size: TermSizeArg
Expand Down Expand Up @@ -7848,8 +7836,8 @@ def prepare(self, # type: ignore
pkcs11_provider: DefTuple[Optional[str]] = (),
pkcs11_pin: Optional[str] = None,
command: DefTuple[Optional[str]] = (),
subsystem: Optional[str] = None, env: DefTuple[_Env] = (),
send_env: DefTuple[_SendEnv] = (),
subsystem: Optional[str] = None, env: DefTuple[Env] = (),
send_env: DefTuple[Optional[EnvList]] = (),
request_pty: DefTuple[_RequestPTY] = (),
term_type: Optional[str] = None,
term_size: TermSizeArg = None,
Expand Down Expand Up @@ -8068,9 +8056,9 @@ def prepare(self, # type: ignore

self.subsystem = subsystem

self.env = cast(_Env, env if env != () else config.get('SetEnv'))
self.env = cast(Env, env if env != () else config.get('SetEnv'))

self.send_env = cast(_SendEnv, send_env if send_env != () else
self.send_env = cast(Optional[EnvList], send_env if send_env != () else
config.get('SendEnv'))

self.request_pty = cast(_RequestPTY, request_pty if request_pty != ()
Expand Down
57 changes: 55 additions & 2 deletions asyncssh/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
"""Miscellaneous utility classes and functions"""

import asyncio
import fnmatch
import functools
import ipaddress
import os
import re
import shlex
import socket
Expand All @@ -32,8 +34,8 @@
from random import SystemRandom
from types import TracebackType
from typing import Any, AsyncContextManager, Awaitable, Callable, Dict
from typing import Generator, Generic, IO, Mapping, Optional, Sequence
from typing import Tuple, Type, TypeVar, Union, cast, overload
from typing import Generator, Generic, IO, Iterator, Mapping, Optional
from typing import Sequence, Tuple, Type, TypeVar, Union, cast, overload
from typing_extensions import Literal, Protocol

from .constants import DEFAULT_LANG
Expand Down Expand Up @@ -109,6 +111,10 @@ async def wait_closed(self) -> None:
IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
SockAddr = Union[Tuple[str, int], Tuple[str, int, int, int]]

EnvDict = Mapping[BytesOrStr, BytesOrStr]
EnvIter = Iterator[Tuple[BytesOrStr, BytesOrStr]]
EnvList = Sequence[BytesOrStr]
Env = Optional[Union[EnvDict, EnvIter, EnvList]]

# Define a version of randrange which is based on SystemRandom(), so that
# we get back numbers suitable for cryptographic use.
Expand All @@ -121,6 +127,53 @@ async def wait_closed(self) -> None:
'd': 24*60*60, 'w': 7*24*60*60}


def encode_env(env: Env) -> Iterator[Tuple[bytes, bytes]]:
"""Convert environemnt dict or list to bytes-based dictionary"""

try:
if isinstance(env, list):
for item in env:
if isinstance(item, str):
item = item.encode('utf-8')

yield item.split(b'=', 1)
else:
env = cast(EnvIter, env.items() if isinstance(env, dict) else env)

for key, value in env:
key_bytes = key.encode('utf-8') \
if isinstance(key, str) else key

value_bytes = value.encode('utf-8') \
if isinstance(value, str) else value

yield key_bytes, value_bytes
except (TypeError, ValueError) as exc:
raise ValueError('Invalid environment value: %s' % exc) from None


def lookup_env(patterns: EnvList) -> Iterator[Tuple[bytes, bytes]]:
"""Look up environemnt variables with wildcard matches"""

for pattern in patterns:
if isinstance(pattern, str):
pattern = pattern.encode('utf-8')

for key_bytes, value_bytes in os.environb.items():
if fnmatch.fnmatch(key_bytes, pattern):
yield key_bytes, value_bytes


def decode_env(env: Dict[bytes, bytes]) -> Iterator[Tuple[str, str]]:
"""Convert bytes-based environemnt dict to Unicode strings"""

for key, value in env.items():
try:
yield key.decode('utf-8'), value.decode('utf-8')
except UnicodeDecodeError:
pass


def hide_empty(value: object, prefix: str = ', ') -> str:
"""Return a string with optional prefix if value is non-empty"""

Expand Down
Loading

0 comments on commit d6a65a1

Please sign in to comment.