diff --git a/asyncssh/__init__.py b/asyncssh/__init__.py index 5b99444..a8328c6 100644 --- a/asyncssh/__init__.py +++ b/asyncssh/__init__.py @@ -44,6 +44,7 @@ from .connection import SSHAcceptor, SSHClientConnection, SSHServerConnection from .connection import SSHClientConnectionOptions, SSHServerConnectionOptions +from .connection import SSHAcceptHandler from .connection import create_connection, create_server, connect, listen from .connection import connect_reverse, listen_reverse, get_server_host_key from .connection import get_server_auth_methods, run_client, run_server diff --git a/asyncssh/connection.py b/asyncssh/connection.py index 872d1dd..3aaa70c 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -233,6 +233,7 @@ async def create_server(self, session_factory: TCPListenerFactory, _VersionArg = DefTuple[BytesOrStr] +SSHAcceptHandler = Callable[[str, int], MaybeAwait[bool]] # SSH service names _USERAUTH_SERVICE = b'ssh-userauth' @@ -2886,10 +2887,10 @@ async def forward_unix_connection(self, dest_path: str) -> SSHForwarder: return SSHForwarder(cast(SSHForwarder, peer)) @async_context_manager - async def forward_local_port(self, listen_host: str, - listen_port: int, - dest_host: str, - dest_port: int) -> SSHListener: + async def forward_local_port( + self, listen_host: str, listen_port: int, + dest_host: str, dest_port: int, + accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener: """Set up local port forwarding This method is a coroutine which attempts to set up port @@ -2906,10 +2907,17 @@ async def forward_local_port(self, listen_host: str, The hostname or address to forward the connections to :param dest_port: The port number to forward the connections to + :param accept_handler: + A `callable` or coroutine which takes arguments of the + original host and port of the client and decides whether + or not to allow connection forwarding, returning `True` to + accept the connection and begin forwarding or `False` to + reject and close it. :type listen_host: `str` :type listen_port: `int` :type dest_host: `str` :type dest_port: `int` + :type accept_handler: `callable` or coroutine :returns: :class:`SSHListener` @@ -2923,6 +2931,21 @@ async def tunnel_connection( Tuple[SSHTCPChannel[bytes], SSHTCPSession[bytes]]: """Forward a local connection over SSH""" + if accept_handler: + result = accept_handler(orig_host, orig_port) + + if inspect.isawaitable(result): + result = await cast(Awaitable[bool], result) + + if not result: + self.logger.info('Request for TCP forwarding from ' + '%s to %s denied by application', + (orig_host, orig_port), + (dest_host, dest_port)) + + raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, + 'Connection forwarding denied') + return (await self.create_connection(session_factory, dest_host, dest_port, orig_host, orig_port)) @@ -4695,9 +4718,9 @@ async def listen_reverse_ssh(self, host: str = '', **kwargs) # type: ignore @async_context_manager - async def forward_local_port_to_path(self, listen_host: str, - listen_port: int, - dest_path: str) -> SSHListener: + async def forward_local_port_to_path( + self, listen_host: str, listen_port: int, dest_path: str, + accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener: """Set up local TCP port forwarding to a remote UNIX domain socket This method is a coroutine which attempts to set up port @@ -4712,9 +4735,16 @@ async def forward_local_port_to_path(self, listen_host: str, The port number on the local host to listen on :param dest_path: The path on the remote host to forward the connections to + :param accept_handler: + A `callable` or coroutine which takes arguments of the + original host and port of the client and decides whether + or not to allow connection forwarding, returning `True` to + accept the connection and begin forwarding or `False` to + reject and close it. :type listen_host: `str` :type listen_port: `int` :type dest_path: `str` + :type accept_handler: `callable` or coroutine :returns: :class:`SSHListener` @@ -4724,10 +4754,24 @@ async def forward_local_port_to_path(self, listen_host: str, async def tunnel_connection( session_factory: SSHUNIXSessionFactory[bytes], - _orig_host: str, _orig_port: int) -> \ + orig_host: str, orig_port: int) -> \ Tuple[SSHUNIXChannel[bytes], SSHUNIXSession[bytes]]: """Forward a local connection over SSH""" + if accept_handler: + result = accept_handler(orig_host, orig_port) + + if inspect.isawaitable(result): + result = await cast(Awaitable[bool], result) + + if not result: + self.logger.info('Request for TCP forwarding from ' + '%s to %s denied by application', + (orig_host, orig_port), dest_path) + + raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, + 'Connection forwarding denied') + return (await self.create_unix_connection(session_factory, dest_path)) @@ -5737,6 +5781,10 @@ async def _finish_port_forward(self, listen_host: str, if listener is True: listener = await self.forward_local_port( listen_host, listen_port, listen_host, listen_port) + elif callable(listener): + listener = await self.forward_local_port( + listen_host, listen_port, + listen_host, listen_port, listener) except OSError: self.logger.debug1('Failed to create TCP listener') self._report_global_response(False) diff --git a/asyncssh/server.py b/asyncssh/server.py index 2c17647..0bab344 100644 --- a/asyncssh/server.py +++ b/asyncssh/server.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: # pylint: disable=cyclic-import - from .connection import SSHServerConnection + from .connection import SSHServerConnection, SSHAcceptHandler from .channel import SSHServerChannel, SSHTCPChannel, SSHUNIXChannel from .session import SSHServerSession, SSHTCPSession, SSHUNIXSession @@ -45,7 +45,7 @@ _NewUNIXSession = Union[bool, 'SSHUNIXSession', SSHSocketSessionFactory, Tuple['SSHUNIXChannel', 'SSHUNIXSession'], Tuple['SSHUNIXChannel', SSHSocketSessionFactory]] -_NewListener = Union[bool, SSHListener] +_NewListener = Union[bool, 'SSHAcceptHandler', SSHListener] class SSHServer: diff --git a/tests/test_forward.py b/tests/test_forward.py index 47afaa4..7c2895e 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -183,6 +183,18 @@ async def server_requested(self, listen_host, listen_port): return listen_host != 'fail' +class _TCPAcceptHandlerServer(Server): + """Server for testing forwarding accept handler""" + + async def server_requested(self, listen_host, listen_port): + """Handle a request to create a new socket listener""" + + def accept_handler(_orig_host: str, _orig_port: int) -> bool: + return True + + return accept_handler + + class _UNIXConnectionServer(Server): """Server for testing direct and forwarded UNIX domain connections""" @@ -594,6 +606,39 @@ async def test_forward_local_port(self): await self._check_local_connection(listener.get_port(), delay=0.1) + @asynctest + async def test_forward_local_port_accept_handler(self): + """Test forwarding of a local port with an accept handler""" + + def accept_handler(_orig_host: str, _orig_port: int) -> bool: + return True + + async with self.connect() as conn: + async with conn.forward_local_port('', 0, '', 7, + accept_handler) as listener: + await self._check_local_connection(listener.get_port(), + delay=0.1) + + @asynctest + async def test_forward_local_port_accept_handler_denial(self): + """Test forwarding of a local port with an accept handler denial""" + + async def accept_handler(_orig_host: str, _orig_port: int) -> bool: + return False + + async with self.connect() as conn: + async with conn.forward_local_port('', 0, '', 7, + accept_handler) as listener: + listen_port = listener.get_port() + + reader, writer = await asyncio.open_connection('127.0.0.1', + listen_port) + + self.assertEqual((await reader.read()), b'') + + writer.close() + await maybe_wait_closed(writer) + @unittest.skipIf(sys.platform == 'win32', 'skip UNIX domain socket tests on Windows') @asynctest @@ -855,6 +900,33 @@ async def test_listener_close_on_conn_close(self): await listener.wait_closed() +class _TestTCPForwardingAcceptHandler(_CheckForwarding): + """Unit tests for TCP forwarding with accept handler""" + + @classmethod + async def start_server(cls): + """Start an SSH server which supports TCP connection forwarding""" + + return await cls.create_server( + _TCPAcceptHandlerServer, authorized_client_keys='authorized_keys') + + @asynctest + async def test_forward_remote_port_accept_handler(self): + """Test forwarding of a remote port with accept handler""" + + server = await asyncio.start_server(echo, None, 0, + family=socket.AF_INET) + server_port = server.sockets[0].getsockname()[1] + + async with self.connect() as conn: + async with conn.forward_remote_port( + '', 0, '127.0.0.1', server_port) as listener: + await self._check_local_connection(listener.get_port()) + + server.close() + await server.wait_closed() + + class _TestAsyncTCPForwarding(_TestTCPForwarding): """Unit tests for AsyncSSH TCP connection forwarding with async return""" @@ -999,6 +1071,39 @@ async def test_forward_local_path(self): os.remove('local') + @asynctest + async def test_forward_local_port_to_path_accept_handler(self): + """Test forwarding of port to UNIX path with accept handler""" + + def accept_handler(_orig_host: str, _orig_port: int) -> bool: + return True + + async with self.connect() as conn: + async with conn.forward_local_port_to_path( + '', 0, '/echo', accept_handler) as listener: + await self._check_local_connection(listener.get_port(), + delay=0.1) + + @asynctest + async def test_forward_local_port_to_path_accept_handler_denial(self): + """Test forwarding of port to UNIX path with accept handler denial""" + + async def accept_handler(_orig_host: str, _orig_port: int) -> bool: + return False + + async with self.connect() as conn: + async with conn.forward_local_port_to_path( + '', 0, '/echo', accept_handler) as listener: + listen_port = listener.get_port() + + reader, writer = await asyncio.open_connection('127.0.0.1', + listen_port) + + self.assertEqual((await reader.read()), b'') + + writer.close() + await maybe_wait_closed(writer) + @asynctest async def test_forward_local_port_to_path(self): """Test forwarding of a local port to a remote UNIX domain socket"""