diff --git a/asyncssh/channel.py b/asyncssh/channel.py index a660cab..0323034 100644 --- a/asyncssh/channel.py +++ b/asyncssh/channel.py @@ -46,11 +46,11 @@ from .logging import SSHLogger from .misc import ChannelOpenError, MaybeAwait, ProtocolError +from .misc import TermModes, TermSize, TermSizeArg from .misc import get_symbol_names, map_handler_name from .packet import Boolean, Byte, String, UInt32, SSHPacket, SSHPacketHandler -from .session import TermModes, TermSize, TermSizeArg from .session import SSHSession, SSHClientSession, SSHServerSession from .session import SSHTCPSession, SSHUNIXSession, SSHTunTapSession from .session import SSHSessionFactory, SSHClientSessionFactory diff --git a/asyncssh/connection.py b/asyncssh/connection.py index d258f72..0e4a9e1 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -112,6 +112,7 @@ from .misc import KeyExchangeFailed, IllegalUserName, MACError from .misc import PasswordChangeRequired, PermissionDenied, ProtocolError from .misc import ProtocolNotSupported, ServiceNotAvailable +from .misc import TermModesArg, TermSizeArg from .misc import async_context_manager, construct_disc_error from .misc import get_symbol_names, ip_address, map_handler_name from .misc import parse_byte_count, parse_time_interval @@ -146,8 +147,7 @@ from .server import SSHServer -from .session import DataType, TermModesArg, TermSizeArg -from .session import SSHClientSession, SSHServerSession +from .session import DataType, SSHClientSession, SSHServerSession from .session import SSHTCPSession, SSHUNIXSession, SSHTunTapSession from .session import SSHClientSessionFactory, SSHTCPSessionFactory from .session import SSHUNIXSessionFactory, SSHTunTapSessionFactory diff --git a/asyncssh/misc.py b/asyncssh/misc.py index 22c75d9..b94265f 100644 --- a/asyncssh/misc.py +++ b/asyncssh/misc.py @@ -24,6 +24,7 @@ import ipaddress import re import socket +import sys from pathlib import Path, PurePath from random import SystemRandom @@ -41,6 +42,16 @@ from .constants import DISC_PROTOCOL_ERROR, DISC_PROTOCOL_VERSION_NOT_SUPPORTED from .constants import DISC_SERVICE_NOT_AVAILABLE +if sys.platform != 'win32': # pragma: no branch + import fcntl + import struct + import termios + +TermModes = Mapping[int, int] +TermModesArg = Optional[TermModes] +TermSize = Tuple[int, int, int, int] +TermSizeArg = Union[None, Tuple[int, int], TermSize] + class _Hash(Protocol): """Protocol for hashing data""" @@ -331,6 +342,14 @@ async def maybe_wait_closed(writer: '_SupportsWaitClosed') -> None: pass +def set_terminal_size(tty: IO, width: int, height: int, + pixwidth: int, pixheight: int) -> None: + """Set the terminal size of a TTY""" + + fcntl.ioctl(tty, termios.TIOCSWINSZ, + struct.pack('hhhh', height, width, pixwidth, pixheight)) + + class Options: """Container for configuration options""" @@ -764,6 +783,12 @@ def __init__(self, width: int, height: int, pixwidth: int, pixheight: int): self.pixwidth = pixwidth self.pixheight = pixheight + @property + def term_size(self) -> TermSize: + """Return terminal size as a tuple of 4 integers""" + + return self.width, self.height, self.pixwidth, self.pixheight + _disc_error_map = { DISC_PROTOCOL_ERROR: ProtocolError, diff --git a/asyncssh/process.py b/asyncssh/process.py index 7c7b5a9..bc3496d 100644 --- a/asyncssh/process.py +++ b/asyncssh/process.py @@ -41,17 +41,16 @@ from .logging import SSHLogger -from .misc import BytesOrStr, Error, MaybeAwait -from .misc import ProtocolError, Record, open_file +from .misc import BytesOrStr, Error, MaybeAwait, TermModes, TermSize +from .misc import ProtocolError, Record, open_file, set_terminal_size from .misc import BreakReceived, SignalReceived, TerminalSizeChanged -from .session import DataType, TermModes, TermSize +from .session import DataType from .stream import SSHReader, SSHWriter, SSHStreamSession from .stream import SSHClientStreamSession, SSHServerStreamSession from .stream import SFTPServerFactory - _AnyStrContra = TypeVar('_AnyStrContra', bytes, str, contravariant=True) _File = Union[IO[bytes], '_AsyncFileProtocol[bytes]'] @@ -406,6 +405,7 @@ def __init__(self, process: 'SSHProcess[AnyStr]', datatype: DataType, self._process: 'SSHProcess[AnyStr]' = process self._datatype = datatype self._transport: Optional[asyncio.WriteTransport] = None + self._tty: Optional[IO] = None self._close_event = asyncio.Event() def connection_made(self, transport: asyncio.BaseTransport) -> None: @@ -413,6 +413,12 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: self._transport = cast(asyncio.WriteTransport, transport) + pipe = transport.get_extra_info('pipe') + + if isinstance(self._process, SSHServerProcess) and pipe.isatty(): + self._tty = pipe + set_terminal_size(pipe, *self._process.term_size) + def connection_lost(self, exc: Optional[Exception]) -> None: """Handle closing of the pipe""" @@ -434,6 +440,12 @@ def write(self, data: AnyStr) -> None: assert self._transport is not None self._transport.write(self.encode(data)) + def write_exception(self, exc: Exception) -> None: + """Write terminal size changes to the pipe if it is a TTY""" + + if isinstance(exc, TerminalSizeChanged) and self._tty: + set_terminal_size(self._tty, *exc.term_size) + def write_eof(self) -> None: """Write EOF to the pipe""" diff --git a/asyncssh/session.py b/asyncssh/session.py index 5cd61e1..329975c 100644 --- a/asyncssh/session.py +++ b/asyncssh/session.py @@ -21,7 +21,7 @@ """SSH session handlers""" from typing import TYPE_CHECKING, Any, AnyStr, Callable, Generic -from typing import Mapping, Optional, Tuple, Union +from typing import Mapping, Optional, Tuple if TYPE_CHECKING: @@ -31,11 +31,6 @@ DataType = Optional[int] -TermModes = Mapping[int, int] -TermModesArg = Optional[TermModes] -TermSize = Tuple[int, int, int, int] -TermSizeArg = Union[None, Tuple[int, int], TermSize] - class SSHSession(Generic[AnyStr]): """SSH session handler""" diff --git a/tests/test_process.py b/tests/test_process.py index 4026c0e..eccdf6b 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -34,12 +34,18 @@ from .server import ServerTestCase from .util import asynctest, echo +if sys.platform != 'win32': # pragma: no branch + import fcntl + import struct + import termios + try: import aiofiles _aiofiles_available = True except ImportError: # pragma: no cover _aiofiles_available = False + async def _handle_client(process): """Handle a new client request""" @@ -100,6 +106,23 @@ async def _handle_client(process): except asyncssh.TerminalSizeChanged as exc: process.exit_with_signal('ABRT', False, '%sx%s' % (exc.width, exc.height)) + elif action == 'term_size_tty': + master, slave = os.openpty() + await process.redirect_stdin(master) + process.stdout.write(b'\n') + + await process.stdin.readline() + size = fcntl.ioctl(slave, termios.TIOCGWINSZ, 8*b'\0') + height, width, _, _ = struct.unpack('hhhh', size) + process.stdout.write(('%sx%s' % (width, height)).encode()) + os.close(slave) + elif action == 'term_size_nontty': + rpipe, wpipe = os.pipe() + await process.redirect_stdin(wpipe) + process.stdout.write(b'\n') + + await process.stdin.readline() + os.close(rpipe) elif action == 'timeout': process.channel.set_encoding('utf-8') process.stdout.write('Sleeping') @@ -648,6 +671,36 @@ async def test_forward_terminal_size(self): self.assertEqual(result.exit_signal[2], '80x24') + @unittest.skipIf(sys.platform == 'win32', + 'skip fcntl/termios test on Windows') + @asynctest + async def test_forward_terminal_size_tty(self): + """Test forwarding a terminal size change to a remote tty""" + + async with self.connect() as conn: + process = await conn.create_process('term_size_tty', + term_type='ansi') + await process.stdout.readline() + process.change_terminal_size(80, 24) + process.stdin.write_eof() + result = await process.wait() + + self.assertEqual(result.stdout, '80x24') + + @asynctest + async def test_forward_terminal_size_nontty(self): + """Test forwarding a terminal size change to a remote non-tty""" + + async with self.connect() as conn: + process = await conn.create_process('term_size_nontty', + term_type='ansi') + await process.stdout.readline() + process.change_terminal_size(80, 24) + process.stdin.write_eof() + result = await process.wait() + + self.assertEqual(result.stdout, '') + @asynctest async def test_forward_break(self): """Test forwarding a break"""