diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index fcbd126d..639abe12 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -270,7 +270,7 @@ class OAuthenticator(Authenticator): - Override the constant `user_auth_state_key` - Override various config's default values, such as `authorize_url`, `token_url`, `userdata_url`, and `login_service`. - - Override various methods called by the `authenticate` method, which + - Override various methods called by :meth:`authenticate`, which subclasses should not override. - Override handler classes such as `login_handler`, `callback_handler`, and `logout_handler`. @@ -919,7 +919,8 @@ def get_handlers(self, app): def build_userdata_request_headers(self, access_token, token_type): """ Builds and returns the headers to be used in the userdata request. - Called by the :meth:`oauthenticator.OAuthenticator.token_to_user` + + Called by :meth:`.token_to_user`. """ # token_type is case-insensitive, but the headers are case-sensitive @@ -937,7 +938,8 @@ def build_userdata_request_headers(self, access_token, token_type): def build_token_info_request_headers(self): """ Builds and returns the headers to be used in the access token request. - Called by the :meth:`oauthenticator.OAuthenticator.get_token_info`. + + Called by :meth:`.get_token_info`. The Content-Type header is specified by the OAuth 2.0 RFC in https://www.rfc-editor.org/rfc/rfc6749#section-4.1.3. utf-8 is also @@ -971,7 +973,7 @@ def user_info_to_username(self, user_info): Returns: user_info["self.username_claim"] or raises an error if such value isn't found. - Called by the :meth:`oauthenticator.OAuthenticator.authenticate` + Called by :meth:`.authenticate` and :meth:`.refresh_user`. """ if callable(self.username_claim): @@ -987,25 +989,12 @@ def user_info_to_username(self, user_info): return username - # Originally a GoogleOAuthenticator only feature - async def get_prev_refresh_token(self, handler, username): - """ - Retrieves the `refresh_token` from previous encrypted auth state. - Called by the :meth:`oauthenticator.OAuthenticator.authenticate` - """ - user = handler.find_user(username) - if not user: - return None - auth_state = await user.get_auth_state() - if not auth_state: - return None - return auth_state.get("refresh_token", None) - def build_access_tokens_request_params(self, handler, data=None): """ Builds the parameters that should be passed to the URL request that exchanges the OAuth code for the Access Token. - Called by the :meth:`oauthenticator.OAuthenticator.authenticate`. + + Called by :meth:`.authenticate`. """ code = handler.get_argument("code") if not code: @@ -1042,14 +1031,36 @@ def build_access_tokens_request_params(self, handler, data=None): return params + def build_refresh_token_request_params(self, refresh_token): + """ + Builds the parameters that should be passed to the URL request + to renew the Access Token based on the Refresh Token + + Called by :meth:`.refresh_user`. + """ + params = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + + # the client_id and client_secret should not be included in the access token request params + # when basic authentication is used + # ref: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1 + if not self.basic_auth: + params["client_id"] = self.client_id + params["client_secret"] = self.client_secret + + return params + async def get_token_info(self, handler, params): """ Makes a "POST" request to `self.token_url`, with the parameters received as argument. Returns: - the JSON response to the `token_url` the request. + the JSON response to the `token_url` the request as described in + https://www.rfc-editor.org/rfc/rfc6749#section-5.1 - Called by the :meth:`oauthenticator.OAuthenticator.authenticate` + Called by :meth:`.authenticate` and :meth:`.refresh_user`. """ token_info = await self.httpfetch( @@ -1073,9 +1084,9 @@ async def get_token_info(self, handler, params): async def token_to_user(self, token_info): """ Determines who the logged-in user by sending a "GET" request to - :data:`oauthenticator.OAuthenticator.userdata_url` using the `access_token`. + :attr:`.userdata_url` using the `access_token`. - If :data:`oauthenticator.OAuthenticator.userdata_from_id_token` is set then + If :attr:`.userdata_from_id_token` is set then extracts the corresponding info from an `id_token` instead. Args: @@ -1084,7 +1095,7 @@ async def token_to_user(self, token_info): Returns: the JSON response to the `userdata_url` request. - Called by the :meth:`oauthenticator.OAuthenticator.authenticate` + Called by :meth:`.authenticate` and :meth:`.refresh_user`. """ if self.userdata_from_id_token: # Use id token instead of exchanging access token with userinfo endpoint. @@ -1134,7 +1145,7 @@ async def token_to_user(self, token_info): def build_auth_state_dict(self, token_info, user_info): """ - Builds the `auth_state` dict that will be returned by a succesfull `authenticate` method call. + Builds the `auth_state` dict that will be returned by a successful `authenticate` method call. May be async (requires oauthenticator >= 17.0). Args: @@ -1150,13 +1161,13 @@ def build_auth_state_dict(self, token_info, user_info): - "token_response": the full token_info response - self.user_auth_state_key: the full user_info response - Called by the :meth:`oauthenticator.OAuthenticator.authenticate` + Called by :meth:`.authenticate` and :meth:`.refresh_user`. .. versionchanged:: 17.0 This method may be async. """ - # We know for sure the `access_token` key exists, oterwise we would have errored out already + # We know for sure the `access_token` key exists, otherwise we would have errored out already access_token = token_info["access_token"] refresh_token = token_info.get("refresh_token", None) @@ -1221,9 +1232,9 @@ async def update_auth_model(self, auth_model): - `admin`: the admin status (True/False/None), where None means it should be unchanged. - `auth_state`: the auth state dictionary, - returned by :meth:`oauthenticator.OAuthenticator.build_auth_state_dict` + returned by :meth:`.build_auth_state_dict` - Called by the :meth:`oauthenticator.OAuthenticator.authenticate` + Called by :meth:`.authenticate` and :meth:`.refresh_user`. """ # NOTE: this base implementation should _not_ be updated to do anything # subclasses should have full control without calling super() @@ -1276,24 +1287,112 @@ async def authenticate(self, handler, data=None, **kwargs): """ # build the parameters to be used in the request exchanging the oauth code for the access token access_token_params = self.build_access_tokens_request_params(handler, data) - # exchange the oauth code for an access token and get the JSON with info about it token_info = await self.get_token_info(handler, access_token_params) + # call the oauth endpoints + return await self._token_to_auth_model(token_info) + + async def refresh_user(self, user, handler=None, **kwargs): + """ + Refresh user authentication + + If auth_state is enabled, constructs a fresh user model + (the same as `authenticate`) + using the access_token in auth_state. + If requests with the access token fail + (e.g. because the token has expired) + and a refresh token is found, attempts to exchange + the refresh token for a new access token to store in auth_state. + If the access token still fails after refresh, + return False to require the user to login via oauth again. + + Set `Authenticator.auth_refresh_age = 0` to disable. + + Returns + ------- + + True: + If auth info is up-to-date and needs no changes + (always if `enable_auth_state` is False) + False: + If the user needs to login again + (e.g. tokens in `auth_state` unavailable or expired) + auth_model: dict + The same dict as `authenticate`, updating any fields that should change. + Can include things like group membership, + but in OAuthenticator this mainly refreshes + the token fields in `auth_state`. + """ + if not self.enable_auth_state: + # auth state not enabled, can't refresh + return True + auth_state = await user.get_auth_state() + if not auth_state: + self.log.info( + f"No auth_state found for user {user.name} refresh, need full authentication", + ) + return False + + token_info = auth_state.get("token_response") + auth_model = None + try: + auth_model = await self._token_to_auth_model(token_info) + except HTTPClientError as e: + # assume any client error means an expired token + # most likely 401 or 403 for well-behaved providers + if 400 <= e.code < 500: + self.log.info( + f"Error refreshing auth with current access_token for {user.name}: {e}. Will try to refresh, if possible." + ) + else: + raise + refresh_token = auth_state.get("refresh_token", None) + if refresh_token and not auth_model: + self.log.info(f"Refreshing oauth access token for {user.name}") + # access_token expired, try refreshing with refresh_token + refresh_token_params = self.build_refresh_token_request_params( + refresh_token + ) + try: + token_info = await self.get_token_info(handler, refresh_token_params) + except Exception as e: + self.log.info( + f"Error using refresh_token for {user.name}: {e}. Requiring fresh login." + ) + return False + else: + self.log.debug( + f"Received fresh access_token for {user.name} via refresh_token" + ) + # refresh_token may not be returned when refreshing a token + # in which case, keep the current one + if not token_info.get("refresh_token"): + token_info["refresh_token"] = refresh_token + try: + auth_model = await self._token_to_auth_model(token_info) + except Exception as e: + # this means we were issued a fresh access token, + # but it didn't work! Fail harder? + self.log.error( + f"Error refreshing auth with fresh access_token for {user.name}: {e}. Requiring fresh login." + ) + return False + + # return False if auth_model is None for "needs new login" + return auth_model or False + + async def _token_to_auth_model(self, token_info): + """ + Turn a token into the user's `auth_model` to be returned by :meth:`.authenticate`. + + Common logic shared by :meth:`.authenticate` and :meth:`.refresh_user`. + """ + # use the access_token to get userdata info user_info = await self.token_to_user(token_info) # extract the username out of the user_info dict and normalize it username = self.user_info_to_username(user_info) username = self.normalize_username(username) - # check if there any refresh_token in the token_info dict - refresh_token = token_info.get("refresh_token", None) - if self.enable_auth_state and not refresh_token: - self.log.debug( - "Refresh token was empty, will try to pull refresh_token from previous auth_state" - ) - refresh_token = await self.get_prev_refresh_token(handler, username) - if refresh_token: - token_info["refresh_token"] = refresh_token - auth_state = self.build_auth_state_dict(token_info, user_info) if isawaitable(auth_state): auth_state = await auth_state diff --git a/oauthenticator/tests/mocks.py b/oauthenticator/tests/mocks.py index d9174604..5186f5ef 100644 --- a/oauthenticator/tests/mocks.py +++ b/oauthenticator/tests/mocks.py @@ -107,6 +107,7 @@ def setup_oauth_mock( user_path=None, token_type='Bearer', token_request_style='post', + enable_refresh_tokens=False, scope="", ): """setup the mock client for OAuth @@ -134,6 +135,8 @@ def setup_oauth_mock( client.oauth_codes = oauth_codes = {} client.access_tokens = access_tokens = {} + client.refresh_tokens = refresh_tokens = {} + client.enable_refresh_tokens = enable_refresh_tokens def access_token(request): """Handler for access token endpoint @@ -146,26 +149,53 @@ def access_token(request): if not query: query = request.body.decode('utf8') query = parse_qs(query) - if 'code' not in query: + grant_type = query.get("grant_type", [""])[0] + if grant_type == 'authorization_code': + if 'code' not in query: + return HTTPResponse( + request=request, + code=400, + reason=f"No code in access token request: url={request.url}, body={request.body}", + ) + code = query['code'][0] + if code not in oauth_codes: + return HTTPResponse( + request=request, code=403, reason=f"No such code: {code}" + ) + user = oauth_codes.pop(code) + elif grant_type == 'refresh_token': + if 'refresh_token' not in query: + return HTTPResponse( + request=request, + code=400, + reason=f"No refresh_token in access token request: url={request.url}, body={request.body}", + ) + refresh_token = query['refresh_token'][0] + if refresh_token not in refresh_token: + return HTTPResponse( + request=request, + code=403, + reason=f"No such refresh_toekn: {refresh_token}", + ) + user = refresh_tokens[refresh_token] + else: return HTTPResponse( request=request, code=400, - reason=f"No code in access token request: url={request.url}, body={request.body}", - ) - code = query['code'][0] - if code not in oauth_codes: - return HTTPResponse( - request=request, code=403, reason=f"No such code: {code}" + reason=f"Invalid grant_type={grant_type}: url={request.url}, body={request.body}", ) # consume code, allocate token - token = uuid.uuid4().hex - user = oauth_codes.pop(code) - access_tokens[token] = user + access_token = uuid.uuid4().hex + access_tokens[access_token] = user model = { - 'access_token': token, + 'access_token': access_token, 'token_type': token_type, } + if client.enable_refresh_tokens: + refresh_token = uuid.uuid4().hex + refresh_tokens[refresh_token] = user + model['refresh_token'] = refresh_token if scope: model['scope'] = scope if 'id_token' in user: diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index e3225039..e5c99602 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -15,14 +15,15 @@ def user_model(username, **kwargs): """Return a user model""" - return { + model = { "username": username, "aud": client_id, "sub": "oauth2|cilogon|http://cilogon.org/servera/users/43431", "scope": "basic", "groups": ["group1"], - **kwargs, } + model.update(kwargs) + return model @fixture(params=["id_token", "userdata_url"]) @@ -505,6 +506,90 @@ async def test_check_allowed_no_auth_state(get_authenticator, name, allowed): assert await authenticator.check_allowed(name, None) +class MockUser: + """Mock subset of JupyterHub User API from the `auth_model` dict""" + + name: str + + def __init__(self, auth_model): + self._auth_model = auth_model + self.name = auth_model["name"] + + async def get_auth_state(self): + return self._auth_model["auth_state"] + + +@mark.parametrize("enable_refresh_tokens", [True, False]) +async def test_refresh_user(get_authenticator, generic_client, enable_refresh_tokens): + generic_client.enable_refresh_tokens = enable_refresh_tokens + authenticator = get_authenticator(allowed_users={"user1"}) + authenticator.manage_groups = True + authenticator.auth_state_groups_key = "oauth_user.groups" + oauth_userinfo = user_model("user1", groups=["round1"]) + handler = generic_client.handler_for_user(oauth_userinfo) + auth_model = await authenticator.get_authenticated_user(handler, None) + auth_state = auth_model["auth_state"] + assert auth_model["groups"] == ["round1"] + if enable_refresh_tokens: + assert "refresh_token" in auth_state + assert "refresh_token" in auth_state["token_response"] + assert ( + auth_state["refresh_token"] == auth_state["token_response"]["refresh_token"] + ) + else: + assert "refresh_token" not in auth_state["token_response"] + assert auth_state.get("refresh_token") is None + user = MockUser(auth_model) + # case: auth_state not enabled, nothing to refresh + refreshed = await authenticator.refresh_user(user, handler) + assert refreshed is True + + # from here on, enable auth state required for refresh to do anything + authenticator.enable_auth_state = True + + # case: no auth state, but auth state enabled needs refresh + auth_without_state = auth_model.copy() + auth_without_state["auth_state"] = None + user_without_state = MockUser(auth_without_state) + refreshed = await authenticator.refresh_user(user_without_state, handler) + assert refreshed is False + + # case: actually refresh + oauth_userinfo["groups"] = ["refreshed"] + refreshed = await authenticator.refresh_user(user, handler) + assert refreshed + assert refreshed["name"] == auth_model["name"] + assert refreshed["groups"] == ["refreshed"] + refreshed_state = refreshed["auth_state"] + assert "access_token" in refreshed_state + # refresh with access token succeeds, keeps tokens unchanged + assert refreshed_state.get("refresh_token") == auth_state.get("refresh_token") + assert refreshed_state["access_token"] == auth_state["access_token"] + + # case: access token is no longer valid, triggers refresh + oauth_userinfo["groups"] = ["token_refreshed"] + generic_client.access_tokens.pop(refreshed_state["access_token"]) + refreshed = await authenticator.refresh_user(user, handler) + if enable_refresh_tokens: + # access_token refreshed + assert refreshed + refreshed_state = refreshed["auth_state"] + assert ( + refreshed_state["access_token"] != auth_model["auth_state"]["access_token"] + ) + assert refreshed["groups"] == ["token_refreshed"] + else: + assert refreshed is False + + if enable_refresh_tokens: + # case: token used for refresh is no longer valid + user = MockUser(refreshed) + generic_client.access_tokens.pop(refreshed_state["access_token"]) + generic_client.refresh_tokens.pop(refreshed_state["refresh_token"]) + refreshed = await authenticator.refresh_user(user, handler) + assert refreshed is False + + @mark.parametrize( "test_variation_id,class_config,expect_config,expect_loglevel,expect_message", [ diff --git a/oauthenticator/tests/test_github.py b/oauthenticator/tests/test_github.py index e49fe064..1ca5f590 100644 --- a/oauthenticator/tests/test_github.py +++ b/oauthenticator/tests/test_github.py @@ -141,7 +141,7 @@ async def test_github( assert user_info == handled_user_model assert auth_model["name"] == user_info[authenticator.username_claim] else: - assert auth_model == None + assert auth_model is None def make_link_header(urlinfo, page):