Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OAuthenticator.refresh_user #579

Merged
merged 16 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 96 additions & 14 deletions oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,12 +1054,34 @@ def build_access_tokens_request_params(self, handler, data=None):

return params

def build_refresh_token_request_params(self, refresh_token):
"""
Builds the parameters that should be passed to the URL request
to renew the Access Token based on the Refresh Token

Called by the :meth:`oauthenticator.OAuthenticator.refresh_user`.
"""
params = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
minrk marked this conversation as resolved.
Show resolved Hide resolved
}

# the client_id and client_secret should not be included in the access token request params
# when basic authentication is used
# ref: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1
if not self.basic_auth:
params["client_id"] = self.client_id
params["client_secret"] = self.client_secret

return params

async def get_token_info(self, handler, params):
"""
Makes a "POST" request to `self.token_url`, with the parameters received as argument.

Returns:
the JSON response to the `token_url` the request.
the JSON response to the `token_url` the request as described in
https://www.rfc-editor.org/rfc/rfc6749#section-5.1

Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
"""
Expand Down Expand Up @@ -1146,7 +1168,7 @@ async def token_to_user(self, token_info):

def build_auth_state_dict(self, token_info, user_info):
"""
Builds the `auth_state` dict that will be returned by a succesfull `authenticate` method call.
Builds the `auth_state` dict that will be returned by a successful `authenticate` method call.
May be async (requires oauthenticator >= 17.0).

Args:
Expand All @@ -1168,7 +1190,7 @@ def build_auth_state_dict(self, token_info, user_info):
This method may be async.
"""

# We know for sure the `access_token` key exists, oterwise we would have errored out already
# We know for sure the `access_token` key exists, otherwise we would have errored out already
access_token = token_info["access_token"]

