Skip to content

Commit

Permalink
Add support for an "accept handler" in connection forwarding
Browse files Browse the repository at this point in the history
This commit adds support for a new accept_handler argument in the
forward_local_port and forward_local_port_to_path methods in
SSHClientConnection and the ability to return an accept handler
in the server_requested method in SSHServer. This method receives
the original host & port of the incoming forwarded connection and
can return a bool to determine whether forwarding is allowed or not.
Thanks go to GitHub user zgxkbtl for suggesting this feature!
  • Loading branch information
ronf committed Oct 1, 2023
1 parent 777d328 commit 70f65eb
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 10 deletions.
1 change: 1 addition & 0 deletions asyncssh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from .connection import SSHAcceptor, SSHClientConnection, SSHServerConnection
from .connection import SSHClientConnectionOptions, SSHServerConnectionOptions
from .connection import SSHAcceptHandler
from .connection import create_connection, create_server, connect, listen
from .connection import connect_reverse, listen_reverse, get_server_host_key
from .connection import get_server_auth_methods, run_client, run_server
Expand Down
64 changes: 56 additions & 8 deletions asyncssh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ async def create_server(self, session_factory: TCPListenerFactory,

_VersionArg = DefTuple[BytesOrStr]

SSHAcceptHandler = Callable[[str, int], MaybeAwait[bool]]

# SSH service names
_USERAUTH_SERVICE = b'ssh-userauth'
Expand Down Expand Up @@ -2886,10 +2887,10 @@ async def forward_unix_connection(self, dest_path: str) -> SSHForwarder:
return SSHForwarder(cast(SSHForwarder, peer))

@async_context_manager
async def forward_local_port(self, listen_host: str,
listen_port: int,
dest_host: str,
dest_port: int) -> SSHListener:
async def forward_local_port(
self, listen_host: str, listen_port: int,
dest_host: str, dest_port: int,
accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener:
"""Set up local port forwarding
This method is a coroutine which attempts to set up port
Expand All @@ -2906,10 +2907,17 @@ async def forward_local_port(self, listen_host: str,
The hostname or address to forward the connections to
:param dest_port:
The port number to forward the connections to
:param accept_handler:
A `callable` or coroutine which takes arguments of the
original host and port of the client and decides whether
or not to allow connection forwarding, returning `True` to
accept the connection and begin forwarding or `False` to
reject and close it.
:type listen_host: `str`
:type listen_port: `int`
:type dest_host: `str`
:type dest_port: `int`
:type accept_handler: `callable` or coroutine
:returns: :class:`SSHListener`
Expand All @@ -2923,6 +2931,21 @@ async def tunnel_connection(
Tuple[SSHTCPChannel[bytes], SSHTCPSession[bytes]]:
"""Forward a local connection over SSH"""

if accept_handler:
result = accept_handler(orig_host, orig_port)

if inspect.isawaitable(result):
result = await cast(Awaitable[bool], result)

if not result:
self.logger.info('Request for TCP forwarding from '
'%s to %s denied by application',
(orig_host, orig_port),
(dest_host, dest_port))

raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED,
'Connection forwarding denied')

return (await self.create_connection(session_factory,
dest_host, dest_port,
orig_host, orig_port))
Expand Down Expand Up @@ -4695,9 +4718,9 @@ async def listen_reverse_ssh(self, host: str = '',
**kwargs) # type: ignore

@async_context_manager
async def forward_local_port_to_path(self, listen_host: str,
listen_port: int,
dest_path: str) -> SSHListener:
async def forward_local_port_to_path(
self, listen_host: str, listen_port: int, dest_path: str,
accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener:
"""Set up local TCP port forwarding to a remote UNIX domain socket
This method is a coroutine which attempts to set up port
Expand All @@ -4712,9 +4735,16 @@ async def forward_local_port_to_path(self, listen_host: str,
The port number on the local host to listen on
:param dest_path:
The path on the remote host to forward the connections to
:param accept_handler:
A `callable` or coroutine which takes arguments of the
original host and port of the client and decides whether
or not to allow connection forwarding, returning `True` to
accept the connection and begin forwarding or `False` to
reject and close it.
:type listen_host: `str`
:type listen_port: `int`
:type dest_path: `str`
:type accept_handler: `callable` or coroutine
:returns: :class:`SSHListener`
Expand All @@ -4724,10 +4754,24 @@ async def forward_local_port_to_path(self, listen_host: str,

async def tunnel_connection(
session_factory: SSHUNIXSessionFactory[bytes],
_orig_host: str, _orig_port: int) -> \
orig_host: str, orig_port: int) -> \
Tuple[SSHUNIXChannel[bytes], SSHUNIXSession[bytes]]:
"""Forward a local connection over SSH"""

if accept_handler:
result = accept_handler(orig_host, orig_port)

if inspect.isawaitable(result):
result = await cast(Awaitable[bool], result)

if not result:
self.logger.info('Request for TCP forwarding from '
'%s to %s denied by application',
(orig_host, orig_port), dest_path)

raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED,
'Connection forwarding denied')

return (await self.create_unix_connection(session_factory,
dest_path))

Expand Down Expand Up @@ -5737,6 +5781,10 @@ async def _finish_port_forward(self, listen_host: str,
if listener is True:
listener = await self.forward_local_port(
listen_host, listen_port, listen_host, listen_port)
elif callable(listener):
listener = await self.forward_local_port(
listen_host, listen_port,
listen_host, listen_port, listener)
except OSError:
self.logger.debug1('Failed to create TCP listener')
self._report_global_response(False)
Expand Down
4 changes: 2 additions & 2 deletions asyncssh/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

if TYPE_CHECKING:
# pylint: disable=cyclic-import
from .connection import SSHServerConnection
from .connection import SSHServerConnection, SSHAcceptHandler
from .channel import SSHServerChannel, SSHTCPChannel, SSHUNIXChannel
from .session import SSHServerSession, SSHTCPSession, SSHUNIXSession

Expand All @@ -45,7 +45,7 @@
_NewUNIXSession = Union[bool, 'SSHUNIXSession', SSHSocketSessionFactory,
Tuple['SSHUNIXChannel', 'SSHUNIXSession'],
Tuple['SSHUNIXChannel', SSHSocketSessionFactory]]
_NewListener = Union[bool, SSHListener]
_NewListener = Union[bool, 'SSHAcceptHandler', SSHListener]


class SSHServer:
Expand Down
105 changes: 105 additions & 0 deletions tests/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,18 @@ async def server_requested(self, listen_host, listen_port):
return listen_host != 'fail'


class _TCPAcceptHandlerServer(Server):
"""Server for testing forwarding accept handler"""

async def server_requested(self, listen_host, listen_port):
"""Handle a request to create a new socket listener"""

def accept_handler(_orig_host: str, _orig_port: int) -> bool:
return True

return accept_handler


class _UNIXConnectionServer(Server):
"""Server for testing direct and forwarded UNIX domain connections"""

Expand Down Expand Up @@ -594,6 +606,39 @@ async def test_forward_local_port(self):
await self._check_local_connection(listener.get_port(),
delay=0.1)

@asynctest
async def test_forward_local_port_accept_handler(self):
"""Test forwarding of a local port with an accept handler"""

def accept_handler(_orig_host: str, _orig_port: int) -> bool:
return True

async with self.connect() as conn:
async with conn.forward_local_port('', 0, '', 7,
accept_handler) as listener:
await self._check_local_connection(listener.get_port(),
delay=0.1)

@asynctest
async def test_forward_local_port_accept_handler_denial(self):
"""Test forwarding of a local port with an accept handler denial"""

async def accept_handler(_orig_host: str, _orig_port: int) -> bool:
return False

async with self.connect() as conn:
async with conn.forward_local_port('', 0, '', 7,
accept_handler) as listener:
listen_port = listener.get_port()

reader, writer = await asyncio.open_connection('127.0.0.1',
listen_port)

self.assertEqual((await reader.read()), b'')

writer.close()
await maybe_wait_closed(writer)

@unittest.skipIf(sys.platform == 'win32',
'skip UNIX domain socket tests on Windows')
@asynctest
Expand Down Expand Up @@ -855,6 +900,33 @@ async def test_listener_close_on_conn_close(self):
await listener.wait_closed()


class _TestTCPForwardingAcceptHandler(_CheckForwarding):
"""Unit tests for TCP forwarding with accept handler"""

@classmethod
async def start_server(cls):
"""Start an SSH server which supports TCP connection forwarding"""

return await cls.create_server(
_TCPAcceptHandlerServer, authorized_client_keys='authorized_keys')

@asynctest
async def test_forward_remote_port_accept_handler(self):
"""Test forwarding of a remote port with accept handler"""

server = await asyncio.start_server(echo, None, 0,
family=socket.AF_INET)
server_port = server.sockets[0].getsockname()[1]

async with self.connect() as conn:
async with conn.forward_remote_port(
'', 0, '127.0.0.1', server_port) as listener:
await self._check_local_connection(listener.get_port())

server.close()
await server.wait_closed()


class _TestAsyncTCPForwarding(_TestTCPForwarding):
"""Unit tests for AsyncSSH TCP connection forwarding with async return"""

Expand Down Expand Up @@ -999,6 +1071,39 @@ async def test_forward_local_path(self):

os.remove('local')

@asynctest
async def test_forward_local_port_to_path_accept_handler(self):
"""Test forwarding of port to UNIX path with accept handler"""

def accept_handler(_orig_host: str, _orig_port: int) -> bool:
return True

async with self.connect() as conn:
async with conn.forward_local_port_to_path(
'', 0, '/echo', accept_handler) as listener:
await self._check_local_connection(listener.get_port(),
delay=0.1)

@asynctest
async def test_forward_local_port_to_path_accept_handler_denial(self):
"""Test forwarding of port to UNIX path with accept handler denial"""

async def accept_handler(_orig_host: str, _orig_port: int) -> bool:
return False

async with self.connect() as conn:
async with conn.forward_local_port_to_path(
'', 0, '/echo', accept_handler) as listener:
listen_port = listener.get_port()

reader, writer = await asyncio.open_connection('127.0.0.1',
listen_port)

self.assertEqual((await reader.read()), b'')

writer.close()
await maybe_wait_closed(writer)

@asynctest
async def test_forward_local_port_to_path(self):
"""Test forwarding of a local port to a remote UNIX domain socket"""
Expand Down

0 comments on commit 70f65eb

Please sign in to comment.