diff --git a/changelog.d/17972.misc b/changelog.d/17972.misc new file mode 100644 index 00000000000..e7f009d20d4 --- /dev/null +++ b/changelog.d/17972.misc @@ -0,0 +1 @@ +Consolidate SSO redirects through `/_matrix/client/v3/login/sso/redirect(/{idpId})`. diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 03a3e96f289..655b5edd7a2 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -23,7 +23,8 @@ import hmac from hashlib import sha256 -from urllib.parse import urlencode +from typing import Optional +from urllib.parse import urlencode, urljoin from synapse.config import ConfigError from synapse.config.homeserver import HomeServerConfig @@ -66,3 +67,42 @@ def build_user_consent_uri(self, user_id: str) -> str: urlencode({"u": user_id, "h": mac}), ) return consent_uri + + +class LoginSSORedirectURIBuilder: + def __init__(self, hs_config: HomeServerConfig): + self._public_baseurl = hs_config.server.public_baseurl + + def build_login_sso_redirect_uri( + self, *, idp_id: Optional[str], client_redirect_url: str + ) -> str: + """Build a `/login/sso/redirect` URI for the given identity provider. + + Builds `/_matrix/client/v3/login/sso/redirect/{idpId}?redirectUrl=xxx` when `idp_id` is specified. + Otherwise, builds `/_matrix/client/v3/login/sso/redirect?redirectUrl=xxx` when `idp_id` is `None`. + + Args: + idp_id: Optional ID of the identity provider + client_redirect_url: URL to redirect the user to after login + + Returns + The URI to follow when choosing a specific identity provider. + """ + base_url = urljoin( + self._public_baseurl, + f"{CLIENT_API_PREFIX}/v3/login/sso/redirect", + ) + + serialized_query_parameters = urlencode({"redirectUrl": client_redirect_url}) + + if idp_id: + resultant_url = urljoin( + # We have to add a trailing slash to the base URL to ensure that the + # last path segment is not stripped away when joining with another path. + f"{base_url}/", + f"{idp_id}?{serialized_query_parameters}", + ) + else: + resultant_url = f"{base_url}?{serialized_query_parameters}" + + return resultant_url diff --git a/synapse/config/cas.py b/synapse/config/cas.py index fa59c350c15..c32bf36951d 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -20,7 +20,7 @@ # # -from typing import Any, List +from typing import Any, List, Optional from synapse.config.sso import SsoAttributeRequirement from synapse.types import JsonDict @@ -46,7 +46,9 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: # TODO Update this to a _synapse URL. public_baseurl = self.root.server.public_baseurl - self.cas_service_url = public_baseurl + "_matrix/client/r0/login/cas/ticket" + self.cas_service_url: Optional[str] = ( + public_baseurl + "_matrix/client/r0/login/cas/ticket" + ) self.cas_protocol_version = cas_config.get("protocol_version") if ( diff --git a/synapse/config/server.py b/synapse/config/server.py index ad7331de428..6b299836176 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -332,8 +332,14 @@ def read_config(self, config: JsonDict, **kwargs: Any) -> None: logger.info("Using default public_baseurl %s", public_baseurl) else: self.serve_client_wellknown = True + # Ensure that public_baseurl ends with a trailing slash if public_baseurl[-1] != "/": public_baseurl += "/" + + # Scrutinize user-provided config + if not isinstance(public_baseurl, str): + raise ConfigError("Must be a string", ("public_baseurl",)) + self.public_baseurl = public_baseurl # check that public_baseurl is valid diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py index f26929bd608..5e599f85b0b 100644 --- a/synapse/rest/synapse/client/pick_idp.py +++ b/synapse/rest/synapse/client/pick_idp.py @@ -21,6 +21,7 @@ import logging from typing import TYPE_CHECKING +from synapse.api.urls import LoginSSORedirectURIBuilder from synapse.http.server import ( DirectServeHtmlResource, finish_request, @@ -49,6 +50,8 @@ def __init__(self, hs: "HomeServer"): hs.config.sso.sso_login_idp_picker_template ) self._server_name = hs.hostname + self._public_baseurl = hs.config.server.public_baseurl + self._login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config) async def _async_render_GET(self, request: SynapseRequest) -> None: client_redirect_url = parse_string( @@ -56,25 +59,23 @@ async def _async_render_GET(self, request: SynapseRequest) -> None: ) idp = parse_string(request, "idp", required=False) - # if we need to pick an IdP, do so + # If we need to pick an IdP, do so if not idp: return await self._serve_id_picker(request, client_redirect_url) - # otherwise, redirect to the IdP's redirect URI - providers = self._sso_handler.get_identity_providers() - auth_provider = providers.get(idp) - if not auth_provider: - logger.info("Unknown idp %r", idp) - self._sso_handler.render_error( - request, "unknown_idp", "Unknown identity provider ID" + # Otherwise, redirect to the login SSO redirect endpoint for the given IdP + # (which will in turn take us to the the IdP's redirect URI). + # + # We could go directly to the IdP's redirect URI, but this way we ensure that + # the user goes through the same logic as normal flow. Additionally, if a proxy + # needs to intercept the request, it only needs to intercept the one endpoint. + sso_login_redirect_url = ( + self._login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id=idp, client_redirect_url=client_redirect_url ) - return - - sso_url = await auth_provider.handle_redirect_request( - request, client_redirect_url.encode("utf8") ) - logger.info("Redirecting to %s", sso_url) - request.redirect(sso_url) + logger.info("Redirecting to %s", sso_login_redirect_url) + request.redirect(sso_login_redirect_url) finish_request(request) async def _serve_id_picker( diff --git a/tests/api/test_urls.py b/tests/api/test_urls.py new file mode 100644 index 00000000000..ce156a05dc4 --- /dev/null +++ b/tests/api/test_urls.py @@ -0,0 +1,55 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.urls import LoginSSORedirectURIBuilder +from synapse.server import HomeServer +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + +# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + +TRICKY_TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' + + +class LoginSSORedirectURIBuilderTestCase(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config) + + def test_no_idp_id(self) -> None: + self.assertEqual( + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id=None, client_redirect_url="http://example.com/redirect" + ), + "https://test/_matrix/client/v3/login/sso/redirect?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect", + ) + + def test_explicit_idp_id(self) -> None: + self.assertEqual( + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id="oidc-github", client_redirect_url="http://example.com/redirect" + ), + "https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect", + ) + + def test_tricky_redirect_uri(self) -> None: + self.assertEqual( + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id="oidc-github", + client_redirect_url=TRICKY_TEST_CLIENT_REDIRECT_URL, + ), + "https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=https%3A%2F%2Fx%3F%3Cab+c%3E%26q%22%2B%253D%252B%22%3D%22f%C3%B6%2526%3Do%22", + ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index cbd6d8d4bf8..1451fd7c29c 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -43,6 +43,7 @@ import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType from synapse.api.errors import Codes +from synapse.api.urls import LoginSSORedirectURIBuilder from synapse.appservice import ApplicationService from synapse.http.client import RawHeaders from synapse.module_api import ModuleApi @@ -69,6 +70,10 @@ except ImportError: HAS_JWT = False +import logging + +logger = logging.getLogger(__name__) + # synapse server name: used to populate public_baseurl in some tests SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" @@ -77,7 +82,7 @@ # FakeChannel.isSecure() returns False, so synapse will see the requested uri as # http://..., so using http in the public_baseurl stops Synapse trying to redirect to # https://.... -BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) +PUBLIC_BASEURL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) # CAS server used in some tests CAS_SERVER = "https://fake.test" @@ -109,6 +114,23 @@ ] +def get_relative_uri_from_absolute_uri(absolute_uri: str) -> str: + """ + Peels off the path and query string from an absolute URI. Useful when interacting + with `make_request(...)` util function which expects a relative path instead of a + full URI. + """ + parsed_uri = urllib.parse.urlparse(absolute_uri) + # Sanity check that we're working with an absolute URI + assert parsed_uri.scheme == "http" or parsed_uri.scheme == "https" + + relative_uri = parsed_uri.path + if parsed_uri.query: + relative_uri += "?" + parsed_uri.query + + return relative_uri + + class TestSpamChecker: def __init__(self, config: None, api: ModuleApi): api.register_spam_checker_callbacks( @@ -614,7 +636,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): def default_config(self) -> Dict[str, Any]: config = super().default_config() - config["public_baseurl"] = BASE_URL + config["public_baseurl"] = PUBLIC_BASEURL config["cas_config"] = { "enabled": True, @@ -653,6 +675,9 @@ def default_config(self) -> Dict[str, Any]: ] return config + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.login_sso_redirect_url_builder = LoginSSORedirectURIBuilder(hs.config) + def create_resource_dict(self) -> Dict[str, Resource]: d = super().create_resource_dict() d.update(build_synapse_client_resource_tree(self.hs)) @@ -725,6 +750,32 @@ def test_multi_sso_redirect_to_cas(self) -> None: + "&idp=cas", shorthand=False, ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + sso_login_redirect_uri = location_headers[0] + + # it should redirect us to the standard login SSO redirect flow + self.assertEqual( + sso_login_redirect_uri, + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id="cas", client_redirect_url=TEST_CLIENT_REDIRECT_URL + ), + ) + + # follow the redirect + channel = self.make_request( + "GET", + # We have to make this relative to be compatible with `make_request(...)` + get_relative_uri_from_absolute_uri(sso_login_redirect_uri), + # We have to set the Host header to match the `public_baseurl` to avoid + # the extra redirect in the `SsoRedirectServlet` in order for the + # cookies to be visible. + custom_headers=[ + ("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME), + ], + ) + self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -750,6 +801,32 @@ def test_multi_sso_redirect_to_saml(self) -> None: + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=saml", ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + sso_login_redirect_uri = location_headers[0] + + # it should redirect us to the standard login SSO redirect flow + self.assertEqual( + sso_login_redirect_uri, + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id="saml", client_redirect_url=TEST_CLIENT_REDIRECT_URL + ), + ) + + # follow the redirect + channel = self.make_request( + "GET", + # We have to make this relative to be compatible with `make_request(...)` + get_relative_uri_from_absolute_uri(sso_login_redirect_uri), + # We have to set the Host header to match the `public_baseurl` to avoid + # the extra redirect in the `SsoRedirectServlet` in order for the + # cookies to be visible. + custom_headers=[ + ("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME), + ], + ) + self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -773,10 +850,35 @@ def test_login_via_oidc(self) -> None: # pick the default OIDC provider channel = self.make_request( "GET", - "/_synapse/client/pick_idp?redirectUrl=" - + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) - + "&idp=oidc", + f"/_synapse/client/pick_idp?redirectUrl={urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)}&idp=oidc", + ) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + sso_login_redirect_uri = location_headers[0] + + # it should redirect us to the standard login SSO redirect flow + self.assertEqual( + sso_login_redirect_uri, + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id="oidc", client_redirect_url=TEST_CLIENT_REDIRECT_URL + ), + ) + + with fake_oidc_server.patch_homeserver(hs=self.hs): + # follow the redirect + channel = self.make_request( + "GET", + # We have to make this relative to be compatible with `make_request(...)` + get_relative_uri_from_absolute_uri(sso_login_redirect_uri), + # We have to set the Host header to match the `public_baseurl` to avoid + # the extra redirect in the `SsoRedirectServlet` in order for the + # cookies to be visible. + custom_headers=[ + ("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME), + ], ) + self.assertEqual(channel.code, 302, channel.result) location_headers = channel.headers.getRawHeaders("Location") assert location_headers @@ -838,12 +940,38 @@ def test_login_via_oidc(self) -> None: self.assertEqual(chan.json_body["user_id"], "@user1:test") def test_multi_sso_redirect_to_unknown(self) -> None: - """An unknown IdP should cause a 400""" + """An unknown IdP should cause a 404""" channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", ) - self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.code, 302, channel.result) + location_headers = channel.headers.getRawHeaders("Location") + assert location_headers + sso_login_redirect_uri = location_headers[0] + + # it should redirect us to the standard login SSO redirect flow + self.assertEqual( + sso_login_redirect_uri, + self.login_sso_redirect_url_builder.build_login_sso_redirect_uri( + idp_id="xyz", client_redirect_url="http://x" + ), + ) + + # follow the redirect + channel = self.make_request( + "GET", + # We have to make this relative to be compatible with `make_request(...)` + get_relative_uri_from_absolute_uri(sso_login_redirect_uri), + # We have to set the Host header to match the `public_baseurl` to avoid + # the extra redirect in the `SsoRedirectServlet` in order for the + # cookies to be visible. + custom_headers=[ + ("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME), + ], + ) + + self.assertEqual(channel.code, 404, channel.result) def test_client_idp_redirect_to_unknown(self) -> None: """If the client tries to pick an unknown IdP, return a 404""" @@ -1473,7 +1601,7 @@ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def default_config(self) -> Dict[str, Any]: config = super().default_config() - config["public_baseurl"] = BASE_URL + config["public_baseurl"] = PUBLIC_BASEURL config["oidc_config"] = {} config["oidc_config"].update(TEST_OIDC_CONFIG) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index a1c284726ad..dbd6049f9fc 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -889,7 +889,7 @@ def initiate_sso_login( "GET", uri, ) - assert channel.code == 302 + assert channel.code == 302, f"Expected 302 for {uri}, got {channel.code}" # hit the redirect url again with the right Host header, which should now issue # a cookie and redirect to the SSO provider. @@ -901,17 +901,18 @@ def get_location(channel: FakeChannel) -> str: location = get_location(channel) parts = urllib.parse.urlsplit(location) + next_uri = urllib.parse.urlunsplit(("", "") + parts[2:]) channel = make_request( self.reactor, self.site, "GET", - urllib.parse.urlunsplit(("", "") + parts[2:]), + next_uri, custom_headers=[ ("Host", parts[1]), ], ) - assert channel.code == 302 + assert channel.code == 302, f"Expected 302 for {next_uri}, got {channel.code}" channel.extract_cookies(cookies) return get_location(channel)