refresh_token = token_info.get("refresh_token", None)
Expand Down Expand Up @@ -1288,24 +1310,84 @@ async def authenticate(self, handler, data=None, **kwargs):
"""
# build the parameters to be used in the request exchanging the oauth code for the access token
access_token_params = self.build_access_tokens_request_params(handler, data)
# exchange the oauth code for an access token and get the JSON with info about it
token_info = await self.get_token_info(handler, access_token_params)
# call the oauth endpoints
return await self._token_to_auth_model(token_info)

async def refresh_user(self, user, handler=None, **kwargs):
"""
Renew the Access Token with a valid Refresh Token
"""
if not self.enable_auth_state:
# auth state not enabled, can't refresh
self.log.debug("auth_state disabled, no auth state to refresh")
return True
minrk marked this conversation as resolved.
Show resolved Hide resolved
auth_state = await user.get_auth_state()
if not auth_state:
self.log.info(
f"No auth_state found for user {user.name} refresh, need full authentication",
)
return False

token_info = auth_state.get("token_response")
auth_model = None
try:
auth_model = await self._token_to_auth_model(token_info)
except HTTPClientError as e:
# assume any client error means an expired token
# most likely 401 or 403 for well-behaved providers
if 400 <= e.code < 500:
self.log.info(
f"Error refreshing auth with current access_token for {user.name}: {e}. Will try to refresh, if possible."
)
else:
raise
refresh_token = auth_state.get("refresh_token", None)
if refresh_token and not auth_model:
self.log.info(f"Refreshing oauth access token for {user.name}")
# access_token expired, try refreshing with refresh_token
refresh_token_params = self.build_refresh_token_request_params(
refresh_token
)
try:
token_info = await self.get_token_info(handler, refresh_token_params)
except Exception as e:
self.log.info(
f"Error using refresh_token for {user.name}: {e}. Requiring fresh login."
)
return False
else:
self.log.debug(
f"Received fresh access_token for {user.name} via refresh_token"
)
# refresh_token may not be returned when refreshing a token
# in which case, keep the current one
if not token_info.get("refresh_token"):
token_info["refresh_token"] = refresh_token
try:
auth_model = await self._token_to_auth_model(token_info)
except Exception as e:
# this means we were issued a fresh access token,
# but it didn't work! Fail harder?
self.log.error(
f"Error refreshing auth with fresh access_token for {user.name}: {e}. Requiring fresh login."
)
return False

# return False if auth_model is None for "needs new login"
return auth_model or False

async def _token_to_auth_model(self, token_info):
"""
Common logic shared by authenticate() and refresh_user()
"""

# use the access_token to get userdata info
user_info = await self.token_to_user(token_info)
# extract the username out of the user_info dict and normalize it
username = self.user_info_to_username(user_info)
username = self.normalize_username(username)

# check if there any refresh_token in the token_info dict
refresh_token = token_info.get("refresh_token", None)
if self.enable_auth_state and not refresh_token:
self.log.debug(
"Refresh token was empty, will try to pull refresh_token from previous auth_state"
)
refresh_token = await self.get_prev_refresh_token(handler, username)
minrk marked this conversation as resolved.
Show resolved Hide resolved
if refresh_token:
token_info["refresh_token"] = refresh_token

auth_state = self.build_auth_state_dict(token_info, user_info)
if isawaitable(auth_state):
auth_state = await auth_state
Expand Down
52 changes: 41 additions & 11 deletions oauthenticator/tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def setup_oauth_mock(
user_path=None,
token_type='Bearer',
token_request_style='post',
enable_refresh_tokens=False,
scope="",
):
"""setup the mock client for OAuth
Expand Down Expand Up @@ -134,6 +135,8 @@ def setup_oauth_mock(

client.oauth_codes = oauth_codes = {}
client.access_tokens = access_tokens = {}
client.refresh_tokens = refresh_tokens = {}
client.enable_refresh_tokens = enable_refresh_tokens

def access_token(request):
"""Handler for access token endpoint
Expand All @@ -146,26 +149,53 @@ def access_token(request):
if not query:
query = request.body.decode('utf8')
query = parse_qs(query)
if 'code' not in query:
grant_type = query.get("grant_type", [""])[0]
if grant_type == 'authorization_code':
if 'code' not in query:
return HTTPResponse(
request=request,
code=400,
reason=f"No code in access token request: url={request.url}, body={request.body}",
)
code = query['code'][0]
if code not in oauth_codes:
return HTTPResponse(
request=request, code=403, reason=f"No such code: {code}"
)
user = oauth_codes.pop(code)
elif grant_type == 'refresh_token':
if 'refresh_token' not in query:
return HTTPResponse(
request=request,
code=400,
reason=f"No refresh_token in access token request: url={request.url}, body={request.body}",
)
refresh_token = query['refresh_token'][0]
if refresh_token not in refresh_token:
return HTTPResponse(
request=request,
code=403,
reason=f"No such refresh_toekn: {refresh_token}",
)
user = refresh_tokens[refresh_token]
else:
return HTTPResponse(
request=request,
code=400,
reason=f"No code in access token request: url={request.url}, body={request.body}",
)
code = query['code'][0]
if code not in oauth_codes:
return HTTPResponse(
request=request, code=403, reason=f"No such code: {code}"
reason=f"Invalid grant_type={grant_type}: url={request.url}, body={request.body}",
)

# consume code, allocate token
token = uuid.uuid4().hex
user = oauth_codes.pop(code)
access_tokens[token] = user
access_token = uuid.uuid4().hex
access_tokens[access_token] = user
model = {
'access_token': token,
'access_token': access_token,
'token_type': token_type,
}
if client.enable_refresh_tokens:
refresh_token = uuid.uuid4().hex
refresh_tokens[refresh_token] = user
model['refresh_token'] = refresh_token
if scope:
model['scope'] = scope
if 'id_token' in user:
Expand Down
89 changes: 87 additions & 2 deletions oauthenticator/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@

def user_model(username, **kwargs):
"""Return a user model"""
return {
model = {
"username": username,
"aud": client_id,
"sub": "oauth2|cilogon|http://cilogon.org/servera/users/43431",
"scope": "basic",
"groups": ["group1"],
**kwargs,
}
model.update(kwargs)
return model


@fixture(params=["id_token", "userdata_url"])
Expand Down Expand Up @@ -505,6 +506,90 @@ async def test_check_allowed_no_auth_state(get_authenticator, name, allowed):
assert await authenticator.check_allowed(name, None)


class MockUser:
"""Mock subset of JupyterHub User API from the `auth_model` dict"""

name: str

def __init__(self, auth_model):
self._auth_model = auth_model
self.name = auth_model["name"]

async def get_auth_state(self):
return self._auth_model["auth_state"]


@mark.parametrize("enable_refresh_tokens", [True, False])
async def test_refresh_user(get_authenticator, generic_client, enable_refresh_tokens):
generic_client.enable_refresh_tokens = enable_refresh_tokens
authenticator = get_authenticator(allowed_users={"user1"})
authenticator.manage_groups = True
authenticator.auth_state_groups_key = "oauth_user.groups"
oauth_userinfo = user_model("user1", groups=["round1"])
handler = generic_client.handler_for_user(oauth_userinfo)
auth_model = await authenticator.get_authenticated_user(handler, None)
auth_state = auth_model["auth_state"]
assert auth_model["groups"] == ["round1"]
if enable_refresh_tokens:
assert "refresh_token" in auth_state
assert "refresh_token" in auth_state["token_response"]
assert (
auth_state["refresh_token"] == auth_state["token_response"]["refresh_token"]
)
else:
assert "refresh_token" not in auth_state["token_response"]
assert auth_state.get("refresh_token") is None
user = MockUser(auth_model)
# case: auth_state not enabled, nothing to refresh
refreshed = await authenticator.refresh_user(user, handler)
assert refreshed is True

# from here on, enable auth state required for refresh to do anything
authenticator.enable_auth_state = True

# case: no auth state, but auth state enabled needs refresh
auth_without_state = auth_model.copy()
auth_without_state["auth_state"] = None
user_without_state = MockUser(auth_without_state)
refreshed = await authenticator.refresh_user(user_without_state, handler)
assert refreshed is False

# case: actually refresh
oauth_userinfo["groups"] = ["refreshed"]
refreshed = await authenticator.refresh_user(user, handler)
assert refreshed
assert refreshed["name"] == auth_model["name"]
assert refreshed["groups"] == ["refreshed"]
refreshed_state = refreshed["auth_state"]
assert "access_token" in refreshed_state
# refresh with access token succeeds, keeps tokens unchanged
assert refreshed_state.get("refresh_token") == auth_state.get("refresh_token")
assert refreshed_state["access_token"] == auth_state["access_token"]

# case: access token is no longer valid, triggers refresh
oauth_userinfo["groups"] = ["token_refreshed"]
generic_client.access_tokens.pop(refreshed_state["access_token"])
refreshed = await authenticator.refresh_user(user, handler)
if enable_refresh_tokens:
# access_token refreshed
assert refreshed
refreshed_state = refreshed["auth_state"]
assert (
refreshed_state["access_token"] != auth_model["auth_state"]["access_token"]
)
assert refreshed["groups"] == ["token_refreshed"]
else:
assert refreshed is False

if enable_refresh_tokens:
# case: token used for refresh is no longer valid
user = MockUser(refreshed)
generic_client.access_tokens.pop(refreshed_state["access_token"])
generic_client.refresh_tokens.pop(refreshed_state["refresh_token"])
refreshed = await authenticator.refresh_user(user, handler)
assert refreshed is False


@mark.parametrize(
"test_variation_id,class_config,expect_config,expect_loglevel,expect_message",
[
Expand Down
2 changes: 1 addition & 1 deletion oauthenticator/tests/test_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def test_github(
assert user_info == handled_user_model
assert auth_model["name"] == user_info[authenticator.username_claim]
else:
assert auth_model == None
assert auth_model is None


def make_link_header(urlinfo, page):
Expand Down