-
Notifications
You must be signed in to change notification settings - Fork 155
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use an executor to prevent GSSAPI calls from blocking the event loop
Some operations such as GSSAPI calls can sometimes block the event loop if not run in an executor. However, doing that requires packet handlers to be asynchronous. This commit adds support for async packet handlers for key exchange and auth, and changes the GSSAPI handlers to run the step() call in an executor.
- Loading branch information
Showing
10 changed files
with
210 additions
and
155 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Copyright (c) 2013-2022 by Ron Frederick <[email protected]> and others. | ||
# Copyright (c) 2013-2024 by Ron Frederick <[email protected]> and others. | ||
# | ||
# This program and the accompanying materials are made available under | ||
# the terms of the Eclipse Public License v2.0 which accompanies this | ||
|
@@ -27,6 +27,7 @@ | |
from .gss import GSSBase, GSSError | ||
from .logging import SSHLogger | ||
from .misc import ProtocolError, PasswordChangeRequired, get_symbol_names | ||
from .misc import run_in_executor | ||
from .packet import Boolean, String, UInt32, SSHPacket, SSHPacketHandler | ||
from .public_key import SigningKey | ||
from .saslprep import saslprep, SASLPrepError | ||
|
@@ -199,8 +200,8 @@ def _finish(self) -> None: | |
else: | ||
self.send_packet(MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE) | ||
|
||
def _process_response(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
async def _process_response(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
"""Process a GSS response from the server""" | ||
|
||
mech = packet.get_string() | ||
|
@@ -212,7 +213,7 @@ def _process_response(self, _pkttype: int, _pktid: int, | |
raise ProtocolError('Mechanism mismatch') | ||
|
||
try: | ||
token = self._gss.step() | ||
token = await run_in_executor(self._gss.step) | ||
assert token is not None | ||
|
||
self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token)) | ||
|
@@ -225,8 +226,8 @@ def _process_response(self, _pkttype: int, _pktid: int, | |
|
||
self._conn.try_next_auth() | ||
|
||
def _process_token(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
async def _process_token(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
"""Process a GSS token from the server""" | ||
|
||
token: Optional[bytes] = packet.get_string() | ||
|
@@ -235,7 +236,7 @@ def _process_token(self, _pkttype: int, _pktid: int, | |
assert self._gss is not None | ||
|
||
try: | ||
token = self._gss.step(token) | ||
token = await run_in_executor(self._gss.step, token) | ||
|
||
if token: | ||
self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token)) | ||
|
@@ -261,8 +262,8 @@ def _process_error(self, _pkttype: int, _pktid: int, | |
self.logger.debug1('GSS error from server: %s', msg) | ||
self._got_error = True | ||
|
||
def _process_error_token(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
async def _process_error_token(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
"""Process a GSS error token from the server""" | ||
|
||
token = packet.get_string() | ||
|
@@ -271,7 +272,7 @@ def _process_error_token(self, _pkttype: int, _pktid: int, | |
assert self._gss is not None | ||
|
||
try: | ||
self._gss.step(token) | ||
await run_in_executor(self._gss.step, token) | ||
except GSSError as exc: | ||
if not self._got_error: # pragma: no cover | ||
self.logger.debug1('GSS error from server: %s', str(exc)) | ||
|
@@ -649,15 +650,15 @@ async def _finish(self) -> None: | |
else: | ||
self.send_failure() | ||
|
||
def _process_token(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
async def _process_token(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
"""Process a GSS token from the client""" | ||
|
||
token: Optional[bytes] = packet.get_string() | ||
packet.check_end() | ||
|
||
try: | ||
token = self._gss.step(token) | ||
token = await run_in_executor(self._gss.step, token) | ||
|
||
if token: | ||
self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token)) | ||
|
@@ -682,15 +683,15 @@ def _process_exchange_complete(self, _pkttype: int, _pktid: int, | |
else: | ||
self.send_failure() | ||
|
||
def _process_error_token(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
async def _process_error_token(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
"""Process a GSS error token from the client""" | ||
|
||
token = packet.get_string() | ||
packet.check_end() | ||
|
||
try: | ||
self._gss.step(token) | ||
await run_in_executor(self._gss.step, token) | ||
except GSSError as exc: | ||
self.logger.debug1('GSS error from client: %s', str(exc)) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Copyright (c) 2013-2022 by Ron Frederick <[email protected]> and others. | ||
# Copyright (c) 2013-2024 by Ron Frederick <[email protected]> and others. | ||
# | ||
# This program and the accompanying materials are made available under | ||
# the terms of the Eclipse Public License v2.0 which accompanies this | ||
|
@@ -58,7 +58,7 @@ def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType): | |
self._hash_alg = hash_alg | ||
|
||
|
||
def start(self) -> None: | ||
async def start(self) -> None: | ||
"""Start key exchange""" | ||
|
||
raise NotImplementedError | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Copyright (c) 2013-2022 by Ron Frederick <[email protected]> and others. | ||
# Copyright (c) 2013-2024 by Ron Frederick <[email protected]> and others. | ||
# | ||
# This program and the accompanying materials are made available under | ||
# the terms of the Eclipse Public License v2.0 which accompanies this | ||
|
@@ -33,7 +33,7 @@ | |
from .gss import GSSError | ||
from .kex import Kex, register_kex_alg, register_gss_kex_alg | ||
from .misc import HashType, KeyExchangeFailed, ProtocolError | ||
from .misc import get_symbol_names | ||
from .misc import get_symbol_names, run_in_executor | ||
from .packet import Boolean, MPInt, String, UInt32, SSHPacket | ||
from .public_key import SigningKey, VerifyingKey | ||
|
||
|
@@ -274,7 +274,7 @@ def _process_reply(self, _pkttype: int, _pktid: int, | |
host_key = client_conn.validate_server_host_key(host_key_data) | ||
self._verify_reply(host_key, host_key_data, sig) | ||
|
||
def start(self) -> None: | ||
async def start(self) -> None: | ||
"""Start DH key exchange""" | ||
|
||
if self._conn.is_client(): | ||
|
@@ -384,7 +384,7 @@ def _process_group(self, _pkttype: int, _pktid: int, | |
self._gex_data += MPInt(p) + MPInt(g) | ||
self._perform_init() | ||
|
||
def start(self) -> None: | ||
async def start(self) -> None: | ||
"""Start DH group exchange""" | ||
|
||
if self._conn.is_client(): | ||
|
@@ -455,7 +455,7 @@ def _compute_server_shared(self) -> bytes: | |
except ValueError: | ||
raise ProtocolError('Invalid ECDH client public key') from None | ||
|
||
def start(self) -> None: | ||
async def start(self) -> None: | ||
"""Start ECDH key exchange""" | ||
|
||
if self._conn.is_client(): | ||
|
@@ -567,11 +567,11 @@ def _send_continue(self) -> None: | |
|
||
self.send_packet(MSG_KEXGSS_CONTINUE, String(self._token)) | ||
|
||
def _process_token(self, token: Optional[bytes] = None) -> None: | ||
async def _process_token(self, token: Optional[bytes] = None) -> None: | ||
"""Process a GSS token""" | ||
|
||
try: | ||
self._token = self._gss.step(token) | ||
self._token = await run_in_executor(self._gss.step, token) | ||
except GSSError as exc: | ||
if self._conn.is_server(): | ||
self.send_packet(MSG_KEXGSS_ERROR, UInt32(exc.maj_code), | ||
|
@@ -583,8 +583,8 @@ def _process_token(self, token: Optional[bytes] = None) -> None: | |
|
||
raise KeyExchangeFailed(str(exc)) from None | ||
|
||
def _process_init(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
async def _process_gss_init(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
"""Process a GSS init message""" | ||
|
||
if self._conn.is_client(): | ||
|
@@ -603,7 +603,7 @@ def _process_init(self, _pkttype: int, _pktid: int, | |
else: | ||
self._host_key_data = b'' | ||
|
||
self._process_token(token) | ||
await self._process_token(token) | ||
|
||
if self._gss.complete: | ||
self._check_secure() | ||
|
@@ -612,8 +612,8 @@ def _process_init(self, _pkttype: int, _pktid: int, | |
else: | ||
self._send_continue() | ||
|
||
def _process_continue(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
async def _process_continue(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
"""Process a GSS continue message""" | ||
|
||
token = packet.get_string() | ||
|
@@ -622,16 +622,16 @@ def _process_continue(self, _pkttype: int, _pktid: int, | |
if self._conn.is_client() and self._gss.complete: | ||
raise ProtocolError('Unexpected kexgss continue msg') | ||
|
||
self._process_token(token) | ||
await self._process_token(token) | ||
|
||
if self._conn.is_server() and self._gss.complete: | ||
self._check_secure() | ||
self._perform_reply(self._gss, self._host_key_data) | ||
else: | ||
self._send_continue() | ||
|
||
def _process_complete(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
async def _process_complete(self, _pkttype: int, _pktid: int, | ||
packet: SSHPacket) -> None: | ||
"""Process a GSS complete message""" | ||
|
||
if self._conn.is_server(): | ||
|
@@ -647,7 +647,7 @@ def _process_complete(self, _pkttype: int, _pktid: int, | |
if self._gss.complete: | ||
raise ProtocolError('Non-empty token after complete') | ||
|
||
self._process_token(token) | ||
await self._process_token(token) | ||
|
||
if self._token: | ||
raise ProtocolError('Non-empty token after complete') | ||
|
@@ -682,12 +682,12 @@ def _process_error(self, _pkttype: int, _pktid: int, | |
self._conn.logger.debug1('GSS error: %s', | ||
msg.decode('utf-8', errors='ignore')) | ||
|
||
def start(self) -> None: | ||
async def start(self) -> None: | ||
"""Start GSS key exchange""" | ||
|
||
if self._conn.is_client(): | ||
self._process_token() | ||
super().start() | ||
await self._process_token() | ||
await super().start() | ||
|
||
|
||
class _KexGSS(_KexGSSBase, _KexDH): | ||
|
@@ -696,7 +696,7 @@ class _KexGSS(_KexGSSBase, _KexDH): | |
_handler_names = get_symbol_names(globals(), 'MSG_KEXGSS_') | ||
|
||
_packet_handlers = { | ||
MSG_KEXGSS_INIT: _KexGSSBase._process_init, | ||
MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init, | ||
MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, | ||
MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, | ||
MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, | ||
|
@@ -713,7 +713,7 @@ class _KexGSSGex(_KexGSSBase, _KexDHGex): | |
_group_type = MSG_KEXGSS_GROUP | ||
|
||
_packet_handlers = { | ||
MSG_KEXGSS_INIT: _KexGSSBase._process_init, | ||
MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init, | ||
MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, | ||
MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, | ||
MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, | ||
|
@@ -729,7 +729,7 @@ class _KexGSSECDH(_KexGSSBase, _KexECDH): | |
_handler_names = get_symbol_names(globals(), 'MSG_KEXGSS_') | ||
|
||
_packet_handlers = { | ||
MSG_KEXGSS_INIT: _KexGSSBase._process_init, | ||
MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init, | ||
MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, | ||
MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, | ||
MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, | ||
|
Oops, something went wrong.