From 9a3d13ff902fcbc77df6ae2038034974f156ae6d Mon Sep 17 00:00:00 2001 From: Ron Frederick Date: Fri, 29 Nov 2024 11:03:27 -0800 Subject: [PATCH] Update hostname canonicalization to better match OpenSSH This commit changes the initial hostname check in host canonicalization to only look for literal IP addresses, rather than attempting to resolve the host as if it were already fully qualified. This is more consistent with OpenSSH and also should avoid a potentially slow additional name lookup on every connection. --- asyncssh/connection.py | 47 ++++++++++++++++++++-------------------- tests/test_connection.py | 18 ++++++++++----- tests/util.py | 3 --- 3 files changed, 36 insertions(+), 32 deletions(-) diff --git a/asyncssh/connection.py b/asyncssh/connection.py index e59fff6..3c666c0 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -25,6 +25,7 @@ import getpass import inspect import io +import ipaddress import os import shlex import socket @@ -274,18 +275,6 @@ async def create_server(self, session_factory: TCPListenerFactory, _DEFAULT_MAX_LINE_LENGTH = 1024 # 1024 characters -async def _resolve_host(host, loop: asyncio.AbstractEventLoop) -> Optional[str]: - """Attempt to resolve a hostname, returning a canonical name""" - - try: - addrinfo = await loop.getaddrinfo(host + '.', 0, - flags=socket.AI_CANONNAME) - except socket.gaierror: - return None - else: - return addrinfo[0][3] - - async def _canonicalize_host(loop: asyncio.AbstractEventLoop, options: 'SSHConnectionOptions') -> Optional[str]: """Canonicalize a host name""" @@ -293,25 +282,35 @@ async def _canonicalize_host(loop: asyncio.AbstractEventLoop, host = options.host if not options.canonicalize_hostname or not options.canonical_domains or \ - host.count('.') > options.canonicalize_max_dots or \ - (await _resolve_host(host, loop)): + host.count('.') > options.canonicalize_max_dots: + return None + + try: + ipaddress.ip_address(host) + except ValueError: + pass + else: return None for domain in options.canonical_domains: canon_host = f'{host}.{domain}' - cname = await _resolve_host(canon_host, loop) - if cname is not None: - if cname: - for patterns in options.canonicalize_permitted_cnames: - host_pat, cname_pat = map(WildcardPatternList, patterns) + try: + addrinfo = await loop.getaddrinfo( + canon_host, 0, flags=socket.AI_CANONNAME) + except socket.gaierror: + continue + + cname = addrinfo[0][3] + + if cname: + for patterns in options.canonicalize_permitted_cnames: + host_pat, cname_pat = map(WildcardPatternList, patterns) - if host_pat.matches(canon_host) and \ - cname_pat.matches(cname): - canon_host = cname - break + if host_pat.matches(canon_host) and cname_pat.matches(cname): + return cname - return canon_host + return canon_host if not options.canonicalize_fallback_local: raise OSError(f'Unable to canonicalize hostname "{host}"') diff --git a/tests/test_connection.py b/tests/test_connection.py index 7fea9ef..d1a0826 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -2688,18 +2688,27 @@ async def start_server(cls): async def test_canonicalize(self): """Test hostname canonicalization""" - async with self.connect('testhost', known_hosts=(['skey.pub'], [], []), + async with self.connect('testhost', known_hosts=None, canonicalize_hostname=True, canonical_domains=['test']) as conn: self.assertEqual(conn.get_extra_info('host'), 'testhost.test') + @asynctest + async def test_canonicalize_ip_address(self): + """Test hostname canonicalization with IP address""" + + async with self.connect('127.0.0.1', known_hosts=None, + canonicalize_hostname=True, + canonicalize_max_dots=3, + canonical_domains=['test']) as conn: + self.assertEqual(conn.get_extra_info('host'), '127.0.0.1') + @asynctest async def test_canonicalize_proxy(self): """Test hostname canonicalization with proxy""" with open('config', 'w') as f: - f.write('UserKnownHostsFile none\n' - 'Match host localhost\nPubkeyAuthentication no') + f.write('UserKnownHostsFile none\n') async with self.connect('testhost', config='config', tunnel=f'localhost:{self._server_port}', @@ -2712,8 +2721,7 @@ async def test_canonicalize_always(self): """Test hostname canonicalization for all connections""" with open('config', 'w') as f: - f.write('UserKnownHostsFile none\n' - 'Match host localhost\nPubkeyAuthentication no') + f.write('UserKnownHostsFile none\n') async with self.connect('testhost', config='config', tunnel=f'localhost:{self._server_port}', diff --git a/tests/util.py b/tests/util.py index bb7caf3..692212a 100644 --- a/tests/util.py +++ b/tests/util.py @@ -110,9 +110,6 @@ def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): # pylint: disable=unused-argument - if host.endswith('.'): - host = host[:-1] - try: return [(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, hosts[host], ('127.0.0.1', port))]