From d9c2d0334f8e9b6d534c297cea276117954ae49e Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Wed, 7 Oct 2020 12:46:43 -0300 Subject: [PATCH 01/16] Implementing the oauth2 integration between FastAPI and Authlib --- .../integrations/fastapi_oauth2/__init__.py | 2 + .../fastapi_oauth2/authorization_server.py | 203 ++++++++++++++++++ authlib/integrations/fastapi_oauth2/errors.py | 9 + .../fastapi_oauth2/resource_protector.py | 107 +++++++++ 4 files changed, 321 insertions(+) create mode 100644 authlib/integrations/fastapi_oauth2/__init__.py create mode 100644 authlib/integrations/fastapi_oauth2/authorization_server.py create mode 100644 authlib/integrations/fastapi_oauth2/errors.py create mode 100644 authlib/integrations/fastapi_oauth2/resource_protector.py diff --git a/authlib/integrations/fastapi_oauth2/__init__.py b/authlib/integrations/fastapi_oauth2/__init__.py new file mode 100644 index 00000000..6765a4f9 --- /dev/null +++ b/authlib/integrations/fastapi_oauth2/__init__.py @@ -0,0 +1,2 @@ +from .authorization_server import AuthorizationServer +from .resource_protector import ResourceProtector diff --git a/authlib/integrations/fastapi_oauth2/authorization_server.py b/authlib/integrations/fastapi_oauth2/authorization_server.py new file mode 100644 index 00000000..c8d17bb4 --- /dev/null +++ b/authlib/integrations/fastapi_oauth2/authorization_server.py @@ -0,0 +1,203 @@ +import json +from werkzeug.utils import import_string +from fastapi.responses import JSONResponse +from authlib.deprecate import deprecate +from authlib.oauth2 import ( + OAuth2Request, + HttpRequest, + AuthorizationServer as _AuthorizationServer, +) +from authlib.oauth2.rfc6750 import BearerToken +from authlib.oauth2.rfc8414 import AuthorizationServerMetadata +from authlib.common.security import generate_token +from authlib.common.encoding import to_unicode + + +class AuthorizationServer(_AuthorizationServer): + """Flask implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`. + Initialize it with ``query_client``, ``save_token`` methods and Flask + app instance:: + + def query_client(client_id): + return Client.query.filter_by(client_id=client_id).first() + + def save_token(token, request): + if request.user: + user_id = request.user.get_user_id() + else: + user_id = None + client = request.client + tok = Token( + client_id=client.client_id, + user_id=user.get_user_id(), + **token + ) + db.session.add(tok) + db.session.commit() + + server = AuthorizationServer(app, query_client, save_token) + # or initialize lazily + server = AuthorizationServer() + server.init_app(app, query_client, save_token) + """ + metadata_class = AuthorizationServerMetadata + + def __init__(self, app=None, query_client=None, save_token=None): + super(AuthorizationServer, self).__init__( + query_client=query_client, + save_token=save_token, + ) + self.config = {} + if app is not None: + self.init_app(app) + + def init_app(self, app, query_client=None, save_token=None): + """Initialize later with Flask app instance.""" + if query_client is not None: + self.query_client = query_client + if save_token is not None: + self.save_token = save_token + + self.generate_token = self.create_bearer_token_generator(app.config) + + metadata_file = app.config.get('OAUTH2_METADATA_FILE') + if metadata_file: + with open(metadata_file) as f: + metadata = self.metadata_class(json.loads(f)) + metadata.validate() + self.metadata = metadata + + self.config.setdefault('error_uris', app.config.get('OAUTH2_ERROR_URIS')) + if app.config.get('OAUTH2_JWT_ENABLED'): + deprecate('Define "get_jwt_config" in OpenID Connect grants', '1.0') + self.init_jwt_config(app.config) + + def init_jwt_config(self, config): + """Initialize JWT related configuration.""" + jwt_iss = config.get('OAUTH2_JWT_ISS') + if not jwt_iss: + raise RuntimeError('Missing "OAUTH2_JWT_ISS" configuration.') + + jwt_key_path = config.get('OAUTH2_JWT_KEY_PATH') + if jwt_key_path: + with open(jwt_key_path, 'r') as f: + if jwt_key_path.endswith('.json'): + jwt_key = json.loads(f) + else: + jwt_key = to_unicode(f.read()) + else: + jwt_key = config.get('OAUTH2_JWT_KEY') + + if not jwt_key: + raise RuntimeError('Missing "OAUTH2_JWT_KEY" configuration.') + + jwt_alg = config.get('OAUTH2_JWT_ALG') + if not jwt_alg: + raise RuntimeError('Missing "OAUTH2_JWT_ALG" configuration.') + + jwt_exp = config.get('OAUTH2_JWT_EXP', 3600) + self.config.setdefault('jwt_iss', jwt_iss) + self.config.setdefault('jwt_key', jwt_key) + self.config.setdefault('jwt_alg', jwt_alg) + self.config.setdefault('jwt_exp', jwt_exp) + + def get_error_uris(self, request): + error_uris = self.config.get('error_uris') + if error_uris: + return dict(error_uris) + + def create_oauth2_request(self, request): + return OAuth2Request(request.method, str(request.url), request.body, request.headers) + + def create_json_request(self, request): + return HttpRequest(request.method, str(request.url), request.body, request.headers) + + def handle_response(self, status, body, headers): + return JSONResponse(content=body, status_code=status, headers=dict(headers)) + + def create_token_expires_in_generator(self, config): + """Create a generator function for generating ``expires_in`` value. + Developers can re-implement this method with a subclass if other means + required. The default expires_in value is defined by ``grant_type``, + different ``grant_type`` has different value. It can be configured + with:: + + OAUTH2_TOKEN_EXPIRES_IN = { + 'authorization_code': 864000, + 'urn:ietf:params:oauth:grant-type:jwt-bearer': 3600, + } + """ + expires_conf = config.get('OAUTH2_TOKEN_EXPIRES_IN') + return create_token_expires_in_generator(expires_conf) + + def create_bearer_token_generator(self, config): + """Create a generator function for generating ``token`` value. This + method will create a Bearer Token generator with + :class:`authlib.oauth2.rfc6750.BearerToken`. By default, it will not + generate ``refresh_token``, which can be turn on by configuration + ``OAUTH2_REFRESH_TOKEN_GENERATOR=True``. + """ + conf = config.get('OAUTH2_ACCESS_TOKEN_GENERATOR', True) + access_token_generator = create_token_generator(conf, 42) + + conf = config.get('OAUTH2_REFRESH_TOKEN_GENERATOR', False) + refresh_token_generator = create_token_generator(conf, 48) + + expires_generator = self.create_token_expires_in_generator(config) + return BearerToken( + access_token_generator, + refresh_token_generator, + expires_generator + ) + + def validate_consent_request(self, request=None, end_user=None): + """Validate current HTTP request for authorization page. This page + is designed for resource owner to grant or deny the authorization:: + + @app.route('/authorize', methods=['GET']) + def authorize(): + try: + grant = server.validate_consent_request(end_user=current_user) + return render_template( + 'authorize.html', + grant=grant, + user=current_user + ) + except OAuth2Error as error: + return render_template( + 'error.html', + error=error + ) + """ + req = self.create_oauth2_request(request) + req.user = end_user + + grant = self.get_authorization_grant(req) + grant.validate_consent_request() + if not hasattr(grant, 'prompt'): + grant.prompt = None + return grant + + +def create_token_expires_in_generator(expires_in_conf=None): + data = {} + data.update(BearerToken.GRANT_TYPES_EXPIRES_IN) + if expires_in_conf: + data.update(expires_in_conf) + + def expires_in(client, grant_type): + return data.get(grant_type, BearerToken.DEFAULT_EXPIRES_IN) + + return expires_in + + +def create_token_generator(token_generator_conf, length=42): + if callable(token_generator_conf): + return token_generator_conf + + if isinstance(token_generator_conf, str): + return import_string(token_generator_conf) + elif token_generator_conf is True: + def token_generator(*args, **kwargs): + return generate_token(length) + return token_generator diff --git a/authlib/integrations/fastapi_oauth2/errors.py b/authlib/integrations/fastapi_oauth2/errors.py new file mode 100644 index 00000000..bed50e55 --- /dev/null +++ b/authlib/integrations/fastapi_oauth2/errors.py @@ -0,0 +1,9 @@ +from fastapi import HTTPException + + +def raise_http_exception(status, body, headers): + raise HTTPException( + status_code=status, + detail=body, + headers=dict(headers) + ) diff --git a/authlib/integrations/fastapi_oauth2/resource_protector.py b/authlib/integrations/fastapi_oauth2/resource_protector.py new file mode 100644 index 00000000..8d43c311 --- /dev/null +++ b/authlib/integrations/fastapi_oauth2/resource_protector.py @@ -0,0 +1,107 @@ +import functools +from contextlib import contextmanager +from authlib.oauth2 import ( + OAuth2Error, + ResourceProtector as _ResourceProtector +) +from authlib.oauth2.rfc6749 import ( + MissingAuthorizationError, + HttpRequest, +) +from .errors import raise_http_exception + + +class ResourceProtector(_ResourceProtector): + """A protecting method for resource servers. Creating a ``require_oauth`` + decorator easily with ResourceProtector:: + + from authlib.integrations.flask_oauth2 import ResourceProtector + + require_oauth = ResourceProtector() + + # add bearer token validator + from authlib.oauth2.rfc6750 import BearerTokenValidator + from project.models import Token + + class MyBearerTokenValidator(BearerTokenValidator): + def authenticate_token(self, token_string): + return Token.query.filter_by(access_token=token_string).first() + + def request_invalid(self, request): + return False + + def token_revoked(self, token): + return False + + require_oauth.register_token_validator(MyBearerTokenValidator()) + + # protect resource with require_oauth + + @app.route('/user') + @require_oauth('profile') + def user_profile(): + user = User.query.get(current_token.user_id) + return jsonify(user.to_dict()) + + """ + def raise_error_response(self, error): + """Raise HTTPException for OAuth2Error. Developers can re-implement + this method to customize the error response. + + :param error: OAuth2Error + :raise: HTTPException + """ + status = error.status_code + body = dict(error.get_body()) + headers = error.get_headers() + raise_http_exception(status, body, headers) + + def acquire_token(self, request=None, scope=None, operator='AND'): + """A method to acquire current valid token with the given scope. + + :param scope: string or list of scope values + :param operator: value of "AND" or "OR" + :return: token object + """ + request = HttpRequest( + request.method, + request.url, + {}, + request.headers + ) + if not callable(operator): + operator = operator.upper() + token = self.validate_request(scope, request, operator) + return token + + @contextmanager + def acquire(self, request=None, scope=None, operator='AND'): + """The with statement of ``require_oauth``. Instead of using a + decorator, you can use a with statement instead:: + + @app.route('/api/user') + def user_api(): + with require_oauth.acquire('profile') as token: + user = User.query.get(token.user_id) + return jsonify(user.to_dict()) + """ + try: + yield self.acquire_token(request, scope, operator) + except OAuth2Error as error: + self.raise_error_response(error) + + def __call__(self, scope=None, operator='AND', optional=False): + def wrapper(f): + @functools.wraps(f) + def decorated(*args, **kwargs): + try: + self.acquire_token(scope, operator) + except MissingAuthorizationError as error: + if optional: + return f(*args, **kwargs) + self.raise_error_response(error) + except OAuth2Error as error: + self.raise_error_response(error) + return f(*args, **kwargs) + return decorated + return wrapper From 4e320282fa624b6656576a1bded773184859955b Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Thu, 15 Oct 2020 16:26:58 -0300 Subject: [PATCH 02/16] Remove the deprecated code and code linter --- .../integrations/fastapi_oauth2/__init__.py | 2 + .../fastapi_oauth2/authorization_server.py | 183 ++++++------------ authlib/integrations/fastapi_oauth2/errors.py | 9 - .../fastapi_oauth2/resource_protector.py | 84 +++----- 4 files changed, 82 insertions(+), 196 deletions(-) delete mode 100644 authlib/integrations/fastapi_oauth2/errors.py diff --git a/authlib/integrations/fastapi_oauth2/__init__.py b/authlib/integrations/fastapi_oauth2/__init__.py index 6765a4f9..9a001a77 100644 --- a/authlib/integrations/fastapi_oauth2/__init__.py +++ b/authlib/integrations/fastapi_oauth2/__init__.py @@ -1,2 +1,4 @@ +"""FastAPI package implementation.""" + from .authorization_server import AuthorizationServer from .resource_protector import ResourceProtector diff --git a/authlib/integrations/fastapi_oauth2/authorization_server.py b/authlib/integrations/fastapi_oauth2/authorization_server.py index c8d17bb4..5d550771 100644 --- a/authlib/integrations/fastapi_oauth2/authorization_server.py +++ b/authlib/integrations/fastapi_oauth2/authorization_server.py @@ -1,7 +1,8 @@ +"""Implementation of authlib.oauth2.rfc6749.AuthorizationServer class for FastAPI.""" + import json from werkzeug.utils import import_string from fastapi.responses import JSONResponse -from authlib.deprecate import deprecate from authlib.oauth2 import ( OAuth2Request, HttpRequest, @@ -10,101 +11,40 @@ from authlib.oauth2.rfc6750 import BearerToken from authlib.oauth2.rfc8414 import AuthorizationServerMetadata from authlib.common.security import generate_token -from authlib.common.encoding import to_unicode class AuthorizationServer(_AuthorizationServer): - """Flask implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`. - Initialize it with ``query_client``, ``save_token`` methods and Flask - app instance:: - - def query_client(client_id): - return Client.query.filter_by(client_id=client_id).first() - - def save_token(token, request): - if request.user: - user_id = request.user.get_user_id() - else: - user_id = None - client = request.client - tok = Token( - client_id=client.client_id, - user_id=user.get_user_id(), - **token - ) - db.session.add(tok) - db.session.commit() - - server = AuthorizationServer(app, query_client, save_token) - # or initialize lazily - server = AuthorizationServer() - server.init_app(app, query_client, save_token) - """ - metadata_class = AuthorizationServerMetadata + """AuthorizationServer class.""" - def __init__(self, app=None, query_client=None, save_token=None): - super(AuthorizationServer, self).__init__( - query_client=query_client, - save_token=save_token, - ) + def __init__(self, query_client=None, save_token=None): + super().__init__(query_client=query_client, save_token=save_token) self.config = {} - if app is not None: - self.init_app(app) def init_app(self, app, query_client=None, save_token=None): - """Initialize later with Flask app instance.""" + """Initialize the FastAPI app.""" if query_client is not None: self.query_client = query_client if save_token is not None: self.save_token = save_token - self.generate_token = self.create_bearer_token_generator(app.config) + self.generate_token = create_bearer_token_generator(app.config) + + metadata_class = AuthorizationServerMetadata metadata_file = app.config.get('OAUTH2_METADATA_FILE') if metadata_file: - with open(metadata_file) as f: - metadata = self.metadata_class(json.loads(f)) + with open(metadata_file) as metadata_file_content: + metadata = metadata_class(json.loads(metadata_file_content)) metadata.validate() self.metadata = metadata self.config.setdefault('error_uris', app.config.get('OAUTH2_ERROR_URIS')) - if app.config.get('OAUTH2_JWT_ENABLED'): - deprecate('Define "get_jwt_config" in OpenID Connect grants', '1.0') - self.init_jwt_config(app.config) - - def init_jwt_config(self, config): - """Initialize JWT related configuration.""" - jwt_iss = config.get('OAUTH2_JWT_ISS') - if not jwt_iss: - raise RuntimeError('Missing "OAUTH2_JWT_ISS" configuration.') - - jwt_key_path = config.get('OAUTH2_JWT_KEY_PATH') - if jwt_key_path: - with open(jwt_key_path, 'r') as f: - if jwt_key_path.endswith('.json'): - jwt_key = json.loads(f) - else: - jwt_key = to_unicode(f.read()) - else: - jwt_key = config.get('OAUTH2_JWT_KEY') - - if not jwt_key: - raise RuntimeError('Missing "OAUTH2_JWT_KEY" configuration.') - - jwt_alg = config.get('OAUTH2_JWT_ALG') - if not jwt_alg: - raise RuntimeError('Missing "OAUTH2_JWT_ALG" configuration.') - - jwt_exp = config.get('OAUTH2_JWT_EXP', 3600) - self.config.setdefault('jwt_iss', jwt_iss) - self.config.setdefault('jwt_key', jwt_key) - self.config.setdefault('jwt_alg', jwt_alg) - self.config.setdefault('jwt_exp', jwt_exp) def get_error_uris(self, request): error_uris = self.config.get('error_uris') if error_uris: return dict(error_uris) + return None def create_oauth2_request(self, request): return OAuth2Request(request.method, str(request.url), request.body, request.headers) @@ -115,60 +55,9 @@ def create_json_request(self, request): def handle_response(self, status, body, headers): return JSONResponse(content=body, status_code=status, headers=dict(headers)) - def create_token_expires_in_generator(self, config): - """Create a generator function for generating ``expires_in`` value. - Developers can re-implement this method with a subclass if other means - required. The default expires_in value is defined by ``grant_type``, - different ``grant_type`` has different value. It can be configured - with:: - - OAUTH2_TOKEN_EXPIRES_IN = { - 'authorization_code': 864000, - 'urn:ietf:params:oauth:grant-type:jwt-bearer': 3600, - } - """ - expires_conf = config.get('OAUTH2_TOKEN_EXPIRES_IN') - return create_token_expires_in_generator(expires_conf) - - def create_bearer_token_generator(self, config): - """Create a generator function for generating ``token`` value. This - method will create a Bearer Token generator with - :class:`authlib.oauth2.rfc6750.BearerToken`. By default, it will not - generate ``refresh_token``, which can be turn on by configuration - ``OAUTH2_REFRESH_TOKEN_GENERATOR=True``. - """ - conf = config.get('OAUTH2_ACCESS_TOKEN_GENERATOR', True) - access_token_generator = create_token_generator(conf, 42) - - conf = config.get('OAUTH2_REFRESH_TOKEN_GENERATOR', False) - refresh_token_generator = create_token_generator(conf, 48) - - expires_generator = self.create_token_expires_in_generator(config) - return BearerToken( - access_token_generator, - refresh_token_generator, - expires_generator - ) - def validate_consent_request(self, request=None, end_user=None): """Validate current HTTP request for authorization page. This page - is designed for resource owner to grant or deny the authorization:: - - @app.route('/authorize', methods=['GET']) - def authorize(): - try: - grant = server.validate_consent_request(end_user=current_user) - return render_template( - 'authorize.html', - grant=grant, - user=current_user - ) - except OAuth2Error as error: - return render_template( - 'error.html', - error=error - ) - """ + is designed for resource owner to grant or deny the authorization""" req = self.create_oauth2_request(request) req.user = end_user @@ -179,25 +68,63 @@ def authorize(): return grant -def create_token_expires_in_generator(expires_in_conf=None): +def create_bearer_token_generator(config): + """Create a generator function for generating ``token`` value. This + method will create a Bearer Token generator with + :class:`authlib.oauth2.rfc6750.BearerToken`. By default, it will not + generate ``refresh_token``, which can be turn on by configuration + ``OAUTH2_REFRESH_TOKEN_GENERATOR=True``. + """ + conf = config.get('OAUTH2_ACCESS_TOKEN_GENERATOR', True) + access_token_generator = create_token_generator(conf, 42) + + conf = config.get('OAUTH2_REFRESH_TOKEN_GENERATOR', False) + refresh_token_generator = create_token_generator(conf, 48) + + expires_generator = create_token_expires_in_generator(config) + + return BearerToken( + access_token_generator, + refresh_token_generator, + expires_generator + ) + + +def create_token_expires_in_generator(config): + """Create a generator function for generating ``expires_in`` value. + Developers can re-implement this method with a subclass if other means + required. The default expires_in value is defined by ``grant_type``, + different ``grant_type`` has different value. It can be configured + with:: + + OAUTH2_TOKEN_EXPIRES_IN = { + 'authorization_code': 864000 + } + """ data = {} data.update(BearerToken.GRANT_TYPES_EXPIRES_IN) + + expires_in_conf = config.get('OAUTH2_TOKEN_EXPIRES_IN') if expires_in_conf: data.update(expires_in_conf) - def expires_in(client, grant_type): + def expires_in(client, grant_type): # pylint: disable=W0613 return data.get(grant_type, BearerToken.DEFAULT_EXPIRES_IN) return expires_in def create_token_generator(token_generator_conf, length=42): + """Create a token generator function.""" if callable(token_generator_conf): return token_generator_conf if isinstance(token_generator_conf, str): return import_string(token_generator_conf) - elif token_generator_conf is True: - def token_generator(*args, **kwargs): + + if token_generator_conf is True: + def token_generator(*args, **kwargs): # pylint: disable=W0613 return generate_token(length) return token_generator + + return None diff --git a/authlib/integrations/fastapi_oauth2/errors.py b/authlib/integrations/fastapi_oauth2/errors.py deleted file mode 100644 index bed50e55..00000000 --- a/authlib/integrations/fastapi_oauth2/errors.py +++ /dev/null @@ -1,9 +0,0 @@ -from fastapi import HTTPException - - -def raise_http_exception(status, body, headers): - raise HTTPException( - status_code=status, - detail=body, - headers=dict(headers) - ) diff --git a/authlib/integrations/fastapi_oauth2/resource_protector.py b/authlib/integrations/fastapi_oauth2/resource_protector.py index 8d43c311..5e89cea7 100644 --- a/authlib/integrations/fastapi_oauth2/resource_protector.py +++ b/authlib/integrations/fastapi_oauth2/resource_protector.py @@ -1,5 +1,8 @@ +"""Implementation of authlib.oauth2.rfc6749.ResourceProtector class for FastAPI.""" + import functools from contextlib import contextmanager +from fastapi import HTTPException from authlib.oauth2 import ( OAuth2Error, ResourceProtector as _ResourceProtector @@ -8,57 +11,15 @@ MissingAuthorizationError, HttpRequest, ) -from .errors import raise_http_exception class ResourceProtector(_ResourceProtector): - """A protecting method for resource servers. Creating a ``require_oauth`` - decorator easily with ResourceProtector:: - - from authlib.integrations.flask_oauth2 import ResourceProtector - - require_oauth = ResourceProtector() - - # add bearer token validator - from authlib.oauth2.rfc6750 import BearerTokenValidator - from project.models import Token - - class MyBearerTokenValidator(BearerTokenValidator): - def authenticate_token(self, token_string): - return Token.query.filter_by(access_token=token_string).first() - - def request_invalid(self, request): - return False - - def token_revoked(self, token): - return False - - require_oauth.register_token_validator(MyBearerTokenValidator()) - - # protect resource with require_oauth - - @app.route('/user') - @require_oauth('profile') - def user_profile(): - user = User.query.get(current_token.user_id) - return jsonify(user.to_dict()) - - """ - def raise_error_response(self, error): - """Raise HTTPException for OAuth2Error. Developers can re-implement - this method to customize the error response. - - :param error: OAuth2Error - :raise: HTTPException - """ - status = error.status_code - body = dict(error.get_body()) - headers = error.get_headers() - raise_http_exception(status, body, headers) + """ResourceProtector class.""" def acquire_token(self, request=None, scope=None, operator='AND'): """A method to acquire current valid token with the given scope. + :param request: request object :param scope: string or list of scope values :param operator: value of "AND" or "OR" :return: token object @@ -77,31 +38,36 @@ def acquire_token(self, request=None, scope=None, operator='AND'): @contextmanager def acquire(self, request=None, scope=None, operator='AND'): """The with statement of ``require_oauth``. Instead of using a - decorator, you can use a with statement instead:: - - @app.route('/api/user') - def user_api(): - with require_oauth.acquire('profile') as token: - user = User.query.get(token.user_id) - return jsonify(user.to_dict()) - """ + decorator, you can use a with statement instead.""" try: yield self.acquire_token(request, scope, operator) except OAuth2Error as error: - self.raise_error_response(error) + raise_error_response(error) def __call__(self, scope=None, operator='AND', optional=False): - def wrapper(f): - @functools.wraps(f) + def wrapper(func): + @functools.wraps(func) def decorated(*args, **kwargs): try: self.acquire_token(scope, operator) except MissingAuthorizationError as error: if optional: - return f(*args, **kwargs) - self.raise_error_response(error) + return func(*args, **kwargs) + raise_error_response(error) except OAuth2Error as error: - self.raise_error_response(error) - return f(*args, **kwargs) + raise_error_response(error) + return func(*args, **kwargs) return decorated return wrapper + + +def raise_error_response(error): + """Raise the FastAPI HTTPException method.""" + status = error.status_code + body = dict(error.get_body()) + headers = error.get_headers() + raise HTTPException( + status_code=status, + detail=body, + headers=dict(headers) + ) From 3e30635cde4637b9107c2d25d407e42561018f39 Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Fri, 16 Oct 2020 19:04:09 -0300 Subject: [PATCH 03/16] Implemented the authorization code grant pytests --- .gitignore | 1 + tests/fastapi/__init__.py | 0 tests/fastapi/test_oauth2/__init__.py | 0 tests/fastapi/test_oauth2/database.py | 15 ++ tests/fastapi/test_oauth2/models.py | 103 +++++++ tests/fastapi/test_oauth2/oauth2_server.py | 120 +++++++++ .../test_authorization_code_grant.py | 254 ++++++++++++++++++ tox.ini | 7 + 8 files changed, 500 insertions(+) create mode 100644 tests/fastapi/__init__.py create mode 100644 tests/fastapi/test_oauth2/__init__.py create mode 100644 tests/fastapi/test_oauth2/database.py create mode 100644 tests/fastapi/test_oauth2/models.py create mode 100644 tests/fastapi/test_oauth2/oauth2_server.py create mode 100644 tests/fastapi/test_oauth2/test_authorization_code_grant.py diff --git a/.gitignore b/.gitignore index b0bcd0b1..467ba799 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.pyo *.egg-info *.swp +*.db __pycache__ build develop-eggs diff --git a/tests/fastapi/__init__.py b/tests/fastapi/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fastapi/test_oauth2/__init__.py b/tests/fastapi/test_oauth2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fastapi/test_oauth2/database.py b/tests/fastapi/test_oauth2/database.py new file mode 100644 index 00000000..8cff1ce2 --- /dev/null +++ b/tests/fastapi/test_oauth2/database.py @@ -0,0 +1,15 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +SQLALCHEMY_DATABASE_URL = 'sqlite:///fastapi_auth2_sql.db' + +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={'check_same_thread': False} +) + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() + +db = SessionLocal() diff --git a/tests/fastapi/test_oauth2/models.py b/tests/fastapi/test_oauth2/models.py new file mode 100644 index 00000000..0fdedd54 --- /dev/null +++ b/tests/fastapi/test_oauth2/models.py @@ -0,0 +1,103 @@ +import time +from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy.orm import relationship +from authlib.integrations.sqla_oauth2 import ( + OAuth2ClientMixin, + OAuth2TokenMixin, + OAuth2AuthorizationCodeMixin, +) +from authlib.oidc.core import UserInfo +from .database import Base, db + + +class User(Base): + __tablename__ = 'user' + + id = Column(Integer, primary_key=True) + username = Column(String(40), unique=True, nullable=False) + + def get_user_id(self): + return self.id + + def check_password(self, password): + return password != 'wrong' + + def generate_user_info(self, scopes): + profile = {'sub': str(self.id), 'name': self.username} + return UserInfo(profile) + + +class Client(Base, OAuth2ClientMixin): + __tablename__ = 'oauth2_client' + + id = Column(Integer, primary_key=True) + user_id = Column( + Integer, ForeignKey('user.id', ondelete='CASCADE') + ) + user = relationship('User') + + +class AuthorizationCode(Base, OAuth2AuthorizationCodeMixin): + __tablename__ = 'oauth2_code' + + id = Column(Integer, primary_key=True) + user_id = Column(Integer, nullable=False) + + @property + def user(self): + return db.query(User).filter( + User.id == self.user_id).first() + + +class Token(Base, OAuth2TokenMixin): + __tablename__ = 'oauth2_token' + + id = Column(Integer, primary_key=True) + user_id = Column( + Integer, ForeignKey('user.id', ondelete='CASCADE') + ) + user = relationship('User') + + def is_refresh_token_expired(self): + expired_at = self.issued_at + self.expires_in * 2 + return expired_at < time.time() + + +class CodeGrantMixin(object): + def query_authorization_code(self, code, client): + item = db.query(AuthorizationCode).filter( + AuthorizationCode.code == code, + Client.client_id == client.client_id).first() + if item and not item.is_expired(): + return item + + def delete_authorization_code(self, authorization_code): + db.delete(authorization_code) + db.commit() + + def authenticate_user(self, authorization_code): + return db.query(User).filter( + User.id == authorization_code.user_id).first() + + +def save_authorization_code(code, request): + client = request.client + auth_code = AuthorizationCode( + code=code, + client_id=client.client_id, + redirect_uri=request.redirect_uri, + scope=request.scope, + nonce=request.data.get('nonce'), + user_id=request.user.id, + code_challenge=request.data.get('code_challenge'), + code_challenge_method=request.data.get('code_challenge_method'), + ) + db.add(auth_code) + db.commit() + return auth_code + + +def exists_nonce(nonce, req): + exists = db.query(AuthorizationCode).filter( + Client.client_id == req.client_id, AuthorizationCode.nonce == nonce).first() + return bool(exists) diff --git a/tests/fastapi/test_oauth2/oauth2_server.py b/tests/fastapi/test_oauth2/oauth2_server.py new file mode 100644 index 00000000..a6915302 --- /dev/null +++ b/tests/fastapi/test_oauth2/oauth2_server.py @@ -0,0 +1,120 @@ +import os +import base64 +import unittest +from fastapi import FastAPI, Request, Form +from fastapi.testclient import TestClient +from authlib.common.security import generate_token +from authlib.common.encoding import to_bytes, to_unicode +from authlib.common.urls import url_encode +from authlib.integrations.sqla_oauth2 import ( + create_query_client_func, + create_save_token_func, +) +from authlib.integrations.fastapi_oauth2 import AuthorizationServer +from authlib.oauth2 import OAuth2Error +from .models import User, Client, Token +from .database import Base, engine, db + +os.environ['AUTHLIB_INSECURE_TRANSPORT'] = 'true' + + +def token_generator(client, grant_type, user=None, scope=None): + token = '{}-{}'.format(client.client_id[0], grant_type) + if user: + token = '{}.{}'.format(token, user.get_user_id()) + return '{}.{}'.format(token, generate_token(32)) + + +def create_authorization_server(app): + query_client = create_query_client_func(db, Client) + save_token = create_save_token_func(db, Token) + + server = AuthorizationServer() + server.init_app(app, query_client, save_token) + + @app.get('/oauth/authorize') + def authorize(request: Request): + user_id = request.query_params.get('user_id') + request.body = {} + if user_id: + end_user = db.query(User).filter(User.id == int(user_id)).first() + else: + end_user = None + try: + grant = server.validate_consent_request(request=request, end_user=end_user) + return grant.prompt or 'ok' + except OAuth2Error as error: + return url_encode(error.get_body()) + + @app.post('/oauth/authorize') + def authorize(request: Request, user_id: str = Form('')): + request.body = {} + if user_id: + grant_user = db.query(User).filter(User.id == int(user_id)).first() + else: + grant_user = None + return server.create_authorization_response(request=request, grant_user=grant_user) + + @app.post('/oauth/token') + def issue_token( + request: Request, + grant_type: str = Form(...), + scope: str = Form(None), + code: str = Form(None), + refresh_token: str = Form(None), + code_verifier: str = Form(None), + client_id: str = Form(None), + client_secret: str = Form(None)): + request.body = { + 'grant_type': grant_type, + 'scope': scope, + } + if grant_type == 'authorization_code': + request.body['code'] = code + elif grant_type == 'refresh_token': + request.body['refresh_token'] = refresh_token + + if code_verifier: + request.body['code_verifier'] = code_verifier + + if client_id: + request.body['client_id'] = client_id + + if client_secret: + request.body['client_secret'] = client_secret + + return server.create_token_response(request=request) + + return server + + +def create_fastapi_app(): + app = FastAPI() + app.debug = True + app.testing = True + app.secret_key = 'testing' + app.test_client = TestClient(app) + app.config = { + 'OAUTH2_ERROR_URIS': [ + ('invalid_client', 'https://a.b/e#invalid_client') + ] + } + return app + + +class TestCase(unittest.TestCase): + def setUp(self): + app = create_fastapi_app() + + Base.metadata.create_all(bind=engine) + + self.app = app + self.client = app.test_client + + def tearDown(self): + Base.metadata.drop_all(bind=engine) + + def create_basic_header(self, username, password): + text = '{}:{}'.format(username, password) + auth = to_unicode(base64.b64encode(to_bytes(text))) + return {'Authorization': 'Basic ' + auth} diff --git a/tests/fastapi/test_oauth2/test_authorization_code_grant.py b/tests/fastapi/test_oauth2/test_authorization_code_grant.py new file mode 100644 index 00000000..367691c5 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_authorization_code_grant.py @@ -0,0 +1,254 @@ +import json +from authlib.common.urls import urlparse, url_decode +from authlib.oauth2.rfc6749.grants import ( + AuthorizationCodeGrant as _AuthorizationCodeGrant, +) +from .database import db +from .models import User, Client, AuthorizationCode +from .models import CodeGrantMixin, save_authorization_code +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] + + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + +class AuthorizationCodeTest(TestCase): + + def register_grant(self, server): + server.register_grant(AuthorizationCodeGrant) + + def prepare_data( + self, is_confidential=True, + response_type='code', grant_type='authorization_code', + token_endpoint_auth_method='client_secret_basic'): + server = create_authorization_server(self.app) + self.register_grant(server) + self.server = server + + user = User(username='foo') + db.add(user) + db.commit() + + if is_confidential: + client_secret = 'code-secret' + else: + client_secret = '' + client = Client( + user_id=user.id, + client_id='code-client', + client_secret=client_secret, + ) + client.set_client_metadata({ + 'redirect_uris': ['https://a.b'], + 'scope': 'profile address', + 'token_endpoint_auth_method': token_endpoint_auth_method, + 'response_types': [response_type], + 'grant_types': grant_type.splitlines(), + }) + self.authorize_url = ( + '/oauth/authorize?response_type=code' + '&client_id=code-client' + ) + db.add(client) + db.commit() + + def test_get_authorize(self): + self.prepare_data() + rv = self.client.get(self.authorize_url) + self.assertEqual(rv.json(), 'ok') + + def test_invalid_client_id(self): + self.prepare_data() + url = '/oauth/authorize?response_type=code' + rv = self.client.get(url) + self.assertIn('invalid_client', rv.json()) + + url = '/oauth/authorize?response_type=code&client_id=invalid' + rv = self.client.get(url) + self.assertIn('invalid_client', rv.json()) + + def test_invalid_authorize(self): + self.prepare_data() + rv = self.client.post(self.authorize_url) + self.assertIn('error=access_denied', rv.headers['location']) + + self.server.metadata = {'scopes_supported': ['profile']} + rv = self.client.post(self.authorize_url + '&scope=invalid&state=foo') + self.assertIn('error=invalid_scope', rv.headers['location']) + self.assertIn('state=foo', rv.headers['location']) + + def test_unauthorized_client(self): + self.prepare_data(True, 'token') + rv = self.client.get(self.authorize_url) + self.assertIn('unauthorized_client', rv.json()) + + def test_invalid_client(self): + self.prepare_data() + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': 'invalid', + 'client_id': 'invalid-id', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = self.create_basic_header('code-client', 'invalid-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': 'invalid', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + self.assertEqual(resp['error_uri'], 'https://a.b/e#invalid_client') + + def test_invalid_code(self): + self.prepare_data() + + headers = self.create_basic_header('code-client', 'code-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': 'invalid', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + code = AuthorizationCode( + code='no-user', + client_id='code-client', + user_id=0 + ) + db.add(code) + db.commit() + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': 'no-user', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + def test_invalid_redirect_uri(self): + self.prepare_data() + uri = self.authorize_url + '&redirect_uri=https%3A%2F%2Fa.c' + rv = self.client.post(uri, data={'user_id': '1'}) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + uri = self.authorize_url + '&redirect_uri=https%3A%2F%2Fa.b' + rv = self.client.post(uri, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + headers = self.create_basic_header('code-client', 'code-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + def test_invalid_grant_type(self): + self.prepare_data( + False, token_endpoint_auth_method='none', + grant_type='invalid' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'client_id': 'code-client', + 'code': 'a', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'unauthorized_client') + + def test_authorize_token_no_refresh_token(self): + self.app.config.update({'OAUTH2_REFRESH_TOKEN_GENERATOR': True}) + self.prepare_data(False, token_endpoint_auth_method='none') + + rv = self.client.post(self.authorize_url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + 'client_id': 'code-client', + }) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertNotIn('refresh_token', resp) + + def test_authorize_token_has_refresh_token(self): + # generate refresh token + self.app.config.update({'OAUTH2_REFRESH_TOKEN_GENERATOR': True}) + self.prepare_data(grant_type='authorization_code\nrefresh_token') + url = self.authorize_url + '&state=bar' + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + self.assertEqual(params['state'], 'bar') + + code = params['code'] + headers = self.create_basic_header('code-client', 'code-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('refresh_token', resp) + + def test_client_secret_post(self): + self.app.config.update({'OAUTH2_REFRESH_TOKEN_GENERATOR': True}) + self.prepare_data( + grant_type='authorization_code\nrefresh_token', + token_endpoint_auth_method='client_secret_post', + ) + url = self.authorize_url + '&state=bar' + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + self.assertEqual(params['state'], 'bar') + + code = params['code'] + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'client_id': 'code-client', + 'client_secret': 'code-secret', + 'code': code, + }) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('refresh_token', resp) + + def test_token_generator(self): + m = 'tests.fastapi.test_oauth2.oauth2_server:token_generator' + self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) + self.prepare_data(False, token_endpoint_auth_method='none') + + rv = self.client.post(self.authorize_url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + 'client_id': 'code-client', + }) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('c-authorization_code.1.', resp['access_token']) diff --git a/tox.ini b/tox.ini index a8c5a354..19963ba7 100644 --- a/tox.ini +++ b/tox.ini @@ -2,6 +2,7 @@ envlist = py{27,36,37,38} {py36,py37,py38} + {py27,py36,py37,py38}-fastapi {py27,py36,py37,py38}-flask {py36,py37,py38}-django coverage @@ -10,6 +11,11 @@ envlist = deps = -rrequirements-test.txt py27: unittest2 + fastapi: FastAPI + fastapi: sqlalchemy + fastapi: mangum + fastapi: werkzeug + fastapi: python-multipart flask: Flask flask: Flask-SQLAlchemy py3: httpx==0.14.3 @@ -25,6 +31,7 @@ setenv = RCFILE=setup.cfg py27: RCFILE=.py27conf py3: TESTPATH=tests/py3 + fastapi: TESTPATH=tests/fastapi flask: TESTPATH=tests/flask django: TESTPATH=tests/django commands = From cbef5fbef86099d8b7f12653b0902e8f3779714f Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Mon, 19 Oct 2020 18:53:51 -0300 Subject: [PATCH 04/16] Implemented the client credentials grant and registration pytests for FastAPI integration --- tests/fastapi/test_oauth2/oauth2_server.py | 6 +- .../test_authorization_code_grant.py | 1 - .../test_client_credentials_grant.py | 95 +++++++ .../test_client_registration_endpoint.py | 194 ++++++++++++++ .../test_oauth2/test_code_challenge.py | 225 ++++++++++++++++ .../test_oauth2/test_device_code_grant.py | 244 ++++++++++++++++++ 6 files changed, 763 insertions(+), 2 deletions(-) create mode 100644 tests/fastapi/test_oauth2/test_client_credentials_grant.py create mode 100644 tests/fastapi/test_oauth2/test_client_registration_endpoint.py create mode 100644 tests/fastapi/test_oauth2/test_code_challenge.py create mode 100644 tests/fastapi/test_oauth2/test_device_code_grant.py diff --git a/tests/fastapi/test_oauth2/oauth2_server.py b/tests/fastapi/test_oauth2/oauth2_server.py index a6915302..2017b569 100644 --- a/tests/fastapi/test_oauth2/oauth2_server.py +++ b/tests/fastapi/test_oauth2/oauth2_server.py @@ -64,7 +64,8 @@ def issue_token( refresh_token: str = Form(None), code_verifier: str = Form(None), client_id: str = Form(None), - client_secret: str = Form(None)): + client_secret: str = Form(None), + device_code: str = Form(None)): request.body = { 'grant_type': grant_type, 'scope': scope, @@ -83,6 +84,9 @@ def issue_token( if client_secret: request.body['client_secret'] = client_secret + if device_code: + request.body['device_code'] = device_code + return server.create_token_response(request=request) return server diff --git a/tests/fastapi/test_oauth2/test_authorization_code_grant.py b/tests/fastapi/test_oauth2/test_authorization_code_grant.py index 367691c5..edb0be14 100644 --- a/tests/fastapi/test_oauth2/test_authorization_code_grant.py +++ b/tests/fastapi/test_oauth2/test_authorization_code_grant.py @@ -1,4 +1,3 @@ -import json from authlib.common.urls import urlparse, url_decode from authlib.oauth2.rfc6749.grants import ( AuthorizationCodeGrant as _AuthorizationCodeGrant, diff --git a/tests/fastapi/test_oauth2/test_client_credentials_grant.py b/tests/fastapi/test_oauth2/test_client_credentials_grant.py new file mode 100644 index 00000000..77d56551 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_client_credentials_grant.py @@ -0,0 +1,95 @@ +from authlib.oauth2.rfc6749.grants import ClientCredentialsGrant +from .database import db +from .models import User, Client +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class ClientCredentialsTest(TestCase): + def prepare_data(self, grant_type='client_credentials'): + server = create_authorization_server(self.app) + server.register_grant(ClientCredentialsGrant) + self.server = server + + user = User(username='foo') + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id='credential-client', + client_secret='credential-secret', + ) + client.set_client_metadata({ + 'scope': 'profile', + 'redirect_uris': ['http://localhost/authorized'], + 'grant_types': [grant_type] + }) + db.add(client) + db.commit() + + def test_invalid_client(self): + self.prepare_data() + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = self.create_basic_header( + 'credential-client', 'invalid-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_invalid_grant_type(self): + self.prepare_data(grant_type='invalid') + headers = self.create_basic_header( + 'credential-client', 'credential-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'unauthorized_client') + + def test_invalid_scope(self): + self.prepare_data() + self.server.metadata = {'scopes_supported': ['profile']} + headers = self.create_basic_header( + 'credential-client', 'credential-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'scope': 'invalid', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_scope') + + def test_authorize_token(self): + self.prepare_data() + headers = self.create_basic_header( + 'credential-client', 'credential-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_token_generator(self): + m = 'tests.fastapi.test_oauth2.oauth2_server:token_generator' + self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) + + self.prepare_data() + headers = self.create_basic_header( + 'credential-client', 'credential-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('c-client_credentials.', resp['access_token']) diff --git a/tests/fastapi/test_oauth2/test_client_registration_endpoint.py b/tests/fastapi/test_oauth2/test_client_registration_endpoint.py new file mode 100644 index 00000000..99bf9add --- /dev/null +++ b/tests/fastapi/test_oauth2/test_client_registration_endpoint.py @@ -0,0 +1,194 @@ +from pydantic import BaseModel +from fastapi import Request +from authlib.jose import jwt +from authlib.oauth2.rfc7591 import ClientRegistrationEndpoint as _ClientRegistrationEndpoint +from tests.util import read_file_path +from .database import db +from .models import User, Client +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class ClientRegistrationEndpoint(_ClientRegistrationEndpoint): + software_statement_alg_values_supported = ['RS256'] + + def authenticate_token(self, request): + auth_header = request.headers.get('Authorization') + if auth_header: + request.user_id = 1 + return auth_header + + def resolve_public_key(self, request): + return read_file_path('rsa_public.pem') + + def save_client(self, client_info, client_metadata, request): + client = Client( + user_id=request.user_id, + **client_info + ) + client.set_client_metadata(client_metadata) + db.add(client) + db.commit() + return client + + +class ClientRegistrationTest(TestCase): + def prepare_data(self, endpoint_cls=None, metadata=None): + app = self.app + server = create_authorization_server(app) + if metadata: + server.metadata = metadata + + if endpoint_cls is None: + endpoint_cls = ClientRegistrationEndpoint + server.register_endpoint(endpoint_cls) + + class Item(BaseModel): + client_name: str = None + client_uri: str = None + redirect_uri: str = None + scope: str = None + software_statement: str = None + token_endpoint_auth_method: str = None + grant_types: list = None + response_types: list = None + + @app.post('/create_client') + def create_client(request: Request, item: Item = None): + request.body = {} + if item: + request.body = { + 'client_name': item.client_name, + 'client_uri': item.client_uri, + 'redirect_uri': item.redirect_uri, + 'scope': item.scope, + 'software_statement': item.software_statement, + 'token_endpoint_auth_method': item.token_endpoint_auth_method, + 'grant_types': item.grant_types, + 'response_types': item.response_types, + } + return server.create_endpoint_response('client_registration', request=request) + + user = User(username='foo') + db.add(user) + db.commit() + + def test_access_denied(self): + self.prepare_data() + rv = self.client.post('/create_client') + resp = rv.json() + self.assertEqual(resp['error'], 'access_denied') + + def test_invalid_request(self): + self.prepare_data() + headers = {'Authorization': 'bearer abc'} + rv = self.client.post('/create_client', headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + def test_create_client(self): + self.prepare_data() + headers = {'Authorization': 'bearer abc'} + body = { + 'client_name': 'Authlib' + } + rv = self.client.post('/create_client', json=body, headers=headers) + resp = rv.json() + self.assertIn('client_id', resp) + self.assertEqual(resp['client_name'], 'Authlib') + + def test_software_statement(self): + payload = {'software_id': 'uuid-123', 'client_name': 'Authlib'} + s = jwt.encode({'alg': 'RS256'}, payload, read_file_path('rsa_private.pem')) + body = { + 'software_statement': s.decode('utf-8'), + } + + self.prepare_data() + headers = {'Authorization': 'bearer abc'} + rv = self.client.post('/create_client', json=body, headers=headers) + resp = rv.json() + self.assertIn('client_id', resp) + self.assertEqual(resp['client_name'], 'Authlib') + + def test_no_public_key(self): + + class ClientRegistrationEndpoint2(ClientRegistrationEndpoint): + def resolve_public_key(self, request): + return None + + payload = {'software_id': 'uuid-123', 'client_name': 'Authlib'} + s = jwt.encode({'alg': 'RS256'}, payload, read_file_path('rsa_private.pem')) + body = { + 'software_statement': s.decode('utf-8'), + } + + self.prepare_data(ClientRegistrationEndpoint2) + headers = {'Authorization': 'bearer abc'} + rv = self.client.post('/create_client', json=body, headers=headers) + resp = rv.json() + self.assertIn(resp['error'], 'unapproved_software_statement') + + def test_scopes_supported(self): + metadata = {'scopes_supported': ['profile', 'email']} + self.prepare_data(metadata=metadata) + + headers = {'Authorization': 'bearer abc'} + body = {'scope': 'profile email', 'client_name': 'Authlib'} + rv = self.client.post('/create_client', json=body, headers=headers) + resp = rv.json() + self.assertIn('client_id', resp) + self.assertEqual(resp['client_name'], 'Authlib') + + body = {'scope': 'profile email address', 'client_name': 'Authlib'} + rv = self.client.post('/create_client', json=body, headers=headers) + resp = rv.json() + self.assertIn(resp['error'], 'invalid_client_metadata') + + def test_response_types_supported(self): + metadata = {'response_types_supported': ['code']} + self.prepare_data(metadata=metadata) + + headers = {'Authorization': 'bearer abc'} + body = {'response_types': ['code'], 'client_name': 'Authlib'} + rv = self.client.post('/create_client', json=body, headers=headers) + resp = rv.json() + self.assertIn('client_id', resp) + self.assertEqual(resp['client_name'], 'Authlib') + + body = {'response_types': ['code', 'token'], 'client_name': 'Authlib'} + rv = self.client.post('/create_client', json=body, headers=headers) + resp = rv.json() + self.assertIn(resp['error'], 'invalid_client_metadata') + + def test_grant_types_supported(self): + metadata = {'grant_types_supported': ['authorization_code', 'password']} + self.prepare_data(metadata=metadata) + + headers = {'Authorization': 'bearer abc'} + body = {'grant_types': ['password'], 'client_name': 'Authlib'} + rv = self.client.post('/create_client', json=body, headers=headers) + resp = rv.json() + self.assertIn('client_id', resp) + self.assertEqual(resp['client_name'], 'Authlib') + + body = {'grant_types': ['client_credentials'], 'client_name': 'Authlib'} + rv = self.client.post('/create_client', json=body, headers=headers) + resp = rv.json() + self.assertIn(resp['error'], 'invalid_client_metadata') + + def test_token_endpoint_auth_methods_supported(self): + metadata = {'token_endpoint_auth_methods_supported': ['client_secret_basic']} + self.prepare_data(metadata=metadata) + + headers = {'Authorization': 'bearer abc'} + body = {'token_endpoint_auth_method': 'client_secret_basic', 'client_name': 'Authlib'} + rv = self.client.post('/create_client', json=body, headers=headers) + resp = rv.json() + self.assertIn('client_id', resp) + self.assertEqual(resp['client_name'], 'Authlib') + + body = {'token_endpoint_auth_method': 'none', 'client_name': 'Authlib'} + rv = self.client.post('/create_client', json=body, headers=headers) + resp = rv.json() + self.assertIn(resp['error'], 'invalid_client_metadata') diff --git a/tests/fastapi/test_oauth2/test_code_challenge.py b/tests/fastapi/test_oauth2/test_code_challenge.py new file mode 100644 index 00000000..90df8186 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_code_challenge.py @@ -0,0 +1,225 @@ +from authlib.common.security import generate_token +from authlib.common.urls import urlparse, url_decode +from authlib.oauth2.rfc6749 import grants +from authlib.oauth2.rfc7636 import ( + CodeChallenge as _CodeChallenge, + create_s256_code_challenge, +) +from .database import db +from .models import User, Client +from .models import CodeGrantMixin, save_authorization_code +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class AuthorizationCodeGrant(CodeGrantMixin, grants.AuthorizationCodeGrant): + TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] + + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + +class CodeChallenge(_CodeChallenge): + SUPPORTED_CODE_CHALLENGE_METHOD = ['plain', 'S256', 'S128'] + + +class CodeChallengeTest(TestCase): + def prepare_data(self, token_endpoint_auth_method='none'): + server = create_authorization_server(self.app) + server.register_grant( + AuthorizationCodeGrant, + [CodeChallenge(required=True)] + ) + + user = User(username='foo') + db.add(user) + db.commit() + + client_secret = '' + if token_endpoint_auth_method != 'none': + client_secret = 'code-secret' + + client = Client( + user_id=user.id, + client_id='code-client', + client_secret=client_secret, + ) + client.set_client_metadata({ + 'redirect_uris': ['https://a.b'], + 'scope': 'profile address', + 'token_endpoint_auth_method': token_endpoint_auth_method, + 'response_types': ['code'], + 'grant_types': ['authorization_code'], + }) + self.authorize_url = ( + '/oauth/authorize?response_type=code' + '&client_id=code-client' + ) + db.add(client) + db.commit() + + def test_missing_code_challenge(self): + self.prepare_data() + rv = self.client.get(self.authorize_url + '&code_challenge_method=plain') + self.assertIn('Missing', rv.json()) + + def test_has_code_challenge(self): + self.prepare_data() + rv = self.client.get(self.authorize_url + '&code_challenge=abc') + self.assertEqual(rv.json(), 'ok') + + def test_invalid_code_challenge_method(self): + self.prepare_data() + suffix = '&code_challenge=abc&code_challenge_method=invalid' + rv = self.client.get(self.authorize_url + suffix) + self.assertIn('Unsupported', rv.json()) + + def test_supported_code_challenge_method(self): + self.prepare_data() + suffix = '&code_challenge=abc&code_challenge_method=plain' + rv = self.client.get(self.authorize_url + suffix) + self.assertEqual(rv.json(), 'ok') + + def test_trusted_client_without_code_challenge(self): + self.prepare_data('client_secret_basic') + rv = self.client.get(self.authorize_url) + self.assertEqual(rv.json(), 'ok') + + rv = self.client.post(self.authorize_url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + + code = params['code'] + headers = self.create_basic_header('code-client', 'code-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_missing_code_verifier(self): + self.prepare_data() + url = self.authorize_url + '&code_challenge=foo' + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + 'client_id': 'code-client', + }) + resp = rv.json() + self.assertIn('Missing', resp['error_description']) + + def test_trusted_client_missing_code_verifier(self): + self.prepare_data('client_secret_basic') + url = self.authorize_url + '&code_challenge=foo' + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + headers = self.create_basic_header('code-client', 'code-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('Missing', resp['error_description']) + + def test_plain_code_challenge_invalid(self): + self.prepare_data() + url = self.authorize_url + '&code_challenge=foo' + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': 'bar', + 'client_id': 'code-client', + }) + resp = rv.json() + self.assertIn('Invalid', resp['error_description']) + + def test_plain_code_challenge_failed(self): + self.prepare_data() + url = self.authorize_url + '&code_challenge=foo' + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': generate_token(48), + 'client_id': 'code-client', + }) + resp = rv.json() + self.assertIn('failed', resp['error_description']) + + def test_plain_code_challenge_success(self): + self.prepare_data() + code_verifier = generate_token(48) + url = self.authorize_url + '&code_challenge=' + code_verifier + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': code_verifier, + 'client_id': 'code-client', + }) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_s256_code_challenge_success(self): + self.prepare_data() + code_verifier = generate_token(48) + code_challenge = create_s256_code_challenge(code_verifier) + url = self.authorize_url + '&code_challenge=' + code_challenge + url += '&code_challenge_method=S256' + + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': code_verifier, + 'client_id': 'code-client', + }) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_not_implemented_code_challenge_method(self): + self.prepare_data() + url = self.authorize_url + '&code_challenge=foo' + url += '&code_challenge_method=S128' + + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + code = params['code'] + self.assertRaises( + RuntimeError, self.client.post, '/oauth/token', + data={ + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': generate_token(48), + 'client_id': 'code-client', + } + ) diff --git a/tests/fastapi/test_oauth2/test_device_code_grant.py b/tests/fastapi/test_oauth2/test_device_code_grant.py new file mode 100644 index 00000000..bc76a7ff --- /dev/null +++ b/tests/fastapi/test_oauth2/test_device_code_grant.py @@ -0,0 +1,244 @@ +import time +from fastapi import Request, Form +from authlib.oauth2.rfc8628 import ( + DeviceAuthorizationEndpoint as _DeviceAuthorizationEndpoint, + DeviceCodeGrant as _DeviceCodeGrant, + DeviceCredentialDict, +) +from .database import db +from .models import User, Client +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +device_credentials = { + 'valid-device': { + 'client_id': 'client', + 'expires_in': 1800, + 'user_code': 'code', + }, + 'expired-token': { + 'client_id': 'client', + 'expires_in': -100, + 'user_code': 'none', + }, + 'invalid-client': { + 'client_id': 'invalid', + 'expires_in': 1800, + 'user_code': 'none', + }, + 'denied-code': { + 'client_id': 'client', + 'expires_in': 1800, + 'user_code': 'denied', + }, + 'grant-code': { + 'client_id': 'client', + 'expires_in': 1800, + 'user_code': 'code', + }, + 'pending-code': { + 'client_id': 'client', + 'expires_in': 1800, + 'user_code': 'none', + } +} + +class DeviceCodeGrant(_DeviceCodeGrant): + def query_device_credential(self, device_code): + data = device_credentials.get(device_code) + if not data: + return None + + now = int(time.time()) + data['expires_at'] = now + data['expires_in'] + data['device_code'] = device_code + data['scope'] = 'profile' + data['interval'] = 5 + data['verification_uri'] = 'https://example.com/activate' + return DeviceCredentialDict(data) + + def query_user_grant(self, user_code): + if user_code == 'code': + return db.query(User).filter(User.id == 1).first(), True + if user_code == 'denied': + return db.query(User).filter(User.id == 1).first(), False + return None + + def should_slow_down(self, credential, now): + return False + + +class DeviceCodeGrantTest(TestCase): + def create_server(self): + server = create_authorization_server(self.app) + server.register_grant(DeviceCodeGrant) + self.server = server + return server + + def prepare_data(self, grant_type=DeviceCodeGrant.GRANT_TYPE): + user = User(username='foo') + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id='client', + client_secret='secret', + ) + client.set_client_metadata({ + 'redirect_uris': ['http://localhost/authorized'], + 'scope': 'profile', + 'grant_types': [grant_type], + }) + db.add(client) + db.commit() + + def test_invalid_request(self): + self.create_server() + self.prepare_data() + rv = self.client.post('/oauth/token', data={ + 'grant_type': DeviceCodeGrant.GRANT_TYPE, + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + rv = self.client.post('/oauth/token', data={ + 'grant_type': DeviceCodeGrant.GRANT_TYPE, + 'device_code': 'valid-device', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + rv = self.client.post('/oauth/token', data={ + 'grant_type': DeviceCodeGrant.GRANT_TYPE, + 'device_code': 'missing', + 'client_id': 'client', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + def test_unauthorized_client(self): + self.create_server() + rv = self.client.post('/oauth/token', data={ + 'grant_type': DeviceCodeGrant.GRANT_TYPE, + 'device_code': 'valid-device', + 'client_id': 'invalid', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'unauthorized_client') + + self.prepare_data(grant_type='password') + rv = self.client.post('/oauth/token', data={ + 'grant_type': DeviceCodeGrant.GRANT_TYPE, + 'device_code': 'valid-device', + 'client_id': 'client', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'unauthorized_client') + + def test_invalid_client(self): + self.create_server() + self.prepare_data() + rv = self.client.post('/oauth/token', data={ + 'grant_type': DeviceCodeGrant.GRANT_TYPE, + 'device_code': 'invalid-client', + 'client_id': 'invalid', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_expired_token(self): + self.create_server() + self.prepare_data() + rv = self.client.post('/oauth/token', data={ + 'grant_type': DeviceCodeGrant.GRANT_TYPE, + 'device_code': 'expired-token', + 'client_id': 'client', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'expired_token') + + def test_denied_by_user(self): + self.create_server() + self.prepare_data() + rv = self.client.post('/oauth/token', data={ + 'grant_type': DeviceCodeGrant.GRANT_TYPE, + 'device_code': 'denied-code', + 'client_id': 'client', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'access_denied') + + def test_authorization_pending(self): + self.create_server() + self.prepare_data() + rv = self.client.post('/oauth/token', data={ + 'grant_type': DeviceCodeGrant.GRANT_TYPE, + 'device_code': 'pending-code', + 'client_id': 'client', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'authorization_pending') + + def test_get_access_token(self): + self.create_server() + self.prepare_data() + rv = self.client.post('/oauth/token', data={ + 'grant_type': DeviceCodeGrant.GRANT_TYPE, + 'device_code': 'grant-code', + 'client_id': 'client', + }) + resp = rv.json() + self.assertIn('access_token', resp) + + +class DeviceAuthorizationEndpoint(_DeviceAuthorizationEndpoint): + def get_verification_uri(self): + return 'https://example.com/activate' + + def save_device_credential(self, client_id, scope, data): + pass + + +class DeviceAuthorizationEndpointTest(TestCase): + def create_server(self): + server = create_authorization_server(self.app) + server.register_endpoint(DeviceAuthorizationEndpoint) + self.server = server + + @self.app.post('/device_authorize') + def device_authorize(request: Request, + scope: str = Form(None), + client_id: str = Form(None)): + request.body = { + 'scope': scope, + 'client_id': client_id, + } + name = DeviceAuthorizationEndpoint.ENDPOINT_NAME + return server.create_endpoint_response(name, request=request) + + return server + + def test_missing_client_id(self): + self.create_server() + rv = self.client.post('/device_authorize', data={ + 'scope': 'profile' + }) + self.assertEqual(rv.status_code, 400) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + def test_create_authorization_response(self): + self.create_server() + rv = self.client.post('/device_authorize', data={ + 'client_id': 'client', + }) + self.assertEqual(rv.status_code, 200) + resp = rv.json() + self.assertIn('device_code', resp) + self.assertIn('user_code', resp) + self.assertEqual(resp['verification_uri'], 'https://example.com/activate') + self.assertEqual( + resp['verification_uri_complete'], + 'https://example.com/activate?user_code=' + resp['user_code'] + ) From 30dad2f52c4d396f12e7216a706ccd20f314330a Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Tue, 20 Oct 2020 16:11:18 -0300 Subject: [PATCH 05/16] Implemented the implicit grant pytest for FastAPI integration --- .../test_oauth2/test_implicit_grant.py | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 tests/fastapi/test_oauth2/test_implicit_grant.py diff --git a/tests/fastapi/test_oauth2/test_implicit_grant.py b/tests/fastapi/test_oauth2/test_implicit_grant.py new file mode 100644 index 00000000..a08b0a7f --- /dev/null +++ b/tests/fastapi/test_oauth2/test_implicit_grant.py @@ -0,0 +1,83 @@ +from authlib.oauth2.rfc6749.grants import ImplicitGrant +from .database import db +from .models import User, Client +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class ImplicitTest(TestCase): + def prepare_data(self, is_confidential=False, response_type='token'): + server = create_authorization_server(self.app) + server.register_grant(ImplicitGrant) + self.server = server + + user = User(username='foo') + db.add(user) + db.commit() + if is_confidential: + client_secret = 'implicit-secret' + token_endpoint_auth_method = 'client_secret_basic' + else: + client_secret = '' + token_endpoint_auth_method = 'none' + + client = Client( + user_id=user.id, + client_id='implicit-client', + client_secret=client_secret, + ) + client.set_client_metadata({ + 'redirect_uris': ['http://localhost/authorized'], + 'scope': 'profile', + 'response_types': [response_type], + 'grant_types': ['implicit'], + 'token_endpoint_auth_method': token_endpoint_auth_method, + }) + self.authorize_url = ( + '/oauth/authorize?response_type=token' + '&client_id=implicit-client' + ) + db.add(client) + db.commit() + + def test_get_authorize(self): + self.prepare_data() + rv = self.client.get(self.authorize_url) + self.assertEqual(rv.json(), 'ok') + + def test_confidential_client(self): + self.prepare_data(True) + rv = self.client.get(self.authorize_url) + self.assertIn('invalid_client', rv.json()) + + def test_unsupported_client(self): + self.prepare_data(response_type='code') + rv = self.client.get(self.authorize_url) + self.assertIn('unauthorized_client', rv.json()) + + def test_invalid_authorize(self): + self.prepare_data() + rv = self.client.post(self.authorize_url) + self.assertIn('#error=access_denied', rv.headers['location']) + + self.server.metadata = {'scopes_supported': ['profile']} + rv = self.client.post(self.authorize_url + '&scope=invalid') + self.assertIn('#error=invalid_scope', rv.headers['location']) + + def test_authorize_token(self): + self.prepare_data() + rv = self.client.post(self.authorize_url, data={'user_id': '1'}) + self.assertIn('access_token=', rv.headers['location']) + + url = self.authorize_url + '&state=bar&scope=profile' + rv = self.client.post(url, data={'user_id': '1'}) + self.assertIn('access_token=', rv.headers['location']) + self.assertIn('state=bar', rv.headers['location']) + self.assertIn('scope=profile', rv.headers['location']) + + def test_token_generator(self): + m = 'tests.fastapi.test_oauth2.oauth2_server:token_generator' + self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) + self.prepare_data() + rv = self.client.post(self.authorize_url, data={'user_id': '1'}) + self.assertIn('access_token=i-implicit.1.', rv.headers['location']) From 340cabd3cc8f0a774d8085af5a865c46615aa07a Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Tue, 20 Oct 2020 16:48:54 -0300 Subject: [PATCH 06/16] Implemented the introspection and jwt bearer pytests for FastAPI integration --- tests/fastapi/test_oauth2/oauth2_server.py | 26 ++- .../test_introspection_endpoint.py | 166 ++++++++++++++++++ .../test_jwt_bearer_client_auth.py | 153 ++++++++++++++++ .../test_oauth2/test_jwt_bearer_grant.py | 106 +++++++++++ 4 files changed, 444 insertions(+), 7 deletions(-) create mode 100644 tests/fastapi/test_oauth2/test_introspection_endpoint.py create mode 100644 tests/fastapi/test_oauth2/test_jwt_bearer_client_auth.py create mode 100644 tests/fastapi/test_oauth2/test_jwt_bearer_grant.py diff --git a/tests/fastapi/test_oauth2/oauth2_server.py b/tests/fastapi/test_oauth2/oauth2_server.py index 2017b569..76095f1b 100644 --- a/tests/fastapi/test_oauth2/oauth2_server.py +++ b/tests/fastapi/test_oauth2/oauth2_server.py @@ -65,27 +65,39 @@ def issue_token( code_verifier: str = Form(None), client_id: str = Form(None), client_secret: str = Form(None), - device_code: str = Form(None)): + device_code: str = Form(None), + client_assertion_type: str = Form(None), + client_assertion: str = Form(None), + assertion: str = Form(None)): request.body = { 'grant_type': grant_type, 'scope': scope, } if grant_type == 'authorization_code': - request.body['code'] = code + request.body.update({'code': code}) elif grant_type == 'refresh_token': - request.body['refresh_token'] = refresh_token + request.body.update({'refresh_token': refresh_token}) if code_verifier: - request.body['code_verifier'] = code_verifier + request.body.update({'code_verifier': code_verifier}) if client_id: - request.body['client_id'] = client_id + request.body.update({'client_id': client_id}) if client_secret: - request.body['client_secret'] = client_secret + request.body.update({'client_secret': client_secret}) if device_code: - request.body['device_code'] = device_code + request.body.update({'device_code': device_code}) + + if client_assertion_type: + request.body.update({'client_assertion_type': client_assertion_type}) + + if client_assertion: + request.body.update({'client_assertion': client_assertion}) + + if assertion: + request.body.update({'assertion': assertion}) return server.create_token_response(request=request) diff --git a/tests/fastapi/test_oauth2/test_introspection_endpoint.py b/tests/fastapi/test_oauth2/test_introspection_endpoint.py new file mode 100644 index 00000000..ee51d0e0 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_introspection_endpoint.py @@ -0,0 +1,166 @@ +from fastapi import Request, Form +from authlib.integrations.sqla_oauth2 import create_query_token_func +from authlib.oauth2.rfc7662 import IntrospectionEndpoint +from .database import db +from .models import User, Client, Token +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +query_token = create_query_token_func(db, Token) + + +class MyIntrospectionEndpoint(IntrospectionEndpoint): + def query_token(self, token, token_type_hint, client): + return query_token(token, token_type_hint, client) + + def introspect_token(self, token): + user = db.query(User).filter(User.id == int(token.user_id)).first() + return { + "active": not token.revoked, + "client_id": token.client_id, + "username": user.username, + "scope": token.scope, + "sub": user.get_user_id(), + "aud": token.client_id, + "iss": "https://server.example.com/", + "exp": token.get_expires_at(), + "iat": token.issued_at, + } + + +class IntrospectTokenTest(TestCase): + def prepare_data(self): + app = self.app + + server = create_authorization_server(app) + server.register_endpoint(MyIntrospectionEndpoint) + + @app.post('/oauth/introspect') + def introspect_token(request: Request, + token: str = Form(None), + token_type_hint: str = Form(None)): + request.body = {} + + if token: + request.body.update({'token': token}) + + if token_type_hint: + request.body.update({'token_type_hint': token_type_hint}) + + return server.create_endpoint_response('introspection', request=request) + + user = User(username='foo') + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id='introspect-client', + client_secret='introspect-secret', + ) + client.set_client_metadata({ + 'scope': 'profile', + 'redirect_uris': ['http://a.b/c'], + }) + db.add(client) + db.commit() + + def create_token(self): + token = Token( + user_id=1, + client_id='introspect-client', + token_type='bearer', + access_token='a1', + refresh_token='r1', + scope='profile', + expires_in=3600, + ) + db.add(token) + db.commit() + + def test_invalid_client(self): + self.prepare_data() + rv = self.client.post('/oauth/introspect') + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = {'Authorization': 'invalid token_string'} + rv = self.client.post('/oauth/introspect', headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = self.create_basic_header( + 'invalid-client', 'introspect-secret' + ) + rv = self.client.post('/oauth/introspect', headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = self.create_basic_header( + 'introspect-client', 'invalid-secret' + ) + rv = self.client.post('/oauth/introspect', headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_invalid_token(self): + self.prepare_data() + headers = self.create_basic_header( + 'introspect-client', 'introspect-secret' + ) + rv = self.client.post('/oauth/introspect', headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + rv = self.client.post('/oauth/introspect', data={ + 'token_type_hint': 'refresh_token', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + rv = self.client.post('/oauth/introspect', data={ + 'token': 'a1', + 'token_type_hint': 'unsupported_token_type', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'unsupported_token_type') + + rv = self.client.post('/oauth/introspect', data={ + 'token': 'invalid-token', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['active'], False) + + rv = self.client.post('/oauth/introspect', data={ + 'token': 'a1', + 'token_type_hint': 'refresh_token', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['active'], False) + + def test_introspect_token_with_hint(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'introspect-client', 'introspect-secret' + ) + rv = self.client.post('/oauth/introspect', data={ + 'token': 'a1', + 'token_type_hint': 'access_token', + }, headers=headers) + self.assertEqual(rv.status_code, 200) + resp = rv.json() + self.assertEqual(resp['client_id'], 'introspect-client') + + def test_introspect_token_without_hint(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'introspect-client', 'introspect-secret' + ) + rv = self.client.post('/oauth/introspect', data={ + 'token': 'a1', + }, headers=headers) + self.assertEqual(rv.status_code, 200) + resp = rv.json() + self.assertEqual(resp['client_id'], 'introspect-client') diff --git a/tests/fastapi/test_oauth2/test_jwt_bearer_client_auth.py b/tests/fastapi/test_oauth2/test_jwt_bearer_client_auth.py new file mode 100644 index 00000000..724924ff --- /dev/null +++ b/tests/fastapi/test_oauth2/test_jwt_bearer_client_auth.py @@ -0,0 +1,153 @@ +from authlib.oauth2.rfc6749.grants import ClientCredentialsGrant +from authlib.oauth2.rfc7523 import ( + JWTBearerClientAssertion, + client_secret_jwt_sign, + private_key_jwt_sign, +) +from tests.util import read_file_path +from .database import db +from .models import User, Client +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class JWTClientCredentialsGrant(ClientCredentialsGrant): + TOKEN_ENDPOINT_AUTH_METHODS = [ + JWTBearerClientAssertion.CLIENT_AUTH_METHOD, + ] + + +class JWTClientAuth(JWTBearerClientAssertion): + def validate_jti(self, claims, jti): + return True + + def resolve_client_public_key(self, client, headers): + if headers['alg'] == 'RS256': + return read_file_path('jwk_public.json') + return client.client_secret + + +class ClientCredentialsTest(TestCase): + def prepare_data(self, auth_method, validate_jti=True): + server = create_authorization_server(self.app) + server.register_grant(JWTClientCredentialsGrant) + server.register_client_auth_method( + JWTClientAuth.CLIENT_AUTH_METHOD, + JWTClientAuth('https://localhost/oauth/token', validate_jti) + ) + + user = User(username='foo') + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id='credential-client', + client_secret='credential-secret', + ) + client.set_client_metadata({ + 'scope': 'profile', + 'redirect_uris': ['http://localhost/authorized'], + 'grant_types': ['client_credentials'], + 'token_endpoint_auth_method': auth_method, + }) + db.add(client) + db.commit() + + def test_invalid_client(self): + self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_invalid_jwt(self): + self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + 'client_assertion': client_secret_jwt_sign( + client_secret='invalid-secret', + client_id='credential-client', + token_endpoint='https://localhost/oauth/token', + ) + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_not_found_client(self): + self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + 'client_assertion': client_secret_jwt_sign( + client_secret='credential-secret', + client_id='invalid-client', + token_endpoint='https://localhost/oauth/token', + ) + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_not_supported_auth_method(self): + self.prepare_data('invalid') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + 'client_assertion': client_secret_jwt_sign( + client_secret='credential-secret', + client_id='credential-client', + token_endpoint='https://localhost/oauth/token', + ) + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_client_secret_jwt(self): + self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + 'client_assertion': client_secret_jwt_sign( + client_secret='credential-secret', + client_id='credential-client', + token_endpoint='https://localhost/oauth/token', + claims={'jti': 'nonce'}, + ) + }) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_private_key_jwt(self): + self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + 'client_assertion': private_key_jwt_sign( + private_key=read_file_path('jwk_private.json'), + client_id='credential-client', + token_endpoint='https://localhost/oauth/token', + ) + }) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_not_validate_jti(self): + self.prepare_data(JWTBearerClientAssertion.CLIENT_AUTH_METHOD, False) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'client_credentials', + 'client_assertion_type': JWTBearerClientAssertion.CLIENT_ASSERTION_TYPE, + 'client_assertion': client_secret_jwt_sign( + client_secret='credential-secret', + client_id='credential-client', + token_endpoint='https://localhost/oauth/token', + ) + }) + resp = rv.json() + self.assertIn('access_token', resp) diff --git a/tests/fastapi/test_oauth2/test_jwt_bearer_grant.py b/tests/fastapi/test_oauth2/test_jwt_bearer_grant.py new file mode 100644 index 00000000..c8587829 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_jwt_bearer_grant.py @@ -0,0 +1,106 @@ +from authlib.oauth2.rfc7523 import JWTBearerGrant as _JWTBearerGrant +from .database import db +from .models import User, Client +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class JWTBearerGrant(_JWTBearerGrant): + def authenticate_user(self, client, claims): + return None + + def authenticate_client(self, claims): + iss = claims['iss'] + return db.query(Client).filter(Client.client_id == iss).first() + + def resolve_public_key(self, headers, payload): + keys = {'1': 'foo', '2': 'bar'} + return keys[headers['kid']] + + +class JWTBearerGrantTest(TestCase): + def prepare_data(self, grant_type=None): + server = create_authorization_server(self.app) + server.register_grant(JWTBearerGrant) + + user = User(username='foo') + db.add(user) + db.commit() + if grant_type is None: + grant_type = JWTBearerGrant.GRANT_TYPE + client = Client( + user_id=user.id, + client_id='jwt-client', + client_secret='jwt-secret', + ) + client.set_client_metadata({ + 'scope': 'profile', + 'redirect_uris': ['http://localhost/authorized'], + 'grant_types': [grant_type], + }) + db.add(client) + db.commit() + + def test_missing_assertion(self): + self.prepare_data() + rv = self.client.post('/oauth/token', data={ + 'grant_type': JWTBearerGrant.GRANT_TYPE + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + self.assertIn('assertion', resp['error_description']) + + def test_invalid_assertion(self): + self.prepare_data() + assertion = JWTBearerGrant.sign( + 'foo', issuer='jwt-client', audience='https://i.b/token', + header={'alg': 'HS256', 'kid': '1'} + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': JWTBearerGrant.GRANT_TYPE, + 'assertion': assertion + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_grant') + + def test_authorize_token(self): + self.prepare_data() + assertion = JWTBearerGrant.sign( + 'foo', issuer='jwt-client', audience='https://i.b/token', + subject='self', header={'alg': 'HS256', 'kid': '1'} + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': JWTBearerGrant.GRANT_TYPE, + 'assertion': assertion + }) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_unauthorized_client(self): + self.prepare_data('password') + assertion = JWTBearerGrant.sign( + 'bar', issuer='jwt-client', audience='https://i.b/token', + subject='self', header={'alg': 'HS256', 'kid': '2'} + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': JWTBearerGrant.GRANT_TYPE, + 'assertion': assertion + }) + resp = rv.json() + self.assertEqual(resp['error'], 'unauthorized_client') + + def test_token_generator(self): + m = 'tests.fastapi.test_oauth2.oauth2_server:token_generator' + self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) + self.prepare_data() + assertion = JWTBearerGrant.sign( + 'foo', issuer='jwt-client', audience='https://i.b/token', + subject='self', header={'alg': 'HS256', 'kid': '1'} + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': JWTBearerGrant.GRANT_TYPE, + 'assertion': assertion + }) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('j-', resp['access_token']) From a167172e553275b61b92b3bf00bb8cd2830dc0f7 Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Tue, 20 Oct 2020 18:30:59 -0300 Subject: [PATCH 07/16] Implemented the oauth2 server pytests for FastAPI integration --- .../fastapi/test_oauth2/test_oauth2_server.py | 184 ++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 tests/fastapi/test_oauth2/test_oauth2_server.py diff --git a/tests/fastapi/test_oauth2/test_oauth2_server.py b/tests/fastapi/test_oauth2/test_oauth2_server.py new file mode 100644 index 00000000..94407188 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_oauth2_server.py @@ -0,0 +1,184 @@ +from fastapi import Request +from authlib.integrations.fastapi_oauth2 import ResourceProtector +from authlib.integrations.sqla_oauth2 import create_bearer_token_validator +from .database import db +from .models import User, Client, Token +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + +require_oauth = ResourceProtector() +BearerTokenValidator = create_bearer_token_validator(db, Token) +require_oauth.register_token_validator(BearerTokenValidator()) + + +def create_resource_server(app): + @app.get('/user') + def user_profile(request: Request): + with require_oauth.acquire(request, 'profile') as token: + user = token.user + return {'id': user.id, 'username': user.username} + + @app.get('/user/email') + def user_email(request: Request): + with require_oauth.acquire(request, 'email') as token: + user = token.user + return {'email': user.username + '@example.com'} + + @app.get('/info') + def public_info(request: Request): + with require_oauth.acquire(request) as token: + return {'status': 'ok'} + + @app.get('/operator-and') + def operator_and(request: Request): + with require_oauth.acquire(request, 'profile email', 'AND') as token: + return {'status': 'ok'} + + @app.get('/operator-or') + def operator_or(request: Request): + with require_oauth.acquire(request, 'profile email', 'OR') as token: + return {'status': 'ok'} + + def scope_operator(token_scopes, resource_scopes): + return 'profile' in token_scopes and 'email' not in token_scopes + + @app.get('/operator-func') + def operator_func(request: Request): + with require_oauth.acquire(request, operator=scope_operator) as token: + return {'status': 'ok'} + + @app.get('/acquire') + def test_acquire(request: Request): + with require_oauth.acquire(request, 'profile') as token: + user = token.user + return {'id': user.id, 'username': user.username} + + +class AuthorizationTest(TestCase): + def test_none_grant(self): + create_authorization_server(self.app) + authorize_url = ( + '/oauth/authorize?response_type=token' + '&client_id=implicit-client' + ) + rv = self.client.get(authorize_url) + self.assertIn('invalid_grant', rv.json()) + + rv = self.client.post(authorize_url, data={'user_id': '1'}) + self.assertNotEqual(rv.status_code, 200) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'code': 'x', + }) + data = rv.json() + self.assertEqual(data['error'], 'unsupported_grant_type') + + +class ResourceTest(TestCase): + def prepare_data(self): + create_resource_server(self.app) + + user = User(username='foo') + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id='resource-client', + client_secret='resource-secret', + ) + client.set_client_metadata({ + 'scope': 'profile', + 'redirect_uris': ['http://localhost/authorized'], + }) + db.add(client) + db.commit() + + def create_token(self, expires_in=3600): + token = Token( + user_id=1, + client_id='resource-client', + token_type='bearer', + access_token='a1', + scope='profile', + expires_in=expires_in, + ) + db.add(token) + db.commit() + + def create_bearer_header(self, token): + return {'Authorization': 'Bearer ' + token} + + def test_invalid_token(self): + self.prepare_data() + + rv = self.client.get('/user') + self.assertEqual(rv.status_code, 401) + resp = rv.json() + self.assertEqual(resp['detail']['error'], 'missing_authorization') + + headers = {'Authorization': 'invalid token'} + rv = self.client.get('/user', headers=headers) + self.assertEqual(rv.status_code, 401) + resp = rv.json() + self.assertEqual(resp['detail']['error'], 'unsupported_token_type') + + headers = self.create_bearer_header('invalid') + rv = self.client.get('/user', headers=headers) + self.assertEqual(rv.status_code, 401) + resp = rv.json() + self.assertEqual(resp['detail']['error'], 'invalid_token') + + def test_expired_token(self): + self.prepare_data() + self.create_token(0) + headers = self.create_bearer_header('a1') + + rv = self.client.get('/user', headers=headers) + self.assertEqual(rv.status_code, 401) + resp = rv.json() + self.assertEqual(resp['detail']['error'], 'invalid_token') + + rv = self.client.get('/acquire', headers=headers) + self.assertEqual(rv.status_code, 401) + + def test_insufficient_token(self): + self.prepare_data() + self.create_token() + headers = self.create_bearer_header('a1') + rv = self.client.get('/user/email', headers=headers) + self.assertEqual(rv.status_code, 403) + resp = rv.json() + self.assertEqual(resp['detail']['error'], 'insufficient_scope') + + def test_access_resource(self): + self.prepare_data() + self.create_token() + headers = self.create_bearer_header('a1') + + rv = self.client.get('/user', headers=headers) + resp = rv.json() + self.assertEqual(resp['username'], 'foo') + + rv = self.client.get('/acquire', headers=headers) + resp = rv.json() + self.assertEqual(resp['username'], 'foo') + + rv = self.client.get('/info', headers=headers) + resp = rv.json() + self.assertEqual(resp['status'], 'ok') + + def test_scope_operator(self): + self.prepare_data() + self.create_token() + headers = self.create_bearer_header('a1') + rv = self.client.get('/operator-and', headers=headers) + self.assertEqual(rv.status_code, 403) + resp = rv.json() + self.assertEqual(resp['detail']['error'], 'insufficient_scope') + + rv = self.client.get('/operator-or', headers=headers) + self.assertEqual(rv.status_code, 200) + + rv = self.client.get('/operator-func', headers=headers) + self.assertEqual(rv.status_code, 200) From b826d2952e72b26695c9b3bdb88443ff3c0bf146 Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Wed, 21 Oct 2020 15:34:11 -0300 Subject: [PATCH 08/16] Implemented the openid code grant pytests for FastAPI integration --- tests/fastapi/test_oauth2/oauth2_server.py | 44 ++- .../test_oauth2/test_openid_code_grant.py | 275 ++++++++++++++++++ 2 files changed, 315 insertions(+), 4 deletions(-) create mode 100644 tests/fastapi/test_oauth2/test_openid_code_grant.py diff --git a/tests/fastapi/test_oauth2/oauth2_server.py b/tests/fastapi/test_oauth2/oauth2_server.py index 76095f1b..a4b9a5b2 100644 --- a/tests/fastapi/test_oauth2/oauth2_server.py +++ b/tests/fastapi/test_oauth2/oauth2_server.py @@ -33,7 +33,7 @@ def create_authorization_server(app): server.init_app(app, query_client, save_token) @app.get('/oauth/authorize') - def authorize(request: Request): + def authorize_get(request: Request): user_id = request.query_params.get('user_id') request.body = {} if user_id: @@ -47,12 +47,44 @@ def authorize(request: Request): return url_encode(error.get_body()) @app.post('/oauth/authorize') - def authorize(request: Request, user_id: str = Form('')): - request.body = {} + def authorize_post(request: Request, + response_type: str = Form(None), + client_id: str = Form(None), + state: str = Form(None), + scope: str = Form(None), + nonce: str = Form(None), + redirect_uri: str = Form(None), + user_id: str = Form(None)): + if not user_id: + user_id = request.query_params.get('user_id') + + request.body = { + 'user_id': user_id + } + + if response_type: + request.body.update({'response_type': response_type}) + + if client_id: + request.body.update({'client_id': client_id}) + + if state: + request.body.update({'state': state}) + + if nonce: + request.body.update({'nonce': nonce}) + + if scope: + request.body.update({'scope': scope}) + + if redirect_uri: + request.body.update({'redirect_uri': redirect_uri}) + if user_id: grant_user = db.query(User).filter(User.id == int(user_id)).first() else: grant_user = None + return server.create_authorization_response(request=request, grant_user=grant_user) @app.post('/oauth/token') @@ -68,7 +100,8 @@ def issue_token( device_code: str = Form(None), client_assertion_type: str = Form(None), client_assertion: str = Form(None), - assertion: str = Form(None)): + assertion: str = Form(None), + redirect_uri: str = Form(None)): request.body = { 'grant_type': grant_type, 'scope': scope, @@ -99,6 +132,9 @@ def issue_token( if assertion: request.body.update({'assertion': assertion}) + if redirect_uri: + request.body.update({'redirect_uri': redirect_uri}) + return server.create_token_response(request=request) return server diff --git a/tests/fastapi/test_oauth2/test_openid_code_grant.py b/tests/fastapi/test_oauth2/test_openid_code_grant.py new file mode 100644 index 00000000..22bcb2c6 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_openid_code_grant.py @@ -0,0 +1,275 @@ +import json +from authlib.common.encoding import to_unicode +from authlib.common.urls import urlparse, url_decode, url_encode +from authlib.jose import JsonWebToken, JsonWebKey +from authlib.oidc.core import CodeIDToken +from authlib.oidc.core.grants import OpenIDCode as _OpenIDCode +from authlib.oauth2.rfc6749.grants import ( + AuthorizationCodeGrant as _AuthorizationCodeGrant, +) +from tests.util import get_file_path +from .database import db +from .models import User, Client, exists_nonce +from .models import CodeGrantMixin, save_authorization_code +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + +DUMMY_JWT_CONFIG = { + 'key': 'secret', + 'alg': 'HS256', + 'iss': 'Authlib', + 'exp': 3600, +} + + +class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + +class OpenIDCode(_OpenIDCode): + def get_jwt_config(self, grant): + return DUMMY_JWT_CONFIG + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + +class BaseTestCase(TestCase): + def config_app(self): + DUMMY_JWT_CONFIG.update({ + 'iss': 'Authlib', + 'key': 'secret', + 'alg': 'HS256', + }) + + def prepare_data(self): + self.config_app() + server = create_authorization_server(self.app) + server.register_grant(AuthorizationCodeGrant, [OpenIDCode()]) + + user = User(username='foo') + db.add(user) + db.commit() + + client = Client( + user_id=user.id, + client_id='code-client', + client_secret='code-secret', + ) + client.set_client_metadata({ + 'redirect_uris': ['https://a.b'], + 'scope': 'openid profile address', + 'response_types': ['code'], + 'grant_types': ['authorization_code'], + }) + db.add(client) + db.commit() + + +class OpenIDCodeTest(BaseTestCase): + def test_authorize_token(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'code', + 'client_id': 'code-client', + 'state': 'bar', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1' + }) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + self.assertEqual(params['state'], 'bar') + + code = params['code'] + headers = self.create_basic_header('code-client', 'code-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'redirect_uri': 'https://a.b', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('id_token', resp) + + jwt = JsonWebToken() + claims = jwt.decode( + resp['id_token'], 'secret', + claims_cls=CodeIDToken, + claims_options={'iss': {'value': 'Authlib'}} + ) + claims.validate() + + def test_pure_code_flow(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'code', + 'client_id': 'code-client', + 'state': 'bar', + 'scope': 'profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1' + }) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + self.assertEqual(params['state'], 'bar') + + code = params['code'] + headers = self.create_basic_header('code-client', 'code-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'redirect_uri': 'https://a.b', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertNotIn('id_token', resp) + + def test_nonce_replay(self): + self.prepare_data() + data = { + 'response_type': 'code', + 'client_id': 'code-client', + 'user_id': '1', + 'state': 'bar', + 'nonce': 'abc', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b' + } + rv = self.client.post('/oauth/authorize', data=data) + self.assertIn('code=', rv.headers['location']) + + rv = self.client.post('/oauth/authorize', data=data) + self.assertIn('error=', rv.headers['location']) + + def test_prompt(self): + self.prepare_data() + params = [ + ('response_type', 'code'), + ('client_id', 'code-client'), + ('state', 'bar'), + ('nonce', 'abc'), + ('scope', 'openid profile'), + ('redirect_uri', 'https://a.b') + ] + query = url_encode(params) + rv = self.client.get('/oauth/authorize?' + query) + self.assertEqual(rv.json(), 'login') + + query = url_encode(params + [('user_id', '1')]) + rv = self.client.get('/oauth/authorize?' + query) + self.assertEqual(rv.json(), 'ok') + + query = url_encode(params + [('prompt', 'login')]) + rv = self.client.get('/oauth/authorize?' + query) + self.assertEqual(rv.json(), 'login') + + +class RSAOpenIDCodeTest(BaseTestCase): + def config_app(self): + jwt_key_path = get_file_path('jwk_private.json') + with open(jwt_key_path, 'r') as f: + jwt_key = json.load(f) + + DUMMY_JWT_CONFIG.update({ + 'iss': 'Authlib', + 'key': jwt_key, + 'alg': 'RS256', + }) + + def get_validate_key(self): + with open(get_file_path('jwk_public.json'), 'r') as f: + return json.load(f) + + def test_authorize_token(self): + # generate refresh token + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'code', + 'client_id': 'code-client', + 'state': 'bar', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1' + }) + self.assertIn('code=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + self.assertEqual(params['state'], 'bar') + + code = params['code'] + headers = self.create_basic_header('code-client', 'code-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'redirect_uri': 'https://a.b', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('id_token', resp) + + jwt = JsonWebToken() + claims = jwt.decode( + resp['id_token'], + self.get_validate_key(), + claims_cls=CodeIDToken, + claims_options={'iss': {'value': 'Authlib'}} + ) + claims.validate() + + +class JWKSOpenIDCodeTest(RSAOpenIDCodeTest): + def config_app(self): + jwt_key_path = get_file_path('jwks_private.json') + with open(jwt_key_path, 'r') as f: + jwt_key = json.load(f) + + DUMMY_JWT_CONFIG.update({ + 'iss': 'Authlib', + 'key': jwt_key, + 'alg': 'PS256', + }) + + def get_validate_key(self): + with open(get_file_path('jwks_public.json'), 'r') as f: + return JsonWebKey.import_key_set(json.load(f)) + + +class ECOpenIDCodeTest(RSAOpenIDCodeTest): + def config_app(self): + jwt_key_path = get_file_path('secp521r1-private.json') + with open(jwt_key_path, 'r') as f: + jwt_key = json.load(f) + + DUMMY_JWT_CONFIG.update({ + 'iss': 'Authlib', + 'key': jwt_key, + 'alg': 'ES512', + }) + + def get_validate_key(self): + with open(get_file_path('secp521r1-public.json'), 'r') as f: + return json.load(f) + + +class PEMOpenIDCodeTest(RSAOpenIDCodeTest): + def config_app(self): + jwt_key_path = get_file_path('rsa_private.pem') + with open(jwt_key_path, 'r') as f: + jwt_key = to_unicode(f.read()) + + DUMMY_JWT_CONFIG.update({ + 'iss': 'Authlib', + 'key': jwt_key, + 'alg': 'RS256', + }) + + def get_validate_key(self): + with open(get_file_path('rsa_public.pem'), 'r') as f: + return f.read() From fc34de879d6872928b5ff25aaec488cb5a0aa790 Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Wed, 21 Oct 2020 15:46:31 -0300 Subject: [PATCH 09/16] Implemented the openid hybrid grant pytests for FastAPI integration --- tests/fastapi/test_oauth2/oauth2_server.py | 4 + .../test_oauth2/test_openid_hybrid_grant.py | 286 ++++++++++++++++++ 2 files changed, 290 insertions(+) create mode 100644 tests/fastapi/test_oauth2/test_openid_hybrid_grant.py diff --git a/tests/fastapi/test_oauth2/oauth2_server.py b/tests/fastapi/test_oauth2/oauth2_server.py index a4b9a5b2..70affe02 100644 --- a/tests/fastapi/test_oauth2/oauth2_server.py +++ b/tests/fastapi/test_oauth2/oauth2_server.py @@ -54,6 +54,7 @@ def authorize_post(request: Request, scope: str = Form(None), nonce: str = Form(None), redirect_uri: str = Form(None), + response_mode: str = Form(None), user_id: str = Form(None)): if not user_id: user_id = request.query_params.get('user_id') @@ -80,6 +81,9 @@ def authorize_post(request: Request, if redirect_uri: request.body.update({'redirect_uri': redirect_uri}) + if response_mode: + request.body.update({'response_mode': response_mode}) + if user_id: grant_user = db.query(User).filter(User.id == int(user_id)).first() else: diff --git a/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py b/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py new file mode 100644 index 00000000..4f55f1e1 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py @@ -0,0 +1,286 @@ +from authlib.common.urls import urlparse, url_decode +from authlib.jose import JWT +from authlib.oidc.core import HybridIDToken +from authlib.oidc.core.grants import ( + OpenIDCode as _OpenIDCode, + OpenIDHybridGrant as _OpenIDHybridGrant, +) +from authlib.oauth2.rfc6749.grants import ( + AuthorizationCodeGrant as _AuthorizationCodeGrant, +) +from .database import db +from .models import User, Client, exists_nonce +from .models import CodeGrantMixin, save_authorization_code +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + +JWT_CONFIG = {'iss': 'Authlib', 'key': 'secret', 'alg': 'HS256', 'exp': 3600} + + +class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + +class OpenIDCode(_OpenIDCode): + def get_jwt_config(self, grant): + return dict(JWT_CONFIG) + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + +class OpenIDHybridGrant(_OpenIDHybridGrant): + def save_authorization_code(self, code, request): + return save_authorization_code(code, request) + + def get_jwt_config(self): + return dict(JWT_CONFIG) + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + +class OpenIDCodeTest(TestCase): + def prepare_data(self): + server = create_authorization_server(self.app) + server.register_grant(OpenIDHybridGrant) + server.register_grant(AuthorizationCodeGrant, [OpenIDCode()]) + + user = User(username='foo') + db.add(user) + db.commit() + + client = Client( + user_id=user.id, + client_id='hybrid-client', + client_secret='hybrid-secret', + ) + client.set_client_metadata({ + 'redirect_uris': ['https://a.b'], + 'scope': 'openid profile address', + 'response_types': ['code id_token', 'code token', 'code id_token token'], + 'grant_types': ['authorization_code'], + }) + db.add(client) + db.commit() + + def validate_claims(self, id_token, params): + jwt = JWT() + claims = jwt.decode( + id_token, 'secret', + claims_cls=HybridIDToken, + claims_params=params + ) + claims.validate() + + def test_invalid_client_id(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'code token', + 'state': 'bar', + 'nonce': 'abc', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + rv = self.client.post('/oauth/authorize', data={ + 'client_id': 'invalid-client', + 'response_type': 'code token', + 'state': 'bar', + 'nonce': 'abc', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_require_nonce(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'client_id': 'hybrid-client', + 'response_type': 'code token', + 'scope': 'openid profile', + 'state': 'bar', + 'redirect_uri': 'https://a.b', + 'user_id': '1' + }) + self.assertIn('error=invalid_request', rv.headers['location']) + self.assertIn('nonce', rv.headers['location']) + + def test_invalid_response_type(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'client_id': 'hybrid-client', + 'response_type': 'code id_token invalid', + 'state': 'bar', + 'nonce': 'abc', + 'scope': 'profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_grant') + + def test_invalid_scope(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'client_id': 'hybrid-client', + 'response_type': 'code id_token', + 'state': 'bar', + 'nonce': 'abc', + 'scope': 'profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1', + }) + self.assertIn('error=invalid_scope', rv.headers['location']) + + def test_access_denied(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'client_id': 'hybrid-client', + 'response_type': 'code token', + 'state': 'bar', + 'nonce': 'abc', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b', + }) + self.assertIn('error=access_denied', rv.headers['location']) + + def test_code_access_token(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'client_id': 'hybrid-client', + 'response_type': 'code token', + 'state': 'bar', + 'nonce': 'abc', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1', + }) + self.assertIn('code=', rv.headers['location']) + self.assertIn('access_token=', rv.headers['location']) + self.assertNotIn('id_token=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).fragment)) + self.assertEqual(params['state'], 'bar') + + code = params['code'] + headers = self.create_basic_header('hybrid-client', 'hybrid-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'redirect_uri': 'https://a.b', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('id_token', resp) + + def test_code_id_token(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'client_id': 'hybrid-client', + 'response_type': 'code id_token', + 'state': 'bar', + 'nonce': 'abc', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1', + }) + self.assertIn('code=', rv.headers['location']) + self.assertIn('id_token=', rv.headers['location']) + self.assertNotIn('access_token=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).fragment)) + self.assertEqual(params['state'], 'bar') + + params['nonce'] = 'abc' + params['client_id'] = 'hybrid-client' + self.validate_claims(params['id_token'], params) + + code = params['code'] + headers = self.create_basic_header('hybrid-client', 'hybrid-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'redirect_uri': 'https://a.b', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('id_token', resp) + + def test_code_id_token_access_token(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'client_id': 'hybrid-client', + 'response_type': 'code id_token token', + 'state': 'bar', + 'nonce': 'abc', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1', + }) + self.assertIn('code=', rv.headers['location']) + self.assertIn('id_token=', rv.headers['location']) + self.assertIn('access_token=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).fragment)) + self.assertEqual(params['state'], 'bar') + self.validate_claims(params['id_token'], params) + + code = params['code'] + headers = self.create_basic_header('hybrid-client', 'hybrid-secret') + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'authorization_code', + 'redirect_uri': 'https://a.b', + 'code': code, + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('id_token', resp) + + def test_response_mode_query(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'client_id': 'hybrid-client', + 'response_type': 'code id_token token', + 'response_mode': 'query', + 'state': 'bar', + 'nonce': 'abc', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1', + }) + self.assertIn('code=', rv.headers['location']) + self.assertIn('id_token=', rv.headers['location']) + self.assertIn('access_token=', rv.headers['location']) + + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + self.assertEqual(params['state'], 'bar') + + def test_response_mode_form_post(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'client_id': 'hybrid-client', + 'response_type': 'code id_token token', + 'response_mode': 'form_post', + 'state': 'bar', + 'nonce': 'abc', + 'scope': 'openid profile', + 'redirect_uri': 'https://a.b', + 'user_id': '1', + }) + resp = rv.json() + self.assertIn('name="code"', resp) + self.assertIn('name="id_token"', resp) + self.assertIn('name="access_token"', resp) From 786259eddc341593c2c7bc2807d153fd80cd3820 Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Wed, 21 Oct 2020 15:50:17 -0300 Subject: [PATCH 10/16] Implemented the openid implict grant pytests for FastAPI integration --- .../test_oauth2/test_openid_implict_grant.py | 173 ++++++++++++++++++ 1 file changed, 173 insertions(+) create mode 100644 tests/fastapi/test_oauth2/test_openid_implict_grant.py diff --git a/tests/fastapi/test_oauth2/test_openid_implict_grant.py b/tests/fastapi/test_oauth2/test_openid_implict_grant.py new file mode 100644 index 00000000..ea9be8fa --- /dev/null +++ b/tests/fastapi/test_oauth2/test_openid_implict_grant.py @@ -0,0 +1,173 @@ +from authlib.jose import JWT +from authlib.oidc.core import ImplicitIDToken +from authlib.oidc.core.grants import ( + OpenIDImplicitGrant as _OpenIDImplicitGrant +) +from authlib.common.urls import urlparse, url_decode, add_params_to_uri +from .database import db +from .models import User, Client, exists_nonce +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class OpenIDImplicitGrant(_OpenIDImplicitGrant): + def get_jwt_config(self): + return dict(key='secret', alg='HS256', iss='Authlib', exp=3600) + + def generate_user_info(self, user, scopes): + return user.generate_user_info(scopes) + + def exists_nonce(self, nonce, request): + return exists_nonce(nonce, request) + + +class ImplicitTest(TestCase): + def prepare_data(self): + server = create_authorization_server(self.app) + server.register_grant(OpenIDImplicitGrant) + + user = User(username='foo') + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id='implicit-client', + client_secret='', + ) + client.set_client_metadata({ + 'redirect_uris': ['https://a.b/c'], + 'scope': 'openid profile', + 'token_endpoint_auth_method': 'none', + 'response_types': ['id_token', 'id_token token'], + }) + self.authorize_url = ( + '/oauth/authorize?response_type=token' + '&client_id=implicit-client' + ) + db.add(client) + db.commit() + + def validate_claims(self, id_token, params): + jwt = JWT(['HS256']) + claims = jwt.decode( + id_token, 'secret', + claims_cls=ImplicitIDToken, + claims_params=params + ) + claims.validate() + + def test_consent_view(self): + self.prepare_data() + rv = self.client.get(add_params_to_uri('/oauth/authorize', { + 'response_type': 'id_token', + 'client_id': 'implicit-client', + 'scope': 'openid profile', + 'state': 'foo', + 'redirect_uri': 'https://a.b/c', + 'user_id': '1' + })) + self.assertIn('error=invalid_request', rv.json()) + self.assertIn('nonce', rv.json()) + + def test_require_nonce(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'id_token', + 'client_id': 'implicit-client', + 'scope': 'openid profile', + 'state': 'bar', + 'redirect_uri': 'https://a.b/c', + 'user_id': '1' + }) + self.assertIn('error=invalid_request', rv.headers['location']) + self.assertIn('nonce', rv.headers['location']) + + def test_missing_openid_in_scope(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'id_token token', + 'client_id': 'implicit-client', + 'scope': 'profile', + 'state': 'bar', + 'nonce': 'abc', + 'redirect_uri': 'https://a.b/c', + 'user_id': '1' + }) + self.assertIn('error=invalid_scope', rv.headers['location']) + + def test_denied(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'id_token', + 'client_id': 'implicit-client', + 'scope': 'openid profile', + 'state': 'bar', + 'nonce': 'abc', + 'redirect_uri': 'https://a.b/c', + }) + self.assertIn('error=access_denied', rv.headers['location']) + + def test_authorize_access_token(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'id_token token', + 'client_id': 'implicit-client', + 'scope': 'openid profile', + 'state': 'bar', + 'nonce': 'abc', + 'redirect_uri': 'https://a.b/c', + 'user_id': '1' + }) + self.assertIn('access_token=', rv.headers['location']) + self.assertIn('id_token=', rv.headers['location']) + self.assertIn('state=bar', rv.headers['location']) + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).fragment)) + self.validate_claims(params['id_token'], params) + + def test_authorize_id_token(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'id_token', + 'client_id': 'implicit-client', + 'scope': 'openid profile', + 'state': 'bar', + 'nonce': 'abc', + 'redirect_uri': 'https://a.b/c', + 'user_id': '1' + }) + self.assertIn('id_token=', rv.headers['location']) + self.assertIn('state=bar', rv.headers['location']) + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).fragment)) + self.validate_claims(params['id_token'], params) + + def test_response_mode_query(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'id_token', + 'response_mode': 'query', + 'client_id': 'implicit-client', + 'scope': 'openid profile', + 'state': 'bar', + 'nonce': 'abc', + 'redirect_uri': 'https://a.b/c', + 'user_id': '1' + }) + self.assertIn('id_token=', rv.headers['location']) + self.assertIn('state=bar', rv.headers['location']) + params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) + self.validate_claims(params['id_token'], params) + + def test_response_mode_form_post(self): + self.prepare_data() + rv = self.client.post('/oauth/authorize', data={ + 'response_type': 'id_token', + 'response_mode': 'form_post', + 'client_id': 'implicit-client', + 'scope': 'openid profile', + 'state': 'bar', + 'nonce': 'abc', + 'redirect_uri': 'https://a.b/c', + 'user_id': '1' + }) + self.assertIn('name="id_token"', rv.json()) + self.assertIn('name="state"', rv.json()) From 3414d825c7aa595ed8abadd4c56b7422917292cf Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Wed, 21 Oct 2020 16:08:30 -0300 Subject: [PATCH 11/16] Implemented the password grant pytests for FastAPI integration --- tests/fastapi/test_oauth2/oauth2_server.py | 17 +- .../test_oauth2/test_password_grant.py | 166 ++++++++++++++++++ 2 files changed, 181 insertions(+), 2 deletions(-) create mode 100644 tests/fastapi/test_oauth2/test_password_grant.py diff --git a/tests/fastapi/test_oauth2/oauth2_server.py b/tests/fastapi/test_oauth2/oauth2_server.py index 70affe02..fcffb595 100644 --- a/tests/fastapi/test_oauth2/oauth2_server.py +++ b/tests/fastapi/test_oauth2/oauth2_server.py @@ -91,10 +91,10 @@ def authorize_post(request: Request, return server.create_authorization_response(request=request, grant_user=grant_user) - @app.post('/oauth/token') + @app.api_route('/oauth/token', methods=["GET", "POST"]) def issue_token( request: Request, - grant_type: str = Form(...), + grant_type: str = Form(None), scope: str = Form(None), code: str = Form(None), refresh_token: str = Form(None), @@ -105,11 +105,18 @@ def issue_token( client_assertion_type: str = Form(None), client_assertion: str = Form(None), assertion: str = Form(None), + username: str = Form(None), + password: str = Form(None), redirect_uri: str = Form(None)): request.body = { 'grant_type': grant_type, 'scope': scope, } + + if not grant_type: + grant_type = request.query_params.get('grant_type') + request.body.update({'grant_type': grant_type}) + if grant_type == 'authorization_code': request.body.update({'code': code}) elif grant_type == 'refresh_token': @@ -139,6 +146,12 @@ def issue_token( if redirect_uri: request.body.update({'redirect_uri': redirect_uri}) + if username: + request.body.update({'username': username}) + + if password: + request.body.update({'password': password}) + return server.create_token_response(request=request) return server diff --git a/tests/fastapi/test_oauth2/test_password_grant.py b/tests/fastapi/test_oauth2/test_password_grant.py new file mode 100644 index 00000000..b230afbe --- /dev/null +++ b/tests/fastapi/test_oauth2/test_password_grant.py @@ -0,0 +1,166 @@ +from authlib.common.urls import add_params_to_uri +from authlib.oauth2.rfc6749.grants import ( + ResourceOwnerPasswordCredentialsGrant as _PasswordGrant, +) +from .database import db +from .models import User, Client +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class PasswordGrant(_PasswordGrant): + def authenticate_user(self, username, password): + user = db.query(User).filter(User.username == username).first() + if user.check_password(password): + return user + + +class PasswordTest(TestCase): + def prepare_data(self, grant_type='password'): + server = create_authorization_server(self.app) + server.register_grant(PasswordGrant) + self.server = server + + user = User(username='foo') + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id='password-client', + client_secret='password-secret', + ) + client.set_client_metadata({ + 'scope': 'profile', + 'grant_types': [grant_type], + 'redirect_uris': ['http://localhost/authorized'], + }) + db.add(client) + db.commit() + + def test_invalid_client(self): + self.prepare_data() + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'password', + 'username': 'foo', + 'password': 'ok', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = self.create_basic_header( + 'password-client', 'invalid-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'password', + 'username': 'foo', + 'password': 'ok', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_invalid_scope(self): + self.prepare_data() + self.server.metadata = {'scopes_supported': ['profile']} + headers = self.create_basic_header( + 'password-client', 'password-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'password', + 'username': 'foo', + 'password': 'ok', + 'scope': 'invalid', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_scope') + + def test_invalid_request(self): + self.prepare_data() + headers = self.create_basic_header( + 'password-client', 'password-secret' + ) + + rv = self.client.get(add_params_to_uri('/oauth/token', { + 'grant_type': 'password', + }), headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'unsupported_grant_type') + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'password', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'password', + 'username': 'foo', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'password', + 'username': 'foo', + 'password': 'wrong', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + def test_invalid_grant_type(self): + self.prepare_data(grant_type='invalid') + headers = self.create_basic_header( + 'password-client', 'password-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'password', + 'username': 'foo', + 'password': 'ok', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'unauthorized_client') + + def test_authorize_token(self): + self.prepare_data() + headers = self.create_basic_header( + 'password-client', 'password-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'password', + 'username': 'foo', + 'password': 'ok', + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_token_generator(self): + m = 'tests.fastapi.test_oauth2.oauth2_server:token_generator' + self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) + self.prepare_data() + headers = self.create_basic_header( + 'password-client', 'password-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'password', + 'username': 'foo', + 'password': 'ok', + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('p-password.1.', resp['access_token']) + + def test_custom_expires_in(self): + self.app.config.update({ + 'OAUTH2_TOKEN_EXPIRES_IN': {'password': 1800} + }) + self.prepare_data() + headers = self.create_basic_header( + 'password-client', 'password-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'password', + 'username': 'foo', + 'password': 'ok', + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertEqual(resp['expires_in'], 1800) From a84ca8360bc38390409716edcf40816ac1f935b1 Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Wed, 21 Oct 2020 16:11:22 -0300 Subject: [PATCH 12/16] Implemented the refresh token pytests for FastAPI integration --- .../fastapi/test_oauth2/test_refresh_token.py | 228 ++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 tests/fastapi/test_oauth2/test_refresh_token.py diff --git a/tests/fastapi/test_oauth2/test_refresh_token.py b/tests/fastapi/test_oauth2/test_refresh_token.py new file mode 100644 index 00000000..df8c3541 --- /dev/null +++ b/tests/fastapi/test_oauth2/test_refresh_token.py @@ -0,0 +1,228 @@ +from authlib.oauth2.rfc6749.grants import ( + RefreshTokenGrant as _RefreshTokenGrant, +) +from .database import db +from .models import User, Client, Token +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +class RefreshTokenGrant(_RefreshTokenGrant): + def authenticate_refresh_token(self, refresh_token): + item = db.query(Token).filter(Token.refresh_token == refresh_token).first() + if item and not item.revoked and not item.is_refresh_token_expired(): + return item + + def authenticate_user(self, credential): + return db.query(User).filter(User.id == int(credential.user_id)).first() + + def revoke_old_credential(self, credential): + credential.revoked = True + db.add(credential) + db.commit() + + +class RefreshTokenTest(TestCase): + def prepare_data(self, grant_type='refresh_token'): + server = create_authorization_server(self.app) + server.register_grant(RefreshTokenGrant) + + user = User(username='foo') + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id='refresh-client', + client_secret='refresh-secret', + ) + client.set_client_metadata({ + 'scope': 'profile', + 'grant_types': [grant_type], + 'redirect_uris': ['http://localhost/authorized'], + }) + db.add(client) + db.commit() + + def create_token(self, scope='profile', user_id=1): + token = Token( + user_id=user_id, + client_id='refresh-client', + token_type='bearer', + access_token='a1', + refresh_token='r1', + scope=scope, + expires_in=3600, + ) + db.add(token) + db.commit() + + def test_invalid_client(self): + self.prepare_data() + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'foo', + }) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = self.create_basic_header( + 'invalid-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'foo', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = self.create_basic_header( + 'refresh-client', 'invalid-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'foo', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_invalid_refresh_token(self): + self.prepare_data() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + self.assertIn('Missing', resp['error_description']) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'foo', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_grant') + + def test_invalid_scope(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'invalid', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_scope') + + def test_invalid_scope_none(self): + self.prepare_data() + self.create_token(scope=None) + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'invalid', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_scope') + + def test_invalid_user(self): + self.prepare_data() + self.create_token(user_id=5) + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'profile', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + def test_invalid_grant_type(self): + self.prepare_data(grant_type='invalid') + self.create_token() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'profile', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'unauthorized_client') + + def test_authorize_token_no_scope(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_authorize_token_scope(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'profile', + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + + def test_revoke_old_credential(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'profile', + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + 'scope': 'profile', + }, headers=headers) + self.assertEqual(rv.status_code, 400) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_grant') + + def test_token_generator(self): + m = 'tests.fastapi.test_oauth2.oauth2_server:token_generator' + self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) + + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'refresh-client', 'refresh-secret' + ) + rv = self.client.post('/oauth/token', data={ + 'grant_type': 'refresh_token', + 'refresh_token': 'r1', + }, headers=headers) + resp = rv.json() + self.assertIn('access_token', resp) + self.assertIn('r-refresh_token.1.', resp['access_token']) From ad116693b7f68bb2dfd877cc1a33146482a9a464 Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Wed, 21 Oct 2020 16:16:45 -0300 Subject: [PATCH 13/16] Implemented the token revocation pytests for FastAPI integration --- .../test_oauth2/test_revocation_endpoint.py | 130 ++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 tests/fastapi/test_oauth2/test_revocation_endpoint.py diff --git a/tests/fastapi/test_oauth2/test_revocation_endpoint.py b/tests/fastapi/test_oauth2/test_revocation_endpoint.py new file mode 100644 index 00000000..df8bed6d --- /dev/null +++ b/tests/fastapi/test_oauth2/test_revocation_endpoint.py @@ -0,0 +1,130 @@ +from fastapi import Request, Form +from authlib.integrations.sqla_oauth2 import create_revocation_endpoint +from .database import db +from .models import User, Client, Token +from .oauth2_server import TestCase +from .oauth2_server import create_authorization_server + + +RevocationEndpoint = create_revocation_endpoint(db, Token) + + +class RevokeTokenTest(TestCase): + def prepare_data(self): + app = self.app + server = create_authorization_server(app) + server.register_endpoint(RevocationEndpoint) + + @app.post('/oauth/revoke') + def revoke_token(request: Request, + token: str = Form(None), + token_type_hint: str = Form(None)): + request.body = {} + if token: + request.body.update({'token': token}) + if token_type_hint: + request.body.update({'token_type_hint': token_type_hint}) + return server.create_endpoint_response('revocation', request=request) + + user = User(username='foo') + db.add(user) + db.commit() + client = Client( + user_id=user.id, + client_id='revoke-client', + client_secret='revoke-secret', + ) + client.set_client_metadata({ + 'scope': 'profile', + 'redirect_uris': ['http://localhost/authorized'], + }) + db.add(client) + db.commit() + + def create_token(self): + token = Token( + user_id=1, + client_id='revoke-client', + token_type='bearer', + access_token='a1', + refresh_token='r1', + scope='profile', + expires_in=3600, + ) + db.add(token) + db.commit() + + def test_invalid_client(self): + self.prepare_data() + rv = self.client.post('/oauth/revoke') + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = {'Authorization': 'invalid token_string'} + rv = self.client.post('/oauth/revoke', headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = self.create_basic_header( + 'invalid-client', 'revoke-secret' + ) + rv = self.client.post('/oauth/revoke', headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + headers = self.create_basic_header( + 'revoke-client', 'invalid-secret' + ) + rv = self.client.post('/oauth/revoke', headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_client') + + def test_invalid_token(self): + self.prepare_data() + headers = self.create_basic_header( + 'revoke-client', 'revoke-secret' + ) + rv = self.client.post('/oauth/revoke', headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'invalid_request') + + rv = self.client.post('/oauth/revoke', data={ + 'token': 'invalid-token', + }, headers=headers) + self.assertEqual(rv.status_code, 200) + + rv = self.client.post('/oauth/revoke', data={ + 'token': 'a1', + 'token_type_hint': 'unsupported_token_type', + }, headers=headers) + resp = rv.json() + self.assertEqual(resp['error'], 'unsupported_token_type') + + rv = self.client.post('/oauth/revoke', data={ + 'token': 'a1', + 'token_type_hint': 'refresh_token', + }, headers=headers) + self.assertEqual(rv.status_code, 200) + + def test_revoke_token_with_hint(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'revoke-client', 'revoke-secret' + ) + rv = self.client.post('/oauth/revoke', data={ + 'token': 'a1', + 'token_type_hint': 'access_token', + }, headers=headers) + self.assertEqual(rv.status_code, 200) + + def test_revoke_token_without_hint(self): + self.prepare_data() + self.create_token() + headers = self.create_basic_header( + 'revoke-client', 'revoke-secret' + ) + rv = self.client.post('/oauth/revoke', data={ + 'token': 'a1', + }, headers=headers) + self.assertEqual(rv.status_code, 200) From 1fdb53f2330f4f2e304cc9924d3a06f465a4233a Mon Sep 17 00:00:00 2001 From: Gabriel Machado Santos Date: Thu, 22 Oct 2020 15:42:42 -0300 Subject: [PATCH 14/16] Merged the database into models file --- tests/fastapi/test_oauth2/database.py | 15 --------------- tests/fastapi/test_oauth2/models.py | 16 +++++++++++++--- tests/fastapi/test_oauth2/oauth2_server.py | 3 +-- .../test_oauth2/test_authorization_code_grant.py | 3 +-- .../test_oauth2/test_client_credentials_grant.py | 3 +-- .../test_client_registration_endpoint.py | 3 +-- tests/fastapi/test_oauth2/test_code_challenge.py | 3 +-- .../test_oauth2/test_device_code_grant.py | 3 +-- tests/fastapi/test_oauth2/test_implicit_grant.py | 3 +-- .../test_oauth2/test_introspection_endpoint.py | 3 +-- .../test_oauth2/test_jwt_bearer_client_auth.py | 3 +-- .../fastapi/test_oauth2/test_jwt_bearer_grant.py | 3 +-- tests/fastapi/test_oauth2/test_oauth2_server.py | 3 +-- .../test_oauth2/test_openid_code_grant.py | 3 +-- .../test_oauth2/test_openid_hybrid_grant.py | 3 +-- .../test_oauth2/test_openid_implict_grant.py | 3 +-- tests/fastapi/test_oauth2/test_password_grant.py | 3 +-- tests/fastapi/test_oauth2/test_refresh_token.py | 3 +-- .../test_oauth2/test_revocation_endpoint.py | 3 +-- 19 files changed, 30 insertions(+), 52 deletions(-) delete mode 100644 tests/fastapi/test_oauth2/database.py diff --git a/tests/fastapi/test_oauth2/database.py b/tests/fastapi/test_oauth2/database.py deleted file mode 100644 index 8cff1ce2..00000000 --- a/tests/fastapi/test_oauth2/database.py +++ /dev/null @@ -1,15 +0,0 @@ -from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker - -SQLALCHEMY_DATABASE_URL = 'sqlite:///fastapi_auth2_sql.db' - -engine = create_engine( - SQLALCHEMY_DATABASE_URL, connect_args={'check_same_thread': False} -) - -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - -Base = declarative_base() - -db = SessionLocal() diff --git a/tests/fastapi/test_oauth2/models.py b/tests/fastapi/test_oauth2/models.py index 0fdedd54..12937077 100644 --- a/tests/fastapi/test_oauth2/models.py +++ b/tests/fastapi/test_oauth2/models.py @@ -1,13 +1,23 @@ import time -from sqlalchemy import Column, ForeignKey, Integer, String -from sqlalchemy.orm import relationship +from sqlalchemy import Column, ForeignKey, Integer, String, create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship, sessionmaker from authlib.integrations.sqla_oauth2 import ( OAuth2ClientMixin, OAuth2TokenMixin, OAuth2AuthorizationCodeMixin, ) from authlib.oidc.core import UserInfo -from .database import Base, db + +engine = create_engine( + 'sqlite:///fastapi_auth2_sql.db', connect_args={'check_same_thread': False} +) + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() + +db = SessionLocal() class User(Base): diff --git a/tests/fastapi/test_oauth2/oauth2_server.py b/tests/fastapi/test_oauth2/oauth2_server.py index fcffb595..3c03280e 100644 --- a/tests/fastapi/test_oauth2/oauth2_server.py +++ b/tests/fastapi/test_oauth2/oauth2_server.py @@ -12,8 +12,7 @@ ) from authlib.integrations.fastapi_oauth2 import AuthorizationServer from authlib.oauth2 import OAuth2Error -from .models import User, Client, Token -from .database import Base, engine, db +from .models import Base, engine, db, User, Client, Token os.environ['AUTHLIB_INSECURE_TRANSPORT'] = 'true' diff --git a/tests/fastapi/test_oauth2/test_authorization_code_grant.py b/tests/fastapi/test_oauth2/test_authorization_code_grant.py index edb0be14..907b0da1 100644 --- a/tests/fastapi/test_oauth2/test_authorization_code_grant.py +++ b/tests/fastapi/test_oauth2/test_authorization_code_grant.py @@ -2,8 +2,7 @@ from authlib.oauth2.rfc6749.grants import ( AuthorizationCodeGrant as _AuthorizationCodeGrant, ) -from .database import db -from .models import User, Client, AuthorizationCode +from .models import db, User, Client, AuthorizationCode from .models import CodeGrantMixin, save_authorization_code from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_client_credentials_grant.py b/tests/fastapi/test_oauth2/test_client_credentials_grant.py index 77d56551..19f48fb1 100644 --- a/tests/fastapi/test_oauth2/test_client_credentials_grant.py +++ b/tests/fastapi/test_oauth2/test_client_credentials_grant.py @@ -1,6 +1,5 @@ from authlib.oauth2.rfc6749.grants import ClientCredentialsGrant -from .database import db -from .models import User, Client +from .models import db, User, Client from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_client_registration_endpoint.py b/tests/fastapi/test_oauth2/test_client_registration_endpoint.py index 99bf9add..d8bc7ddb 100644 --- a/tests/fastapi/test_oauth2/test_client_registration_endpoint.py +++ b/tests/fastapi/test_oauth2/test_client_registration_endpoint.py @@ -3,8 +3,7 @@ from authlib.jose import jwt from authlib.oauth2.rfc7591 import ClientRegistrationEndpoint as _ClientRegistrationEndpoint from tests.util import read_file_path -from .database import db -from .models import User, Client +from .models import db, User, Client from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_code_challenge.py b/tests/fastapi/test_oauth2/test_code_challenge.py index 90df8186..c9e8188c 100644 --- a/tests/fastapi/test_oauth2/test_code_challenge.py +++ b/tests/fastapi/test_oauth2/test_code_challenge.py @@ -5,8 +5,7 @@ CodeChallenge as _CodeChallenge, create_s256_code_challenge, ) -from .database import db -from .models import User, Client +from .models import db, User, Client from .models import CodeGrantMixin, save_authorization_code from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_device_code_grant.py b/tests/fastapi/test_oauth2/test_device_code_grant.py index bc76a7ff..bd5518b1 100644 --- a/tests/fastapi/test_oauth2/test_device_code_grant.py +++ b/tests/fastapi/test_oauth2/test_device_code_grant.py @@ -5,8 +5,7 @@ DeviceCodeGrant as _DeviceCodeGrant, DeviceCredentialDict, ) -from .database import db -from .models import User, Client +from .models import db, User, Client from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_implicit_grant.py b/tests/fastapi/test_oauth2/test_implicit_grant.py index a08b0a7f..4653edb1 100644 --- a/tests/fastapi/test_oauth2/test_implicit_grant.py +++ b/tests/fastapi/test_oauth2/test_implicit_grant.py @@ -1,6 +1,5 @@ from authlib.oauth2.rfc6749.grants import ImplicitGrant -from .database import db -from .models import User, Client +from .models import db, User, Client from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_introspection_endpoint.py b/tests/fastapi/test_oauth2/test_introspection_endpoint.py index ee51d0e0..0e264fe2 100644 --- a/tests/fastapi/test_oauth2/test_introspection_endpoint.py +++ b/tests/fastapi/test_oauth2/test_introspection_endpoint.py @@ -1,8 +1,7 @@ from fastapi import Request, Form from authlib.integrations.sqla_oauth2 import create_query_token_func from authlib.oauth2.rfc7662 import IntrospectionEndpoint -from .database import db -from .models import User, Client, Token +from .models import db, User, Client, Token from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_jwt_bearer_client_auth.py b/tests/fastapi/test_oauth2/test_jwt_bearer_client_auth.py index 724924ff..f038374f 100644 --- a/tests/fastapi/test_oauth2/test_jwt_bearer_client_auth.py +++ b/tests/fastapi/test_oauth2/test_jwt_bearer_client_auth.py @@ -5,8 +5,7 @@ private_key_jwt_sign, ) from tests.util import read_file_path -from .database import db -from .models import User, Client +from .models import db, User, Client from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_jwt_bearer_grant.py b/tests/fastapi/test_oauth2/test_jwt_bearer_grant.py index c8587829..a3d80abf 100644 --- a/tests/fastapi/test_oauth2/test_jwt_bearer_grant.py +++ b/tests/fastapi/test_oauth2/test_jwt_bearer_grant.py @@ -1,6 +1,5 @@ from authlib.oauth2.rfc7523 import JWTBearerGrant as _JWTBearerGrant -from .database import db -from .models import User, Client +from .models import db, User, Client from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_oauth2_server.py b/tests/fastapi/test_oauth2/test_oauth2_server.py index 94407188..99fd4b9d 100644 --- a/tests/fastapi/test_oauth2/test_oauth2_server.py +++ b/tests/fastapi/test_oauth2/test_oauth2_server.py @@ -1,8 +1,7 @@ from fastapi import Request from authlib.integrations.fastapi_oauth2 import ResourceProtector from authlib.integrations.sqla_oauth2 import create_bearer_token_validator -from .database import db -from .models import User, Client, Token +from .models import db, User, Client, Token from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_openid_code_grant.py b/tests/fastapi/test_oauth2/test_openid_code_grant.py index 22bcb2c6..ed230299 100644 --- a/tests/fastapi/test_oauth2/test_openid_code_grant.py +++ b/tests/fastapi/test_oauth2/test_openid_code_grant.py @@ -8,8 +8,7 @@ AuthorizationCodeGrant as _AuthorizationCodeGrant, ) from tests.util import get_file_path -from .database import db -from .models import User, Client, exists_nonce +from .models import db, User, Client, exists_nonce from .models import CodeGrantMixin, save_authorization_code from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py b/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py index 4f55f1e1..e6bdbb60 100644 --- a/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py +++ b/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py @@ -8,8 +8,7 @@ from authlib.oauth2.rfc6749.grants import ( AuthorizationCodeGrant as _AuthorizationCodeGrant, ) -from .database import db -from .models import User, Client, exists_nonce +from .models import db, User, Client, exists_nonce from .models import CodeGrantMixin, save_authorization_code from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_openid_implict_grant.py b/tests/fastapi/test_oauth2/test_openid_implict_grant.py index ea9be8fa..b6e0a953 100644 --- a/tests/fastapi/test_oauth2/test_openid_implict_grant.py +++ b/tests/fastapi/test_oauth2/test_openid_implict_grant.py @@ -4,8 +4,7 @@ OpenIDImplicitGrant as _OpenIDImplicitGrant ) from authlib.common.urls import urlparse, url_decode, add_params_to_uri -from .database import db -from .models import User, Client, exists_nonce +from .models import db, User, Client, exists_nonce from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_password_grant.py b/tests/fastapi/test_oauth2/test_password_grant.py index b230afbe..a7fd50f9 100644 --- a/tests/fastapi/test_oauth2/test_password_grant.py +++ b/tests/fastapi/test_oauth2/test_password_grant.py @@ -2,8 +2,7 @@ from authlib.oauth2.rfc6749.grants import ( ResourceOwnerPasswordCredentialsGrant as _PasswordGrant, ) -from .database import db -from .models import User, Client +from .models import db, User, Client from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_refresh_token.py b/tests/fastapi/test_oauth2/test_refresh_token.py index df8c3541..a6564aa9 100644 --- a/tests/fastapi/test_oauth2/test_refresh_token.py +++ b/tests/fastapi/test_oauth2/test_refresh_token.py @@ -1,8 +1,7 @@ from authlib.oauth2.rfc6749.grants import ( RefreshTokenGrant as _RefreshTokenGrant, ) -from .database import db -from .models import User, Client, Token +from .models import db, User, Client, Token from .oauth2_server import TestCase from .oauth2_server import create_authorization_server diff --git a/tests/fastapi/test_oauth2/test_revocation_endpoint.py b/tests/fastapi/test_oauth2/test_revocation_endpoint.py index df8bed6d..2c2bc14a 100644 --- a/tests/fastapi/test_oauth2/test_revocation_endpoint.py +++ b/tests/fastapi/test_oauth2/test_revocation_endpoint.py @@ -1,7 +1,6 @@ from fastapi import Request, Form from authlib.integrations.sqla_oauth2 import create_revocation_endpoint -from .database import db -from .models import User, Client, Token +from .models import db, User, Client, Token from .oauth2_server import TestCase from .oauth2_server import create_authorization_server From cf1d6d365b68fa885c1be2b4c81578be5309aea8 Mon Sep 17 00:00:00 2001 From: "alexey.turlapov" Date: Tue, 15 Dec 2020 15:17:25 -0600 Subject: [PATCH 15/16] Adding fastapi tests run to `make tox` and to github pipeline --- .github/workflows/python.yml | 2 +- Makefile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index dbbdfa88..a9256cf3 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -41,7 +41,7 @@ jobs: - name: Test with tox ${{ matrix.python.toxenv }} env: - TOXENV: py,flask,django,starlette + TOXENV: py,flask,django,fastapi,starlette run: tox - name: Report coverage diff --git a/Makefile b/Makefile index 617a66e2..810d2ee3 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ clean: clean-build clean-pyc clean-docs clean-tox tests: - @TOXENV=py,flask,django,coverage tox + @TOXENV=py,flask,django,fastapi,coverage tox clean-build: @rm -fr build/ From 8666fe4df484c637df2747b2dd6762b8c5027ffc Mon Sep 17 00:00:00 2001 From: "alexey.turlapov" Date: Wed, 16 Dec 2020 16:06:17 -0600 Subject: [PATCH 16/16] Integrating refactored authlib into fastapi integration --- .../fastapi_oauth2/authorization_server.py | 71 +-- .../fastapi_oauth2/resource_protector.py | 49 +- tests/fastapi/test_oauth2/models.py | 70 +-- tests/fastapi/test_oauth2/oauth2_server.py | 157 +++---- .../test_authorization_code_grant.py | 368 ++++++++------- .../test_client_credentials_grant.py | 119 ++--- .../test_client_registration_endpoint.py | 184 ++++---- .../test_oauth2/test_device_code_grant.py | 294 ++++++------ .../test_oauth2/test_implicit_grant.py | 77 ++-- .../test_introspection_endpoint.py | 181 ++++---- .../fastapi/test_oauth2/test_oauth2_server.py | 161 ++++--- .../test_oauth2/test_openid_hybrid_grant.py | 429 ++++++++++-------- .../test_oauth2/test_openid_implict_grant.py | 249 +++++----- .../test_oauth2/test_password_grant.py | 240 +++++----- tox.ini | 1 - 15 files changed, 1428 insertions(+), 1222 deletions(-) diff --git a/authlib/integrations/fastapi_oauth2/authorization_server.py b/authlib/integrations/fastapi_oauth2/authorization_server.py index 5d550771..9e5feb13 100644 --- a/authlib/integrations/fastapi_oauth2/authorization_server.py +++ b/authlib/integrations/fastapi_oauth2/authorization_server.py @@ -1,56 +1,71 @@ """Implementation of authlib.oauth2.rfc6749.AuthorizationServer class for FastAPI.""" import json -from werkzeug.utils import import_string -from fastapi.responses import JSONResponse -from authlib.oauth2 import ( - OAuth2Request, - HttpRequest, - AuthorizationServer as _AuthorizationServer, -) + +from authlib.common.security import generate_token +from authlib.oauth2 import AuthorizationServer as _AuthorizationServer +from authlib.oauth2 import HttpRequest, OAuth2Request from authlib.oauth2.rfc6750 import BearerToken from authlib.oauth2.rfc8414 import AuthorizationServerMetadata -from authlib.common.security import generate_token +from fastapi.responses import JSONResponse +from werkzeug.utils import import_string class AuthorizationServer(_AuthorizationServer): """AuthorizationServer class.""" - def __init__(self, query_client=None, save_token=None): - super().__init__(query_client=query_client, save_token=save_token) + def __init__(self, app=None, query_client=None, save_token=None): + super(AuthorizationServer, self).__init__() + self._query_client = query_client + self._save_token = save_token self.config = {} + if app: + self.init_app(app) def init_app(self, app, query_client=None, save_token=None): """Initialize the FastAPI app.""" - if query_client is not None: + if query_client: self.query_client = query_client - if save_token is not None: + if save_token: self.save_token = save_token self.generate_token = create_bearer_token_generator(app.config) metadata_class = AuthorizationServerMetadata - metadata_file = app.config.get('OAUTH2_METADATA_FILE') + metadata_file = app.config.get("OAUTH2_METADATA_FILE") if metadata_file: with open(metadata_file) as metadata_file_content: metadata = metadata_class(json.loads(metadata_file_content)) metadata.validate() self.metadata = metadata - self.config.setdefault('error_uris', app.config.get('OAUTH2_ERROR_URIS')) + self.scopes_supported = app.config.get("OAUTH2_SCOPES_SUPPORTED") + self._error_uris = app.config.get("OAUTH2_ERROR_URIS") + + def query_client(self, client_id): + return self._query_client(client_id) + + def save_token(self, token, request): + return self._save_token(token, request) - def get_error_uris(self, request): - error_uris = self.config.get('error_uris') - if error_uris: - return dict(error_uris) - return None + def get_error_uri(self, request, error): + if self._error_uris: + uris = dict(self._error_uris) + return uris.get(error.error) def create_oauth2_request(self, request): - return OAuth2Request(request.method, str(request.url), request.body, request.headers) + return OAuth2Request( + request.method, str(request.url), request.body, request.headers + ) def create_json_request(self, request): - return HttpRequest(request.method, str(request.url), request.body, request.headers) + return HttpRequest( + request.method, str(request.url), request.body, request.headers + ) + + def send_signal(self, name, *args, **kwargs): + pass def handle_response(self, status, body, headers): return JSONResponse(content=body, status_code=status, headers=dict(headers)) @@ -63,7 +78,7 @@ def validate_consent_request(self, request=None, end_user=None): grant = self.get_authorization_grant(req) grant.validate_consent_request() - if not hasattr(grant, 'prompt'): + if not hasattr(grant, "prompt"): grant.prompt = None return grant @@ -75,18 +90,16 @@ def create_bearer_token_generator(config): generate ``refresh_token``, which can be turn on by configuration ``OAUTH2_REFRESH_TOKEN_GENERATOR=True``. """ - conf = config.get('OAUTH2_ACCESS_TOKEN_GENERATOR', True) + conf = config.get("OAUTH2_ACCESS_TOKEN_GENERATOR", True) access_token_generator = create_token_generator(conf, 42) - conf = config.get('OAUTH2_REFRESH_TOKEN_GENERATOR', False) + conf = config.get("OAUTH2_REFRESH_TOKEN_GENERATOR", False) refresh_token_generator = create_token_generator(conf, 48) expires_generator = create_token_expires_in_generator(config) return BearerToken( - access_token_generator, - refresh_token_generator, - expires_generator + access_token_generator, refresh_token_generator, expires_generator ) @@ -104,7 +117,7 @@ def create_token_expires_in_generator(config): data = {} data.update(BearerToken.GRANT_TYPES_EXPIRES_IN) - expires_in_conf = config.get('OAUTH2_TOKEN_EXPIRES_IN') + expires_in_conf = config.get("OAUTH2_TOKEN_EXPIRES_IN") if expires_in_conf: data.update(expires_in_conf) @@ -123,8 +136,10 @@ def create_token_generator(token_generator_conf, length=42): return import_string(token_generator_conf) if token_generator_conf is True: + def token_generator(*args, **kwargs): # pylint: disable=W0613 return generate_token(length) + return token_generator return None diff --git a/authlib/integrations/fastapi_oauth2/resource_protector.py b/authlib/integrations/fastapi_oauth2/resource_protector.py index 5e89cea7..df1d8d97 100644 --- a/authlib/integrations/fastapi_oauth2/resource_protector.py +++ b/authlib/integrations/fastapi_oauth2/resource_protector.py @@ -2,62 +2,53 @@ import functools from contextlib import contextmanager + +from authlib.oauth2 import OAuth2Error +from authlib.oauth2 import ResourceProtector as _ResourceProtector +from authlib.oauth2.rfc6749 import HttpRequest, MissingAuthorizationError from fastapi import HTTPException -from authlib.oauth2 import ( - OAuth2Error, - ResourceProtector as _ResourceProtector -) -from authlib.oauth2.rfc6749 import ( - MissingAuthorizationError, - HttpRequest, -) class ResourceProtector(_ResourceProtector): """ResourceProtector class.""" - def acquire_token(self, request=None, scope=None, operator='AND'): + def acquire_token(self, request=None, scope=None): """A method to acquire current valid token with the given scope. :param request: request object :param scope: string or list of scope values - :param operator: value of "AND" or "OR" :return: token object """ - request = HttpRequest( - request.method, - request.url, - {}, - request.headers - ) - if not callable(operator): - operator = operator.upper() - token = self.validate_request(scope, request, operator) + http_request = HttpRequest(request.method, request.url, {}, request.headers) + token = self.validate_request(scope, http_request) + request.state.token = token return token @contextmanager - def acquire(self, request=None, scope=None, operator='AND'): + def acquire(self, request=None, scope=None): """The with statement of ``require_oauth``. Instead of using a decorator, you can use a with statement instead.""" try: - yield self.acquire_token(request, scope, operator) + yield self.acquire_token(request, scope) except OAuth2Error as error: raise_error_response(error) - def __call__(self, scope=None, operator='AND', optional=False): + def __call__(self, scope=None, optional=False): def wrapper(func): @functools.wraps(func) - def decorated(*args, **kwargs): + def decorated(request, *args, **kwargs): try: - self.acquire_token(scope, operator) + self.acquire_token(request, scope) except MissingAuthorizationError as error: if optional: - return func(*args, **kwargs) + return func(request, *args, **kwargs) raise_error_response(error) except OAuth2Error as error: raise_error_response(error) - return func(*args, **kwargs) + return func(request, *args, **kwargs) + return decorated + return wrapper @@ -66,8 +57,4 @@ def raise_error_response(error): status = error.status_code body = dict(error.get_body()) headers = error.get_headers() - raise HTTPException( - status_code=status, - detail=body, - headers=dict(headers) - ) + raise HTTPException(status_code=status, detail=body, headers=dict(headers)) diff --git a/tests/fastapi/test_oauth2/models.py b/tests/fastapi/test_oauth2/models.py index 12937077..13e2783d 100644 --- a/tests/fastapi/test_oauth2/models.py +++ b/tests/fastapi/test_oauth2/models.py @@ -1,16 +1,16 @@ import time -from sqlalchemy import Column, ForeignKey, Integer, String, create_engine + +from authlib.integrations.sqla_oauth2 import (OAuth2AuthorizationCodeMixin, + OAuth2ClientMixin, + OAuth2TokenMixin) +from authlib.oidc.core import UserInfo +from sqlalchemy import (Boolean, Column, ForeignKey, Integer, String, + create_engine) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship, sessionmaker -from authlib.integrations.sqla_oauth2 import ( - OAuth2ClientMixin, - OAuth2TokenMixin, - OAuth2AuthorizationCodeMixin, -) -from authlib.oidc.core import UserInfo engine = create_engine( - 'sqlite:///fastapi_auth2_sql.db', connect_args={'check_same_thread': False} + "sqlite:///fastapi_auth2_sql.db", connect_args={"check_same_thread": False} ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -21,7 +21,7 @@ class User(Base): - __tablename__ = 'user' + __tablename__ = "user" id = Column(Integer, primary_key=True) username = Column(String(40), unique=True, nullable=False) @@ -30,43 +30,39 @@ def get_user_id(self): return self.id def check_password(self, password): - return password != 'wrong' + return password != "wrong" def generate_user_info(self, scopes): - profile = {'sub': str(self.id), 'name': self.username} + profile = {"sub": str(self.id), "name": self.username} return UserInfo(profile) class Client(Base, OAuth2ClientMixin): - __tablename__ = 'oauth2_client' + __tablename__ = "oauth2_client" id = Column(Integer, primary_key=True) - user_id = Column( - Integer, ForeignKey('user.id', ondelete='CASCADE') - ) - user = relationship('User') + user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE")) + user = relationship("User") class AuthorizationCode(Base, OAuth2AuthorizationCodeMixin): - __tablename__ = 'oauth2_code' + __tablename__ = "oauth2_code" id = Column(Integer, primary_key=True) user_id = Column(Integer, nullable=False) @property def user(self): - return db.query(User).filter( - User.id == self.user_id).first() + return db.query(User).filter(User.id == self.user_id).first() class Token(Base, OAuth2TokenMixin): - __tablename__ = 'oauth2_token' + __tablename__ = "oauth2_token" id = Column(Integer, primary_key=True) - user_id = Column( - Integer, ForeignKey('user.id', ondelete='CASCADE') - ) - user = relationship('User') + user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE")) + user = relationship("User") + revoked = Column(Boolean) def is_refresh_token_expired(self): expired_at = self.issued_at + self.expires_in * 2 @@ -75,9 +71,13 @@ def is_refresh_token_expired(self): class CodeGrantMixin(object): def query_authorization_code(self, code, client): - item = db.query(AuthorizationCode).filter( - AuthorizationCode.code == code, - Client.client_id == client.client_id).first() + item = ( + db.query(AuthorizationCode) + .filter( + AuthorizationCode.code == code, Client.client_id == client.client_id + ) + .first() + ) if item and not item.is_expired(): return item @@ -86,8 +86,7 @@ def delete_authorization_code(self, authorization_code): db.commit() def authenticate_user(self, authorization_code): - return db.query(User).filter( - User.id == authorization_code.user_id).first() + return db.query(User).filter(User.id == authorization_code.user_id).first() def save_authorization_code(code, request): @@ -97,10 +96,10 @@ def save_authorization_code(code, request): client_id=client.client_id, redirect_uri=request.redirect_uri, scope=request.scope, - nonce=request.data.get('nonce'), + nonce=request.data.get("nonce"), user_id=request.user.id, - code_challenge=request.data.get('code_challenge'), - code_challenge_method=request.data.get('code_challenge_method'), + code_challenge=request.data.get("code_challenge"), + code_challenge_method=request.data.get("code_challenge_method"), ) db.add(auth_code) db.commit() @@ -108,6 +107,9 @@ def save_authorization_code(code, request): def exists_nonce(nonce, req): - exists = db.query(AuthorizationCode).filter( - Client.client_id == req.client_id, AuthorizationCode.nonce == nonce).first() + exists = ( + db.query(AuthorizationCode) + .filter(Client.client_id == req.client_id, AuthorizationCode.nonce == nonce) + .first() + ) return bool(exists) diff --git a/tests/fastapi/test_oauth2/oauth2_server.py b/tests/fastapi/test_oauth2/oauth2_server.py index 3c03280e..d5a70760 100644 --- a/tests/fastapi/test_oauth2/oauth2_server.py +++ b/tests/fastapi/test_oauth2/oauth2_server.py @@ -1,27 +1,25 @@ -import os import base64 +import os import unittest -from fastapi import FastAPI, Request, Form -from fastapi.testclient import TestClient -from authlib.common.security import generate_token + from authlib.common.encoding import to_bytes, to_unicode +from authlib.common.security import generate_token from authlib.common.urls import url_encode -from authlib.integrations.sqla_oauth2 import ( - create_query_client_func, - create_save_token_func, -) from authlib.integrations.fastapi_oauth2 import AuthorizationServer +from authlib.integrations.sqla_oauth2 import (create_query_client_func, + create_save_token_func) from authlib.oauth2 import OAuth2Error -from .models import Base, engine, db, User, Client, Token +from fastapi import FastAPI, Form, Request +from fastapi.testclient import TestClient -os.environ['AUTHLIB_INSECURE_TRANSPORT'] = 'true' +from .models import Base, Client, Token, User, db, engine def token_generator(client, grant_type, user=None, scope=None): - token = '{}-{}'.format(client.client_id[0], grant_type) + token = "{}-{}".format(client.client_id[0], grant_type) if user: - token = '{}.{}'.format(token, user.get_user_id()) - return '{}.{}'.format(token, generate_token(32)) + token = "{}.{}".format(token, user.get_user_id()) + return "{}.{}".format(token, generate_token(32)) def create_authorization_server(app): @@ -31,9 +29,9 @@ def create_authorization_server(app): server = AuthorizationServer() server.init_app(app, query_client, save_token) - @app.get('/oauth/authorize') + @app.get("/oauth/authorize") def authorize_get(request: Request): - user_id = request.query_params.get('user_id') + user_id = request.query_params.get("user_id") request.body = {} if user_id: end_user = db.query(User).filter(User.id == int(user_id)).first() @@ -41,115 +39,118 @@ def authorize_get(request: Request): end_user = None try: grant = server.validate_consent_request(request=request, end_user=end_user) - return grant.prompt or 'ok' + return grant.prompt or "ok" except OAuth2Error as error: return url_encode(error.get_body()) - @app.post('/oauth/authorize') - def authorize_post(request: Request, - response_type: str = Form(None), - client_id: str = Form(None), - state: str = Form(None), - scope: str = Form(None), - nonce: str = Form(None), - redirect_uri: str = Form(None), - response_mode: str = Form(None), - user_id: str = Form(None)): + @app.post("/oauth/authorize") + def authorize_post( + request: Request, + response_type: str = Form(None), + client_id: str = Form(None), + state: str = Form(None), + scope: str = Form(None), + nonce: str = Form(None), + redirect_uri: str = Form(None), + response_mode: str = Form(None), + user_id: str = Form(None), + ): if not user_id: - user_id = request.query_params.get('user_id') + user_id = request.query_params.get("user_id") - request.body = { - 'user_id': user_id - } + request.body = {"user_id": user_id} if response_type: - request.body.update({'response_type': response_type}) + request.body.update({"response_type": response_type}) if client_id: - request.body.update({'client_id': client_id}) + request.body.update({"client_id": client_id}) if state: - request.body.update({'state': state}) + request.body.update({"state": state}) if nonce: - request.body.update({'nonce': nonce}) + request.body.update({"nonce": nonce}) if scope: - request.body.update({'scope': scope}) + request.body.update({"scope": scope}) if redirect_uri: - request.body.update({'redirect_uri': redirect_uri}) + request.body.update({"redirect_uri": redirect_uri}) if response_mode: - request.body.update({'response_mode': response_mode}) + request.body.update({"response_mode": response_mode}) if user_id: grant_user = db.query(User).filter(User.id == int(user_id)).first() else: grant_user = None - return server.create_authorization_response(request=request, grant_user=grant_user) + return server.create_authorization_response( + request=request, grant_user=grant_user + ) - @app.api_route('/oauth/token', methods=["GET", "POST"]) + @app.api_route("/oauth/token", methods=["GET", "POST"]) def issue_token( - request: Request, - grant_type: str = Form(None), - scope: str = Form(None), - code: str = Form(None), - refresh_token: str = Form(None), - code_verifier: str = Form(None), - client_id: str = Form(None), - client_secret: str = Form(None), - device_code: str = Form(None), - client_assertion_type: str = Form(None), - client_assertion: str = Form(None), - assertion: str = Form(None), - username: str = Form(None), - password: str = Form(None), - redirect_uri: str = Form(None)): + request: Request, + grant_type: str = Form(None), + scope: str = Form(None), + code: str = Form(None), + refresh_token: str = Form(None), + code_verifier: str = Form(None), + client_id: str = Form(None), + client_secret: str = Form(None), + device_code: str = Form(None), + client_assertion_type: str = Form(None), + client_assertion: str = Form(None), + assertion: str = Form(None), + username: str = Form(None), + password: str = Form(None), + redirect_uri: str = Form(None), + ): request.body = { - 'grant_type': grant_type, - 'scope': scope, + "grant_type": grant_type, + "scope": scope, } if not grant_type: - grant_type = request.query_params.get('grant_type') - request.body.update({'grant_type': grant_type}) + grant_type = request.query_params.get("grant_type") + request.body.update({"grant_type": grant_type}) - if grant_type == 'authorization_code': - request.body.update({'code': code}) - elif grant_type == 'refresh_token': - request.body.update({'refresh_token': refresh_token}) + if grant_type == "authorization_code": + request.body.update({"code": code}) + elif grant_type == "refresh_token": + request.body.update({"refresh_token": refresh_token}) if code_verifier: - request.body.update({'code_verifier': code_verifier}) + request.body.update({"code_verifier": code_verifier}) if client_id: - request.body.update({'client_id': client_id}) + request.body.update({"client_id": client_id}) if client_secret: - request.body.update({'client_secret': client_secret}) + request.body.update({"client_secret": client_secret}) if device_code: - request.body.update({'device_code': device_code}) + request.body.update({"device_code": device_code}) if client_assertion_type: - request.body.update({'client_assertion_type': client_assertion_type}) + request.body.update({"client_assertion_type": client_assertion_type}) if client_assertion: - request.body.update({'client_assertion': client_assertion}) + request.body.update({"client_assertion": client_assertion}) if assertion: - request.body.update({'assertion': assertion}) + request.body.update({"assertion": assertion}) if redirect_uri: - request.body.update({'redirect_uri': redirect_uri}) + request.body.update({"redirect_uri": redirect_uri}) if username: - request.body.update({'username': username}) + request.body.update({"username": username}) if password: - request.body.update({'password': password}) + request.body.update({"password": password}) return server.create_token_response(request=request) @@ -160,18 +161,17 @@ def create_fastapi_app(): app = FastAPI() app.debug = True app.testing = True - app.secret_key = 'testing' + app.secret_key = "testing" app.test_client = TestClient(app) app.config = { - 'OAUTH2_ERROR_URIS': [ - ('invalid_client', 'https://a.b/e#invalid_client') - ] + "OAUTH2_ERROR_URIS": [("invalid_client", "https://a.b/e#invalid_client")] } return app class TestCase(unittest.TestCase): def setUp(self): + os.environ["AUTHLIB_INSECURE_TRANSPORT"] = "true" app = create_fastapi_app() Base.metadata.create_all(bind=engine) @@ -181,8 +181,9 @@ def setUp(self): def tearDown(self): Base.metadata.drop_all(bind=engine) + os.environ.pop("AUTHLIB_INSECURE_TRANSPORT") def create_basic_header(self, username, password): - text = '{}:{}'.format(username, password) + text = "{}:{}".format(username, password) auth = to_unicode(base64.b64encode(to_bytes(text))) - return {'Authorization': 'Basic ' + auth} + return {"Authorization": "Basic " + auth} diff --git a/tests/fastapi/test_oauth2/test_authorization_code_grant.py b/tests/fastapi/test_oauth2/test_authorization_code_grant.py index 907b0da1..d244e4fd 100644 --- a/tests/fastapi/test_oauth2/test_authorization_code_grant.py +++ b/tests/fastapi/test_oauth2/test_authorization_code_grant.py @@ -1,56 +1,58 @@ -from authlib.common.urls import urlparse, url_decode -from authlib.oauth2.rfc6749.grants import ( - AuthorizationCodeGrant as _AuthorizationCodeGrant, -) -from .models import db, User, Client, AuthorizationCode -from .models import CodeGrantMixin, save_authorization_code -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server +from authlib.common.urls import url_decode, urlparse +from authlib.oauth2.rfc6749.grants import \ + AuthorizationCodeGrant as _AuthorizationCodeGrant + +from .models import (AuthorizationCode, Client, CodeGrantMixin, User, db, + save_authorization_code) +from .oauth2_server import TestCase, create_authorization_server class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): - TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] + TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post", "none"] def save_authorization_code(self, code, request): return save_authorization_code(code, request) class AuthorizationCodeTest(TestCase): - def register_grant(self, server): server.register_grant(AuthorizationCodeGrant) def prepare_data( - self, is_confidential=True, - response_type='code', grant_type='authorization_code', - token_endpoint_auth_method='client_secret_basic'): + self, + is_confidential=True, + response_type="code", + grant_type="authorization_code", + token_endpoint_auth_method="client_secret_basic", + ): server = create_authorization_server(self.app) self.register_grant(server) self.server = server - user = User(username='foo') + user = User(username="foo") db.add(user) db.commit() if is_confidential: - client_secret = 'code-secret' + client_secret = "code-secret" else: - client_secret = '' + client_secret = "" client = Client( user_id=user.id, - client_id='code-client', + client_id="code-client", client_secret=client_secret, ) - client.set_client_metadata({ - 'redirect_uris': ['https://a.b'], - 'scope': 'profile address', - 'token_endpoint_auth_method': token_endpoint_auth_method, - 'response_types': [response_type], - 'grant_types': grant_type.splitlines(), - }) + client.set_client_metadata( + { + "redirect_uris": ["https://a.b"], + "scope": "profile address", + "token_endpoint_auth_method": token_endpoint_auth_method, + "response_types": [response_type], + "grant_types": grant_type.splitlines(), + } + ) self.authorize_url = ( - '/oauth/authorize?response_type=code' - '&client_id=code-client' + "/oauth/authorize?response_type=code" "&client_id=code-client" ) db.add(client) db.commit() @@ -58,195 +60,229 @@ def prepare_data( def test_get_authorize(self): self.prepare_data() rv = self.client.get(self.authorize_url) - self.assertEqual(rv.json(), 'ok') + self.assertEqual(rv.json(), "ok") def test_invalid_client_id(self): self.prepare_data() - url = '/oauth/authorize?response_type=code' + url = "/oauth/authorize?response_type=code" rv = self.client.get(url) - self.assertIn('invalid_client', rv.json()) + self.assertIn("invalid_client", rv.json()) - url = '/oauth/authorize?response_type=code&client_id=invalid' + url = "/oauth/authorize?response_type=code&client_id=invalid" rv = self.client.get(url) - self.assertIn('invalid_client', rv.json()) + self.assertIn("invalid_client", rv.json()) def test_invalid_authorize(self): self.prepare_data() rv = self.client.post(self.authorize_url) - self.assertIn('error=access_denied', rv.headers['location']) + self.assertIn("error=access_denied", rv.headers["location"]) - self.server.metadata = {'scopes_supported': ['profile']} - rv = self.client.post(self.authorize_url + '&scope=invalid&state=foo') - self.assertIn('error=invalid_scope', rv.headers['location']) - self.assertIn('state=foo', rv.headers['location']) + self.server.scopes_supported = ["profile"] + rv = self.client.post(self.authorize_url + "&scope=invalid&state=foo") + self.assertIn("error=invalid_scope", rv.headers["location"]) + self.assertIn("state=foo", rv.headers["location"]) def test_unauthorized_client(self): - self.prepare_data(True, 'token') + self.prepare_data(True, "token") rv = self.client.get(self.authorize_url) - self.assertIn('unauthorized_client', rv.json()) + self.assertIn("unauthorized_client", rv.json()) def test_invalid_client(self): self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': 'invalid', - 'client_id': 'invalid-id', - }) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "invalid", + "client_id": "invalid-id", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_client') - - headers = self.create_basic_header('code-client', 'invalid-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': 'invalid', - }, headers=headers) + self.assertEqual(resp["error"], "invalid_client") + + headers = self.create_basic_header("code-client", "invalid-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "invalid", + }, + headers=headers, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_client') - self.assertEqual(resp['error_uri'], 'https://a.b/e#invalid_client') + self.assertEqual(resp["error"], "invalid_client") + self.assertEqual(resp["error_uri"], "https://a.b/e#invalid_client") def test_invalid_code(self): self.prepare_data() - headers = self.create_basic_header('code-client', 'code-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - }, headers=headers) + headers = self.create_basic_header("code-client", "code-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + }, + headers=headers, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': 'invalid', - }, headers=headers) + self.assertEqual(resp["error"], "invalid_request") + + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "invalid", + }, + headers=headers, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') + self.assertEqual(resp["error"], "invalid_request") - code = AuthorizationCode( - code='no-user', - client_id='code-client', - user_id=0 - ) + code = AuthorizationCode(code="no-user", client_id="code-client", user_id=0) db.add(code) db.commit() - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': 'no-user', - }, headers=headers) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "no-user", + }, + headers=headers, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') + self.assertEqual(resp["error"], "invalid_request") def test_invalid_redirect_uri(self): self.prepare_data() - uri = self.authorize_url + '&redirect_uri=https%3A%2F%2Fa.c' - rv = self.client.post(uri, data={'user_id': '1'}) + uri = self.authorize_url + "&redirect_uri=https%3A%2F%2Fa.c" + rv = self.client.post(uri, data={"user_id": "1"}) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') - - uri = self.authorize_url + '&redirect_uri=https%3A%2F%2Fa.b' - rv = self.client.post(uri, data={'user_id': '1'}) - self.assertIn('code=', rv.headers['location']) - - params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) - code = params['code'] - headers = self.create_basic_header('code-client', 'code-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - }, headers=headers) + self.assertEqual(resp["error"], "invalid_request") + + uri = self.authorize_url + "&redirect_uri=https%3A%2F%2Fa.b" + rv = self.client.post(uri, data={"user_id": "1"}) + self.assertIn("code=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + code = params["code"] + headers = self.create_basic_header("code-client", "code-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + }, + headers=headers, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') + self.assertEqual(resp["error"], "invalid_request") def test_invalid_grant_type(self): self.prepare_data( - False, token_endpoint_auth_method='none', - grant_type='invalid' - ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'client_id': 'code-client', - 'code': 'a', - }) + False, token_endpoint_auth_method="none", grant_type="invalid" + ) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "client_id": "code-client", + "code": "a", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'unauthorized_client') + self.assertEqual(resp["error"], "unauthorized_client") def test_authorize_token_no_refresh_token(self): - self.app.config.update({'OAUTH2_REFRESH_TOKEN_GENERATOR': True}) - self.prepare_data(False, token_endpoint_auth_method='none') - - rv = self.client.post(self.authorize_url, data={'user_id': '1'}) - self.assertIn('code=', rv.headers['location']) - - params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) - code = params['code'] - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - 'client_id': 'code-client', - }) + self.app.config.update({"OAUTH2_REFRESH_TOKEN_GENERATOR": True}) + self.prepare_data(False, token_endpoint_auth_method="none") + + rv = self.client.post(self.authorize_url, data={"user_id": "1"}) + self.assertIn("code=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + code = params["code"] + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "client_id": "code-client", + }, + ) resp = rv.json() - self.assertIn('access_token', resp) - self.assertNotIn('refresh_token', resp) + self.assertIn("access_token", resp) + self.assertNotIn("refresh_token", resp) def test_authorize_token_has_refresh_token(self): # generate refresh token - self.app.config.update({'OAUTH2_REFRESH_TOKEN_GENERATOR': True}) - self.prepare_data(grant_type='authorization_code\nrefresh_token') - url = self.authorize_url + '&state=bar' - rv = self.client.post(url, data={'user_id': '1'}) - self.assertIn('code=', rv.headers['location']) - - params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) - self.assertEqual(params['state'], 'bar') - - code = params['code'] - headers = self.create_basic_header('code-client', 'code-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - }, headers=headers) + self.app.config.update({"OAUTH2_REFRESH_TOKEN_GENERATOR": True}) + self.prepare_data(grant_type="authorization_code\nrefresh_token") + url = self.authorize_url + "&state=bar" + rv = self.client.post(url, data={"user_id": "1"}) + self.assertIn("code=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + self.assertEqual(params["state"], "bar") + + code = params["code"] + headers = self.create_basic_header("code-client", "code-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + }, + headers=headers, + ) resp = rv.json() - self.assertIn('access_token', resp) - self.assertIn('refresh_token', resp) + self.assertIn("access_token", resp) + self.assertIn("refresh_token", resp) def test_client_secret_post(self): - self.app.config.update({'OAUTH2_REFRESH_TOKEN_GENERATOR': True}) + self.app.config.update({"OAUTH2_REFRESH_TOKEN_GENERATOR": True}) self.prepare_data( - grant_type='authorization_code\nrefresh_token', - token_endpoint_auth_method='client_secret_post', - ) - url = self.authorize_url + '&state=bar' - rv = self.client.post(url, data={'user_id': '1'}) - self.assertIn('code=', rv.headers['location']) - - params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) - self.assertEqual(params['state'], 'bar') - - code = params['code'] - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'client_id': 'code-client', - 'client_secret': 'code-secret', - 'code': code, - }) + grant_type="authorization_code\nrefresh_token", + token_endpoint_auth_method="client_secret_post", + ) + url = self.authorize_url + "&state=bar" + rv = self.client.post(url, data={"user_id": "1"}) + self.assertIn("code=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + self.assertEqual(params["state"], "bar") + + code = params["code"] + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "client_id": "code-client", + "client_secret": "code-secret", + "code": code, + }, + ) resp = rv.json() - self.assertIn('access_token', resp) - self.assertIn('refresh_token', resp) + self.assertIn("access_token", resp) + self.assertIn("refresh_token", resp) def test_token_generator(self): - m = 'tests.fastapi.test_oauth2.oauth2_server:token_generator' - self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) - self.prepare_data(False, token_endpoint_auth_method='none') - - rv = self.client.post(self.authorize_url, data={'user_id': '1'}) - self.assertIn('code=', rv.headers['location']) - - params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) - code = params['code'] - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': code, - 'client_id': 'code-client', - }) + m = "tests.fastapi.test_oauth2.oauth2_server:token_generator" + self.app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) + self.prepare_data(False, token_endpoint_auth_method="none") + + rv = self.client.post(self.authorize_url, data={"user_id": "1"}) + self.assertIn("code=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + code = params["code"] + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": code, + "client_id": "code-client", + }, + ) resp = rv.json() - self.assertIn('access_token', resp) - self.assertIn('c-authorization_code.1.', resp['access_token']) + self.assertIn("access_token", resp) + self.assertIn("c-authorization_code.1.", resp["access_token"]) diff --git a/tests/fastapi/test_oauth2/test_client_credentials_grant.py b/tests/fastapi/test_oauth2/test_client_credentials_grant.py index 19f48fb1..29470d78 100644 --- a/tests/fastapi/test_oauth2/test_client_credentials_grant.py +++ b/tests/fastapi/test_oauth2/test_client_credentials_grant.py @@ -1,94 +1,109 @@ from authlib.oauth2.rfc6749.grants import ClientCredentialsGrant -from .models import db, User, Client -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server + +from .models import Client, User, db +from .oauth2_server import TestCase, create_authorization_server class ClientCredentialsTest(TestCase): - def prepare_data(self, grant_type='client_credentials'): + def prepare_data(self, grant_type="client_credentials"): server = create_authorization_server(self.app) server.register_grant(ClientCredentialsGrant) self.server = server - user = User(username='foo') + user = User(username="foo") db.add(user) db.commit() client = Client( user_id=user.id, - client_id='credential-client', - client_secret='credential-secret', + client_id="credential-client", + client_secret="credential-secret", + ) + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["http://localhost/authorized"], + "grant_types": [grant_type], + } ) - client.set_client_metadata({ - 'scope': 'profile', - 'redirect_uris': ['http://localhost/authorized'], - 'grant_types': [grant_type] - }) db.add(client) db.commit() def test_invalid_client(self): self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - }) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_client') + self.assertEqual(resp["error"], "invalid_client") - headers = self.create_basic_header( - 'credential-client', 'invalid-secret' + headers = self.create_basic_header("credential-client", "invalid-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + headers=headers, ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - }, headers=headers) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_client') + self.assertEqual(resp["error"], "invalid_client") def test_invalid_grant_type(self): - self.prepare_data(grant_type='invalid') - headers = self.create_basic_header( - 'credential-client', 'credential-secret' + self.prepare_data(grant_type="invalid") + headers = self.create_basic_header("credential-client", "credential-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + headers=headers, ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - }, headers=headers) resp = rv.json() - self.assertEqual(resp['error'], 'unauthorized_client') + self.assertEqual(resp["error"], "unauthorized_client") def test_invalid_scope(self): self.prepare_data() - self.server.metadata = {'scopes_supported': ['profile']} - headers = self.create_basic_header( - 'credential-client', 'credential-secret' + self.server.scopes_supported = ["profile"] + headers = self.create_basic_header("credential-client", "credential-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + "scope": "invalid", + }, + headers=headers, ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - 'scope': 'invalid', - }, headers=headers) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_scope') + self.assertEqual(resp["error"], "invalid_scope") def test_authorize_token(self): self.prepare_data() - headers = self.create_basic_header( - 'credential-client', 'credential-secret' + headers = self.create_basic_header("credential-client", "credential-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + headers=headers, ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - }, headers=headers) resp = rv.json() - self.assertIn('access_token', resp) + self.assertIn("access_token", resp) def test_token_generator(self): - m = 'tests.fastapi.test_oauth2.oauth2_server:token_generator' - self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) + m = "tests.fastapi.test_oauth2.oauth2_server:token_generator" + self.app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) self.prepare_data() - headers = self.create_basic_header( - 'credential-client', 'credential-secret' + headers = self.create_basic_header("credential-client", "credential-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "client_credentials", + }, + headers=headers, ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'client_credentials', - }, headers=headers) resp = rv.json() - self.assertIn('access_token', resp) - self.assertIn('c-client_credentials.', resp['access_token']) + self.assertIn("access_token", resp) + self.assertIn("c-client_credentials.", resp["access_token"]) diff --git a/tests/fastapi/test_oauth2/test_client_registration_endpoint.py b/tests/fastapi/test_oauth2/test_client_registration_endpoint.py index d8bc7ddb..a0a4c936 100644 --- a/tests/fastapi/test_oauth2/test_client_registration_endpoint.py +++ b/tests/fastapi/test_oauth2/test_client_registration_endpoint.py @@ -1,30 +1,28 @@ -from pydantic import BaseModel -from fastapi import Request from authlib.jose import jwt -from authlib.oauth2.rfc7591 import ClientRegistrationEndpoint as _ClientRegistrationEndpoint +from authlib.oauth2.rfc7591 import \ + ClientRegistrationEndpoint as _ClientRegistrationEndpoint +from fastapi import Request +from pydantic import BaseModel from tests.util import read_file_path -from .models import db, User, Client -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server + +from .models import Client, User, db +from .oauth2_server import TestCase, create_authorization_server class ClientRegistrationEndpoint(_ClientRegistrationEndpoint): - software_statement_alg_values_supported = ['RS256'] + software_statement_alg_values_supported = ["RS256"] def authenticate_token(self, request): - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if auth_header: request.user_id = 1 return auth_header def resolve_public_key(self, request): - return read_file_path('rsa_public.pem') + return read_file_path("rsa_public.pem") def save_client(self, client_info, client_metadata, request): - client = Client( - user_id=request.user_id, - **client_info - ) + client = Client(user_id=request.user_id, **client_info) client.set_client_metadata(client_metadata) db.add(client) db.commit() @@ -38,9 +36,15 @@ def prepare_data(self, endpoint_cls=None, metadata=None): if metadata: server.metadata = metadata - if endpoint_cls is None: - endpoint_cls = ClientRegistrationEndpoint - server.register_endpoint(endpoint_cls) + if endpoint_cls: + server.register_endpoint(endpoint_cls) + else: + + class MyClientRegistration(ClientRegistrationEndpoint): + def get_server_metadata(self): + return metadata + + server.register_endpoint(MyClientRegistration) class Item(BaseModel): client_name: str = None @@ -52,142 +56,144 @@ class Item(BaseModel): grant_types: list = None response_types: list = None - @app.post('/create_client') + @app.post("/create_client") def create_client(request: Request, item: Item = None): request.body = {} if item: request.body = { - 'client_name': item.client_name, - 'client_uri': item.client_uri, - 'redirect_uri': item.redirect_uri, - 'scope': item.scope, - 'software_statement': item.software_statement, - 'token_endpoint_auth_method': item.token_endpoint_auth_method, - 'grant_types': item.grant_types, - 'response_types': item.response_types, + "client_name": item.client_name, + "client_uri": item.client_uri, + "redirect_uri": item.redirect_uri, + "scope": item.scope, + "software_statement": item.software_statement, + "token_endpoint_auth_method": item.token_endpoint_auth_method, + "grant_types": item.grant_types, + "response_types": item.response_types, } - return server.create_endpoint_response('client_registration', request=request) + return server.create_endpoint_response( + "client_registration", request=request + ) - user = User(username='foo') + user = User(username="foo") db.add(user) db.commit() def test_access_denied(self): self.prepare_data() - rv = self.client.post('/create_client') + rv = self.client.post("/create_client") resp = rv.json() - self.assertEqual(resp['error'], 'access_denied') + self.assertEqual(resp["error"], "access_denied") def test_invalid_request(self): self.prepare_data() - headers = {'Authorization': 'bearer abc'} - rv = self.client.post('/create_client', headers=headers) + headers = {"Authorization": "bearer abc"} + rv = self.client.post("/create_client", headers=headers) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') + self.assertEqual(resp["error"], "invalid_request") def test_create_client(self): self.prepare_data() - headers = {'Authorization': 'bearer abc'} - body = { - 'client_name': 'Authlib' - } - rv = self.client.post('/create_client', json=body, headers=headers) + headers = {"Authorization": "bearer abc"} + body = {"client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) resp = rv.json() - self.assertIn('client_id', resp) - self.assertEqual(resp['client_name'], 'Authlib') + self.assertIn("client_id", resp) + self.assertEqual(resp["client_name"], "Authlib") def test_software_statement(self): - payload = {'software_id': 'uuid-123', 'client_name': 'Authlib'} - s = jwt.encode({'alg': 'RS256'}, payload, read_file_path('rsa_private.pem')) + payload = {"software_id": "uuid-123", "client_name": "Authlib"} + s = jwt.encode({"alg": "RS256"}, payload, read_file_path("rsa_private.pem")) body = { - 'software_statement': s.decode('utf-8'), + "software_statement": s.decode("utf-8"), } self.prepare_data() - headers = {'Authorization': 'bearer abc'} - rv = self.client.post('/create_client', json=body, headers=headers) + headers = {"Authorization": "bearer abc"} + rv = self.client.post("/create_client", json=body, headers=headers) resp = rv.json() - self.assertIn('client_id', resp) - self.assertEqual(resp['client_name'], 'Authlib') + self.assertIn("client_id", resp) + self.assertEqual(resp["client_name"], "Authlib") def test_no_public_key(self): - class ClientRegistrationEndpoint2(ClientRegistrationEndpoint): def resolve_public_key(self, request): return None - payload = {'software_id': 'uuid-123', 'client_name': 'Authlib'} - s = jwt.encode({'alg': 'RS256'}, payload, read_file_path('rsa_private.pem')) + payload = {"software_id": "uuid-123", "client_name": "Authlib"} + s = jwt.encode({"alg": "RS256"}, payload, read_file_path("rsa_private.pem")) body = { - 'software_statement': s.decode('utf-8'), + "software_statement": s.decode("utf-8"), } self.prepare_data(ClientRegistrationEndpoint2) - headers = {'Authorization': 'bearer abc'} - rv = self.client.post('/create_client', json=body, headers=headers) + headers = {"Authorization": "bearer abc"} + rv = self.client.post("/create_client", json=body, headers=headers) resp = rv.json() - self.assertIn(resp['error'], 'unapproved_software_statement') + self.assertIn(resp["error"], "unapproved_software_statement") def test_scopes_supported(self): - metadata = {'scopes_supported': ['profile', 'email']} + metadata = {"scopes_supported": ["profile", "email"]} self.prepare_data(metadata=metadata) - headers = {'Authorization': 'bearer abc'} - body = {'scope': 'profile email', 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) + headers = {"Authorization": "bearer abc"} + body = {"scope": "profile email", "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) resp = rv.json() - self.assertIn('client_id', resp) - self.assertEqual(resp['client_name'], 'Authlib') + self.assertIn("client_id", resp) + self.assertEqual(resp["client_name"], "Authlib") - body = {'scope': 'profile email address', 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) + body = {"scope": "profile email address", "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) resp = rv.json() - self.assertIn(resp['error'], 'invalid_client_metadata') + self.assertIn(resp["error"], "invalid_client_metadata") def test_response_types_supported(self): - metadata = {'response_types_supported': ['code']} + metadata = {"response_types_supported": ["code"]} self.prepare_data(metadata=metadata) - headers = {'Authorization': 'bearer abc'} - body = {'response_types': ['code'], 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) + headers = {"Authorization": "bearer abc"} + body = {"response_types": ["code"], "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) resp = rv.json() - self.assertIn('client_id', resp) - self.assertEqual(resp['client_name'], 'Authlib') + self.assertIn("client_id", resp) + self.assertEqual(resp["client_name"], "Authlib") - body = {'response_types': ['code', 'token'], 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) + body = {"response_types": ["code", "token"], "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) resp = rv.json() - self.assertIn(resp['error'], 'invalid_client_metadata') + self.assertIn(resp["error"], "invalid_client_metadata") def test_grant_types_supported(self): - metadata = {'grant_types_supported': ['authorization_code', 'password']} + metadata = {"grant_types_supported": ["authorization_code", "password"]} self.prepare_data(metadata=metadata) - headers = {'Authorization': 'bearer abc'} - body = {'grant_types': ['password'], 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) + headers = {"Authorization": "bearer abc"} + body = {"grant_types": ["password"], "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) resp = rv.json() - self.assertIn('client_id', resp) - self.assertEqual(resp['client_name'], 'Authlib') + self.assertIn("client_id", resp) + self.assertEqual(resp["client_name"], "Authlib") - body = {'grant_types': ['client_credentials'], 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) + body = {"grant_types": ["client_credentials"], "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) resp = rv.json() - self.assertIn(resp['error'], 'invalid_client_metadata') + self.assertIn(resp["error"], "invalid_client_metadata") def test_token_endpoint_auth_methods_supported(self): - metadata = {'token_endpoint_auth_methods_supported': ['client_secret_basic']} + metadata = {"token_endpoint_auth_methods_supported": ["client_secret_basic"]} self.prepare_data(metadata=metadata) - headers = {'Authorization': 'bearer abc'} - body = {'token_endpoint_auth_method': 'client_secret_basic', 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) + headers = {"Authorization": "bearer abc"} + body = { + "token_endpoint_auth_method": "client_secret_basic", + "client_name": "Authlib", + } + rv = self.client.post("/create_client", json=body, headers=headers) resp = rv.json() - self.assertIn('client_id', resp) - self.assertEqual(resp['client_name'], 'Authlib') + self.assertIn("client_id", resp) + self.assertEqual(resp["client_name"], "Authlib") - body = {'token_endpoint_auth_method': 'none', 'client_name': 'Authlib'} - rv = self.client.post('/create_client', json=body, headers=headers) + body = {"token_endpoint_auth_method": "none", "client_name": "Authlib"} + rv = self.client.post("/create_client", json=body, headers=headers) resp = rv.json() - self.assertIn(resp['error'], 'invalid_client_metadata') + self.assertIn(resp["error"], "invalid_client_metadata") diff --git a/tests/fastapi/test_oauth2/test_device_code_grant.py b/tests/fastapi/test_oauth2/test_device_code_grant.py index bd5518b1..45729625 100644 --- a/tests/fastapi/test_oauth2/test_device_code_grant.py +++ b/tests/fastapi/test_oauth2/test_device_code_grant.py @@ -1,48 +1,48 @@ import time -from fastapi import Request, Form -from authlib.oauth2.rfc8628 import ( - DeviceAuthorizationEndpoint as _DeviceAuthorizationEndpoint, - DeviceCodeGrant as _DeviceCodeGrant, - DeviceCredentialDict, -) -from .models import db, User, Client -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server +from authlib.oauth2.rfc8628 import \ + DeviceAuthorizationEndpoint as _DeviceAuthorizationEndpoint +from authlib.oauth2.rfc8628 import DeviceCodeGrant as _DeviceCodeGrant +from authlib.oauth2.rfc8628 import DeviceCredentialDict +from fastapi import Form, Request + +from .models import Client, User, db +from .oauth2_server import TestCase, create_authorization_server device_credentials = { - 'valid-device': { - 'client_id': 'client', - 'expires_in': 1800, - 'user_code': 'code', + "valid-device": { + "client_id": "client", + "expires_in": 1800, + "user_code": "code", + }, + "expired-token": { + "client_id": "client", + "expires_in": -100, + "user_code": "none", }, - 'expired-token': { - 'client_id': 'client', - 'expires_in': -100, - 'user_code': 'none', + "invalid-client": { + "client_id": "invalid", + "expires_in": 1800, + "user_code": "none", }, - 'invalid-client': { - 'client_id': 'invalid', - 'expires_in': 1800, - 'user_code': 'none', + "denied-code": { + "client_id": "client", + "expires_in": 1800, + "user_code": "denied", }, - 'denied-code': { - 'client_id': 'client', - 'expires_in': 1800, - 'user_code': 'denied', + "grant-code": { + "client_id": "client", + "expires_in": 1800, + "user_code": "code", }, - 'grant-code': { - 'client_id': 'client', - 'expires_in': 1800, - 'user_code': 'code', + "pending-code": { + "client_id": "client", + "expires_in": 1800, + "user_code": "none", }, - 'pending-code': { - 'client_id': 'client', - 'expires_in': 1800, - 'user_code': 'none', - } } + class DeviceCodeGrant(_DeviceCodeGrant): def query_device_credential(self, device_code): data = device_credentials.get(device_code) @@ -50,21 +50,21 @@ def query_device_credential(self, device_code): return None now = int(time.time()) - data['expires_at'] = now + data['expires_in'] - data['device_code'] = device_code - data['scope'] = 'profile' - data['interval'] = 5 - data['verification_uri'] = 'https://example.com/activate' + data["expires_at"] = now + data["expires_in"] + data["device_code"] = device_code + data["scope"] = "profile" + data["interval"] = 5 + data["verification_uri"] = "https://example.com/activate" return DeviceCredentialDict(data) def query_user_grant(self, user_code): - if user_code == 'code': + if user_code == "code": return db.query(User).filter(User.id == 1).first(), True - if user_code == 'denied': + if user_code == "denied": return db.query(User).filter(User.id == 1).first(), False return None - def should_slow_down(self, credential, now): + def should_slow_down(self, credential): return False @@ -76,124 +76,148 @@ def create_server(self): return server def prepare_data(self, grant_type=DeviceCodeGrant.GRANT_TYPE): - user = User(username='foo') + user = User(username="foo") db.add(user) db.commit() client = Client( user_id=user.id, - client_id='client', - client_secret='secret', + client_id="client", + client_secret="secret", + ) + client.set_client_metadata( + { + "redirect_uris": ["http://localhost/authorized"], + "scope": "profile", + "grant_types": [grant_type], + "token_endpoint_auth_method": "none", + } ) - client.set_client_metadata({ - 'redirect_uris': ['http://localhost/authorized'], - 'scope': 'profile', - 'grant_types': [grant_type], - }) db.add(client) db.commit() def test_invalid_request(self): self.create_server() self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - }) - resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'valid-device', - }) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "client_id": "client", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') + self.assertEqual(resp["error"], "invalid_request") - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'missing', - 'client_id': 'client', - }) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "missing", + "client_id": "client", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') + self.assertEqual(resp["error"], "invalid_request") def test_unauthorized_client(self): self.create_server() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'valid-device', - 'client_id': 'invalid', - }) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "valid-device", + "client_id": "invalid", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'unauthorized_client') - - self.prepare_data(grant_type='password') - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'valid-device', - 'client_id': 'client', - }) + self.assertEqual(resp["error"], "invalid_client") + + self.prepare_data(grant_type="password") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "valid-device", + "client_id": "client", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'unauthorized_client') + self.assertEqual(resp["error"], "unauthorized_client") def test_invalid_client(self): self.create_server() self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'invalid-client', - 'client_id': 'invalid', - }) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "invalid-client", + "client_id": "invalid", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_client') + self.assertEqual(resp["error"], "invalid_client") def test_expired_token(self): self.create_server() self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'expired-token', - 'client_id': 'client', - }) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "expired-token", + "client_id": "client", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'expired_token') + self.assertEqual(resp["error"], "expired_token") def test_denied_by_user(self): self.create_server() self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'denied-code', - 'client_id': 'client', - }) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "denied-code", + "client_id": "client", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'access_denied') + self.assertEqual(resp["error"], "access_denied") def test_authorization_pending(self): self.create_server() self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'pending-code', - 'client_id': 'client', - }) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "pending-code", + "client_id": "client", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'authorization_pending') + self.assertEqual(resp["error"], "authorization_pending") def test_get_access_token(self): self.create_server() self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': DeviceCodeGrant.GRANT_TYPE, - 'device_code': 'grant-code', - 'client_id': 'client', - }) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": DeviceCodeGrant.GRANT_TYPE, + "device_code": "grant-code", + "client_id": "client", + }, + ) resp = rv.json() - self.assertIn('access_token', resp) + self.assertIn("access_token", resp) class DeviceAuthorizationEndpoint(_DeviceAuthorizationEndpoint): def get_verification_uri(self): - return 'https://example.com/activate' + return "https://example.com/activate" def save_device_credential(self, client_id, scope, data): pass @@ -205,13 +229,13 @@ def create_server(self): server.register_endpoint(DeviceAuthorizationEndpoint) self.server = server - @self.app.post('/device_authorize') - def device_authorize(request: Request, - scope: str = Form(None), - client_id: str = Form(None)): + @self.app.post("/device_authorize") + def device_authorize( + request: Request, scope: str = Form(None), client_id: str = Form(None) + ): request.body = { - 'scope': scope, - 'client_id': client_id, + "scope": scope, + "client_id": client_id, } name = DeviceAuthorizationEndpoint.ENDPOINT_NAME return server.create_endpoint_response(name, request=request) @@ -220,24 +244,32 @@ def device_authorize(request: Request, def test_missing_client_id(self): self.create_server() - rv = self.client.post('/device_authorize', data={ - 'scope': 'profile' - }) - self.assertEqual(rv.status_code, 400) + rv = self.client.post("/device_authorize", data={"scope": "profile"}) + self.assertEqual(rv.status_code, 401) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') + self.assertEqual(resp["error"], "invalid_client") def test_create_authorization_response(self): self.create_server() - rv = self.client.post('/device_authorize', data={ - 'client_id': 'client', - }) + client = Client( + user_id=1, + client_id="client", + client_secret="secret", + ) + db.add(client) + db.commit() + rv = self.client.post( + "/device_authorize", + data={ + "client_id": "client", + }, + ) self.assertEqual(rv.status_code, 200) resp = rv.json() - self.assertIn('device_code', resp) - self.assertIn('user_code', resp) - self.assertEqual(resp['verification_uri'], 'https://example.com/activate') + self.assertIn("device_code", resp) + self.assertIn("user_code", resp) + self.assertEqual(resp["verification_uri"], "https://example.com/activate") self.assertEqual( - resp['verification_uri_complete'], - 'https://example.com/activate?user_code=' + resp['user_code'] + resp["verification_uri_complete"], + "https://example.com/activate?user_code=" + resp["user_code"], ) diff --git a/tests/fastapi/test_oauth2/test_implicit_grant.py b/tests/fastapi/test_oauth2/test_implicit_grant.py index 4653edb1..b875cb6d 100644 --- a/tests/fastapi/test_oauth2/test_implicit_grant.py +++ b/tests/fastapi/test_oauth2/test_implicit_grant.py @@ -1,40 +1,41 @@ from authlib.oauth2.rfc6749.grants import ImplicitGrant -from .models import db, User, Client -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server + +from .models import Client, User, db +from .oauth2_server import TestCase, create_authorization_server class ImplicitTest(TestCase): - def prepare_data(self, is_confidential=False, response_type='token'): + def prepare_data(self, is_confidential=False, response_type="token"): server = create_authorization_server(self.app) server.register_grant(ImplicitGrant) self.server = server - user = User(username='foo') + user = User(username="foo") db.add(user) db.commit() if is_confidential: - client_secret = 'implicit-secret' - token_endpoint_auth_method = 'client_secret_basic' + client_secret = "implicit-secret" + token_endpoint_auth_method = "client_secret_basic" else: - client_secret = '' - token_endpoint_auth_method = 'none' + client_secret = "" + token_endpoint_auth_method = "none" client = Client( user_id=user.id, - client_id='implicit-client', + client_id="implicit-client", client_secret=client_secret, ) - client.set_client_metadata({ - 'redirect_uris': ['http://localhost/authorized'], - 'scope': 'profile', - 'response_types': [response_type], - 'grant_types': ['implicit'], - 'token_endpoint_auth_method': token_endpoint_auth_method, - }) + client.set_client_metadata( + { + "redirect_uris": ["http://localhost/authorized"], + "scope": "profile", + "response_types": [response_type], + "grant_types": ["implicit"], + "token_endpoint_auth_method": token_endpoint_auth_method, + } + ) self.authorize_url = ( - '/oauth/authorize?response_type=token' - '&client_id=implicit-client' + "/oauth/authorize?response_type=token" "&client_id=implicit-client" ) db.add(client) db.commit() @@ -42,41 +43,41 @@ def prepare_data(self, is_confidential=False, response_type='token'): def test_get_authorize(self): self.prepare_data() rv = self.client.get(self.authorize_url) - self.assertEqual(rv.json(), 'ok') + self.assertEqual(rv.json(), "ok") def test_confidential_client(self): self.prepare_data(True) rv = self.client.get(self.authorize_url) - self.assertIn('invalid_client', rv.json()) + self.assertIn("invalid_client", rv.json()) def test_unsupported_client(self): - self.prepare_data(response_type='code') + self.prepare_data(response_type="code") rv = self.client.get(self.authorize_url) - self.assertIn('unauthorized_client', rv.json()) + self.assertIn("unauthorized_client", rv.json()) def test_invalid_authorize(self): self.prepare_data() rv = self.client.post(self.authorize_url) - self.assertIn('#error=access_denied', rv.headers['location']) + self.assertIn("#error=access_denied", rv.headers["location"]) - self.server.metadata = {'scopes_supported': ['profile']} - rv = self.client.post(self.authorize_url + '&scope=invalid') - self.assertIn('#error=invalid_scope', rv.headers['location']) + self.server.scopes_supported = ["profile"] + rv = self.client.post(self.authorize_url + "&scope=invalid") + self.assertIn("#error=invalid_scope", rv.headers["location"]) def test_authorize_token(self): self.prepare_data() - rv = self.client.post(self.authorize_url, data={'user_id': '1'}) - self.assertIn('access_token=', rv.headers['location']) + rv = self.client.post(self.authorize_url, data={"user_id": "1"}) + self.assertIn("access_token=", rv.headers["location"]) - url = self.authorize_url + '&state=bar&scope=profile' - rv = self.client.post(url, data={'user_id': '1'}) - self.assertIn('access_token=', rv.headers['location']) - self.assertIn('state=bar', rv.headers['location']) - self.assertIn('scope=profile', rv.headers['location']) + url = self.authorize_url + "&state=bar&scope=profile" + rv = self.client.post(url, data={"user_id": "1"}) + self.assertIn("access_token=", rv.headers["location"]) + self.assertIn("state=bar", rv.headers["location"]) + self.assertIn("scope=profile", rv.headers["location"]) def test_token_generator(self): - m = 'tests.fastapi.test_oauth2.oauth2_server:token_generator' - self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) + m = "tests.fastapi.test_oauth2.oauth2_server:token_generator" + self.app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) self.prepare_data() - rv = self.client.post(self.authorize_url, data={'user_id': '1'}) - self.assertIn('access_token=i-implicit.1.', rv.headers['location']) + rv = self.client.post(self.authorize_url, data={"user_id": "1"}) + self.assertIn("access_token=i-implicit.1.", rv.headers["location"]) diff --git a/tests/fastapi/test_oauth2/test_introspection_endpoint.py b/tests/fastapi/test_oauth2/test_introspection_endpoint.py index 0e264fe2..2a9c3cca 100644 --- a/tests/fastapi/test_oauth2/test_introspection_endpoint.py +++ b/tests/fastapi/test_oauth2/test_introspection_endpoint.py @@ -1,17 +1,19 @@ -from fastapi import Request, Form from authlib.integrations.sqla_oauth2 import create_query_token_func from authlib.oauth2.rfc7662 import IntrospectionEndpoint -from .models import db, User, Client, Token -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server +from fastapi import Form, Request +from .models import Client, Token, User, db +from .oauth2_server import TestCase, create_authorization_server query_token = create_query_token_func(db, Token) class MyIntrospectionEndpoint(IntrospectionEndpoint): - def query_token(self, token, token_type_hint, client): - return query_token(token, token_type_hint, client) + def check_permission(self, token, client, request): + return True + + def query_token(self, token, token_type_hint): + return query_token(token, token_type_hint) def introspect_token(self, token): user = db.query(User).filter(User.id == int(token.user_id)).first() @@ -23,7 +25,7 @@ def introspect_token(self, token): "sub": user.get_user_id(), "aud": token.client_id, "iss": "https://server.example.com/", - "exp": token.get_expires_at(), + "exp": token.issued_at + token.expires_in, "iat": token.issued_at, } @@ -35,131 +37,148 @@ def prepare_data(self): server = create_authorization_server(app) server.register_endpoint(MyIntrospectionEndpoint) - @app.post('/oauth/introspect') - def introspect_token(request: Request, - token: str = Form(None), - token_type_hint: str = Form(None)): + @app.post("/oauth/introspect") + def introspect_token( + request: Request, token: str = Form(None), token_type_hint: str = Form(None) + ): request.body = {} if token: - request.body.update({'token': token}) + request.body.update({"token": token}) if token_type_hint: - request.body.update({'token_type_hint': token_type_hint}) + request.body.update({"token_type_hint": token_type_hint}) - return server.create_endpoint_response('introspection', request=request) + return server.create_endpoint_response("introspection", request=request) - user = User(username='foo') + user = User(username="foo") db.add(user) db.commit() client = Client( user_id=user.id, - client_id='introspect-client', - client_secret='introspect-secret', + client_id="introspect-client", + client_secret="introspect-secret", + ) + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["http://a.b/c"], + } ) - client.set_client_metadata({ - 'scope': 'profile', - 'redirect_uris': ['http://a.b/c'], - }) db.add(client) db.commit() def create_token(self): token = Token( user_id=1, - client_id='introspect-client', - token_type='bearer', - access_token='a1', - refresh_token='r1', - scope='profile', + client_id="introspect-client", + token_type="bearer", + access_token="a1", + refresh_token="r1", + scope="profile", expires_in=3600, + revoked=False, ) db.add(token) db.commit() def test_invalid_client(self): self.prepare_data() - rv = self.client.post('/oauth/introspect') + rv = self.client.post("/oauth/introspect") resp = rv.json() - self.assertEqual(resp['error'], 'invalid_client') + self.assertEqual(resp["error"], "invalid_client") - headers = {'Authorization': 'invalid token_string'} - rv = self.client.post('/oauth/introspect', headers=headers) + headers = {"Authorization": "invalid token_string"} + rv = self.client.post("/oauth/introspect", headers=headers) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_client') + self.assertEqual(resp["error"], "invalid_client") - headers = self.create_basic_header( - 'invalid-client', 'introspect-secret' - ) - rv = self.client.post('/oauth/introspect', headers=headers) + headers = self.create_basic_header("invalid-client", "introspect-secret") + rv = self.client.post("/oauth/introspect", headers=headers) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_client') + self.assertEqual(resp["error"], "invalid_client") - headers = self.create_basic_header( - 'introspect-client', 'invalid-secret' - ) - rv = self.client.post('/oauth/introspect', headers=headers) + headers = self.create_basic_header("introspect-client", "invalid-secret") + rv = self.client.post("/oauth/introspect", headers=headers) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_client') + self.assertEqual(resp["error"], "invalid_client") def test_invalid_token(self): self.prepare_data() - headers = self.create_basic_header( - 'introspect-client', 'introspect-secret' - ) - rv = self.client.post('/oauth/introspect', headers=headers) + headers = self.create_basic_header("introspect-client", "introspect-secret") + rv = self.client.post("/oauth/introspect", headers=headers) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/introspect', data={ - 'token_type_hint': 'refresh_token', - }, headers=headers) + self.assertEqual(resp["error"], "invalid_request") + + rv = self.client.post( + "/oauth/introspect", + data={ + "token_type_hint": "refresh_token", + }, + headers=headers, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/introspect', data={ - 'token': 'a1', - 'token_type_hint': 'unsupported_token_type', - }, headers=headers) + self.assertEqual(resp["error"], "invalid_request") + + rv = self.client.post( + "/oauth/introspect", + data={ + "token": "a1", + "token_type_hint": "unsupported_token_type", + }, + headers=headers, + ) resp = rv.json() - self.assertEqual(resp['error'], 'unsupported_token_type') - - rv = self.client.post('/oauth/introspect', data={ - 'token': 'invalid-token', - }, headers=headers) + self.assertEqual(resp["error"], "unsupported_token_type") + + rv = self.client.post( + "/oauth/introspect", + data={ + "token": "invalid-token", + }, + headers=headers, + ) resp = rv.json() - self.assertEqual(resp['active'], False) - - rv = self.client.post('/oauth/introspect', data={ - 'token': 'a1', - 'token_type_hint': 'refresh_token', - }, headers=headers) + self.assertEqual(resp["active"], False) + + rv = self.client.post( + "/oauth/introspect", + data={ + "token": "a1", + "token_type_hint": "refresh_token", + }, + headers=headers, + ) resp = rv.json() - self.assertEqual(resp['active'], False) + self.assertEqual(resp["active"], False) def test_introspect_token_with_hint(self): self.prepare_data() self.create_token() - headers = self.create_basic_header( - 'introspect-client', 'introspect-secret' + headers = self.create_basic_header("introspect-client", "introspect-secret") + rv = self.client.post( + "/oauth/introspect", + data={ + "token": "a1", + "token_type_hint": "access_token", + }, + headers=headers, ) - rv = self.client.post('/oauth/introspect', data={ - 'token': 'a1', - 'token_type_hint': 'access_token', - }, headers=headers) self.assertEqual(rv.status_code, 200) resp = rv.json() - self.assertEqual(resp['client_id'], 'introspect-client') + self.assertEqual(resp["client_id"], "introspect-client") def test_introspect_token_without_hint(self): self.prepare_data() self.create_token() - headers = self.create_basic_header( - 'introspect-client', 'introspect-secret' + headers = self.create_basic_header("introspect-client", "introspect-secret") + rv = self.client.post( + "/oauth/introspect", + data={ + "token": "a1", + }, + headers=headers, ) - rv = self.client.post('/oauth/introspect', data={ - 'token': 'a1', - }, headers=headers) self.assertEqual(rv.status_code, 200) resp = rv.json() - self.assertEqual(resp['client_id'], 'introspect-client') + self.assertEqual(resp["client_id"], "introspect-client") diff --git a/tests/fastapi/test_oauth2/test_oauth2_server.py b/tests/fastapi/test_oauth2/test_oauth2_server.py index 99fd4b9d..e1b54534 100644 --- a/tests/fastapi/test_oauth2/test_oauth2_server.py +++ b/tests/fastapi/test_oauth2/test_oauth2_server.py @@ -1,9 +1,9 @@ -from fastapi import Request from authlib.integrations.fastapi_oauth2 import ResourceProtector from authlib.integrations.sqla_oauth2 import create_bearer_token_validator -from .models import db, User, Client, Token -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server +from fastapi import Request + +from .models import Client, Token, User, db +from .oauth2_server import TestCase, create_authorization_server require_oauth = ResourceProtector() BearerTokenValidator = create_bearer_token_validator(db, Token) @@ -11,173 +11,168 @@ def create_resource_server(app): - @app.get('/user') + @app.get("/user") + @require_oauth(["profile"]) def user_profile(request: Request): - with require_oauth.acquire(request, 'profile') as token: - user = token.user - return {'id': user.id, 'username': user.username} + user = request.state.token.user + return {"id": user.id, "username": user.username} - @app.get('/user/email') + @app.get("/user/email") + @require_oauth("email") def user_email(request: Request): - with require_oauth.acquire(request, 'email') as token: - user = token.user - return {'email': user.username + '@example.com'} + pass - @app.get('/info') + @app.get("/info") + @require_oauth() def public_info(request: Request): - with require_oauth.acquire(request) as token: - return {'status': 'ok'} + return {"status": "ok"} - @app.get('/operator-and') + @app.get("/operator-and") + @require_oauth(["profile email"]) def operator_and(request: Request): - with require_oauth.acquire(request, 'profile email', 'AND') as token: - return {'status': 'ok'} + return {"status": "ok"} - @app.get('/operator-or') + @app.get("/operator-or") + @require_oauth(["profile", "email"]) def operator_or(request: Request): - with require_oauth.acquire(request, 'profile email', 'OR') as token: - return {'status': 'ok'} - - def scope_operator(token_scopes, resource_scopes): - return 'profile' in token_scopes and 'email' not in token_scopes - - @app.get('/operator-func') - def operator_func(request: Request): - with require_oauth.acquire(request, operator=scope_operator) as token: - return {'status': 'ok'} + return {"status": "ok"} - @app.get('/acquire') + @app.get("/acquire") def test_acquire(request: Request): - with require_oauth.acquire(request, 'profile') as token: + with require_oauth.acquire(request, ["profile"]) as token: user = token.user - return {'id': user.id, 'username': user.username} + return {"id": user.id, "username": user.username} class AuthorizationTest(TestCase): def test_none_grant(self): create_authorization_server(self.app) authorize_url = ( - '/oauth/authorize?response_type=token' - '&client_id=implicit-client' + "/oauth/authorize?response_type=token" "&client_id=implicit-client" ) rv = self.client.get(authorize_url) - self.assertIn('invalid_grant', rv.json()) + self.assertIn("unsupported_response_type", rv.text) - rv = self.client.post(authorize_url, data={'user_id': '1'}) + rv = self.client.post(authorize_url, data={"user_id": "1"}) self.assertNotEqual(rv.status_code, 200) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'code': 'x', - }) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "code": "x", + }, + ) data = rv.json() - self.assertEqual(data['error'], 'unsupported_grant_type') + self.assertEqual(data["error"], "unsupported_grant_type") class ResourceTest(TestCase): def prepare_data(self): create_resource_server(self.app) - user = User(username='foo') + user = User(username="foo") db.add(user) db.commit() client = Client( user_id=user.id, - client_id='resource-client', - client_secret='resource-secret', + client_id="resource-client", + client_secret="resource-secret", + ) + client.set_client_metadata( + { + "scope": "profile", + "redirect_uris": ["http://localhost/authorized"], + } ) - client.set_client_metadata({ - 'scope': 'profile', - 'redirect_uris': ['http://localhost/authorized'], - }) db.add(client) db.commit() def create_token(self, expires_in=3600): token = Token( user_id=1, - client_id='resource-client', - token_type='bearer', - access_token='a1', - scope='profile', + client_id="resource-client", + token_type="bearer", + access_token="a1", + scope="profile", expires_in=expires_in, ) db.add(token) db.commit() def create_bearer_header(self, token): - return {'Authorization': 'Bearer ' + token} + return {"Authorization": "Bearer " + token} def test_invalid_token(self): self.prepare_data() - rv = self.client.get('/user') + rv = self.client.get("/user") self.assertEqual(rv.status_code, 401) resp = rv.json() - self.assertEqual(resp['detail']['error'], 'missing_authorization') + self.assertEqual(resp["detail"]["error"], "missing_authorization") - headers = {'Authorization': 'invalid token'} - rv = self.client.get('/user', headers=headers) + headers = {"Authorization": "invalid token"} + rv = self.client.get("/user", headers=headers) self.assertEqual(rv.status_code, 401) resp = rv.json() - self.assertEqual(resp['detail']['error'], 'unsupported_token_type') + self.assertEqual(resp["detail"]["error"], "unsupported_token_type") - headers = self.create_bearer_header('invalid') - rv = self.client.get('/user', headers=headers) + headers = self.create_bearer_header("invalid") + rv = self.client.get("/user", headers=headers) self.assertEqual(rv.status_code, 401) resp = rv.json() - self.assertEqual(resp['detail']['error'], 'invalid_token') + self.assertEqual(resp["detail"]["error"], "invalid_token") def test_expired_token(self): self.prepare_data() - self.create_token(0) - headers = self.create_bearer_header('a1') + self.create_token(-10) + headers = self.create_bearer_header("a1") - rv = self.client.get('/user', headers=headers) + rv = self.client.get("/user", headers=headers) self.assertEqual(rv.status_code, 401) resp = rv.json() - self.assertEqual(resp['detail']['error'], 'invalid_token') + self.assertEqual(resp["detail"]["error"], "invalid_token") - rv = self.client.get('/acquire', headers=headers) + rv = self.client.get("/acquire", headers=headers) self.assertEqual(rv.status_code, 401) def test_insufficient_token(self): self.prepare_data() self.create_token() - headers = self.create_bearer_header('a1') - rv = self.client.get('/user/email', headers=headers) + headers = self.create_bearer_header("a1") + rv = self.client.get("/user/email", headers=headers) self.assertEqual(rv.status_code, 403) resp = rv.json() - self.assertEqual(resp['detail']['error'], 'insufficient_scope') + self.assertEqual(resp["detail"]["error"], "insufficient_scope") def test_access_resource(self): self.prepare_data() self.create_token() - headers = self.create_bearer_header('a1') + headers = self.create_bearer_header("a1") - rv = self.client.get('/user', headers=headers) + rv = self.client.get("/user", headers=headers) resp = rv.json() - self.assertEqual(resp['username'], 'foo') + self.assertEqual(rv.status_code, 200) + self.assertEqual(resp["username"], "foo") - rv = self.client.get('/acquire', headers=headers) + rv = self.client.get("/acquire", headers=headers) resp = rv.json() - self.assertEqual(resp['username'], 'foo') + self.assertEqual(rv.status_code, 200) + self.assertEqual(resp["username"], "foo") - rv = self.client.get('/info', headers=headers) + rv = self.client.get("/info", headers=headers) resp = rv.json() - self.assertEqual(resp['status'], 'ok') + self.assertEqual(rv.status_code, 200) + self.assertEqual(resp["status"], "ok") def test_scope_operator(self): self.prepare_data() self.create_token() - headers = self.create_bearer_header('a1') - rv = self.client.get('/operator-and', headers=headers) + headers = self.create_bearer_header("a1") + rv = self.client.get("/operator-and", headers=headers) self.assertEqual(rv.status_code, 403) resp = rv.json() - self.assertEqual(resp['detail']['error'], 'insufficient_scope') - - rv = self.client.get('/operator-or', headers=headers) - self.assertEqual(rv.status_code, 200) + self.assertEqual(resp["detail"]["error"], "insufficient_scope") - rv = self.client.get('/operator-func', headers=headers) + rv = self.client.get("/operator-or", headers=headers) self.assertEqual(rv.status_code, 200) diff --git a/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py b/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py index e6bdbb60..3050504d 100644 --- a/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py +++ b/tests/fastapi/test_oauth2/test_openid_hybrid_grant.py @@ -1,19 +1,16 @@ -from authlib.common.urls import urlparse, url_decode -from authlib.jose import JWT +from authlib.common.urls import url_decode, urlparse +from authlib.jose import jwt +from authlib.oauth2.rfc6749.grants import \ + AuthorizationCodeGrant as _AuthorizationCodeGrant from authlib.oidc.core import HybridIDToken -from authlib.oidc.core.grants import ( - OpenIDCode as _OpenIDCode, - OpenIDHybridGrant as _OpenIDHybridGrant, -) -from authlib.oauth2.rfc6749.grants import ( - AuthorizationCodeGrant as _AuthorizationCodeGrant, -) -from .models import db, User, Client, exists_nonce -from .models import CodeGrantMixin, save_authorization_code -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server +from authlib.oidc.core.grants import OpenIDCode as _OpenIDCode +from authlib.oidc.core.grants import OpenIDHybridGrant as _OpenIDHybridGrant -JWT_CONFIG = {'iss': 'Authlib', 'key': 'secret', 'alg': 'HS256', 'exp': 3600} +from .models import (Client, CodeGrantMixin, User, db, exists_nonce, + save_authorization_code) +from .oauth2_server import TestCase, create_authorization_server + +JWT_CONFIG = {"iss": "Authlib", "key": "secret", "alg": "HS256", "exp": 3600} class AuthorizationCodeGrant(CodeGrantMixin, _AuthorizationCodeGrant): @@ -52,233 +49,281 @@ def prepare_data(self): server.register_grant(OpenIDHybridGrant) server.register_grant(AuthorizationCodeGrant, [OpenIDCode()]) - user = User(username='foo') + user = User(username="foo") db.add(user) db.commit() client = Client( user_id=user.id, - client_id='hybrid-client', - client_secret='hybrid-secret', + client_id="hybrid-client", + client_secret="hybrid-secret", + ) + client.set_client_metadata( + { + "redirect_uris": ["https://a.b"], + "scope": "openid profile address", + "response_types": [ + "code id_token", + "code token", + "code id_token token", + ], + "grant_types": ["authorization_code"], + } ) - client.set_client_metadata({ - 'redirect_uris': ['https://a.b'], - 'scope': 'openid profile address', - 'response_types': ['code id_token', 'code token', 'code id_token token'], - 'grant_types': ['authorization_code'], - }) db.add(client) db.commit() def validate_claims(self, id_token, params): - jwt = JWT() claims = jwt.decode( - id_token, 'secret', - claims_cls=HybridIDToken, - claims_params=params + id_token, "secret", claims_cls=HybridIDToken, claims_params=params ) claims.validate() def test_invalid_client_id(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'code token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "code token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_client') - - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'invalid-client', - 'response_type': 'code token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) + self.assertEqual(resp["error"], "invalid_client") + + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "invalid-client", + "response_type": "code token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_client') + self.assertEqual(resp["error"], "invalid_client") def test_require_nonce(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code token', - 'scope': 'openid profile', - 'state': 'bar', - 'redirect_uri': 'https://a.b', - 'user_id': '1' - }) - self.assertIn('error=invalid_request', rv.headers['location']) - self.assertIn('nonce', rv.headers['location']) + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code token", + "scope": "openid profile", + "state": "bar", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + self.assertIn("error=invalid_request", rv.headers["location"]) + self.assertIn("nonce", rv.headers["location"]) def test_invalid_response_type(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code id_token invalid', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code id_token invalid", + "state": "bar", + "nonce": "abc", + "scope": "profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_grant') + self.assertEqual(resp["error"], "unsupported_response_type") def test_invalid_scope(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code id_token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - self.assertIn('error=invalid_scope', rv.headers['location']) + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code id_token", + "state": "bar", + "nonce": "abc", + "scope": "profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + self.assertIn("error=invalid_scope", rv.headers["location"]) def test_access_denied(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - }) - self.assertIn('error=access_denied', rv.headers['location']) + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + }, + ) + self.assertIn("error=access_denied", rv.headers["location"]) def test_code_access_token(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - self.assertIn('code=', rv.headers['location']) - self.assertIn('access_token=', rv.headers['location']) - self.assertNotIn('id_token=', rv.headers['location']) - - params = dict(url_decode(urlparse.urlparse(rv.headers['location']).fragment)) - self.assertEqual(params['state'], 'bar') - - code = params['code'] - headers = self.create_basic_header('hybrid-client', 'hybrid-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'redirect_uri': 'https://a.b', - 'code': code, - }, headers=headers) + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + self.assertIn("code=", rv.headers["location"]) + self.assertIn("access_token=", rv.headers["location"]) + self.assertNotIn("id_token=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).fragment)) + self.assertEqual(params["state"], "bar") + + code = params["code"] + headers = self.create_basic_header("hybrid-client", "hybrid-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://a.b", + "code": code, + }, + headers=headers, + ) resp = rv.json() - self.assertIn('access_token', resp) - self.assertIn('id_token', resp) + self.assertIn("access_token", resp) + self.assertIn("id_token", resp) def test_code_id_token(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code id_token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - self.assertIn('code=', rv.headers['location']) - self.assertIn('id_token=', rv.headers['location']) - self.assertNotIn('access_token=', rv.headers['location']) - - params = dict(url_decode(urlparse.urlparse(rv.headers['location']).fragment)) - self.assertEqual(params['state'], 'bar') - - params['nonce'] = 'abc' - params['client_id'] = 'hybrid-client' - self.validate_claims(params['id_token'], params) - - code = params['code'] - headers = self.create_basic_header('hybrid-client', 'hybrid-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'redirect_uri': 'https://a.b', - 'code': code, - }, headers=headers) + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code id_token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + self.assertIn("code=", rv.headers["location"]) + self.assertIn("id_token=", rv.headers["location"]) + self.assertNotIn("access_token=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).fragment)) + self.assertEqual(params["state"], "bar") + + params["nonce"] = "abc" + params["client_id"] = "hybrid-client" + self.validate_claims(params["id_token"], params) + + code = params["code"] + headers = self.create_basic_header("hybrid-client", "hybrid-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://a.b", + "code": code, + }, + headers=headers, + ) resp = rv.json() - self.assertIn('access_token', resp) - self.assertIn('id_token', resp) + self.assertIn("access_token", resp) + self.assertIn("id_token", resp) def test_code_id_token_access_token(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code id_token token', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - self.assertIn('code=', rv.headers['location']) - self.assertIn('id_token=', rv.headers['location']) - self.assertIn('access_token=', rv.headers['location']) - - params = dict(url_decode(urlparse.urlparse(rv.headers['location']).fragment)) - self.assertEqual(params['state'], 'bar') - self.validate_claims(params['id_token'], params) - - code = params['code'] - headers = self.create_basic_header('hybrid-client', 'hybrid-secret') - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'authorization_code', - 'redirect_uri': 'https://a.b', - 'code': code, - }, headers=headers) + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code id_token token", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + self.assertIn("code=", rv.headers["location"]) + self.assertIn("id_token=", rv.headers["location"]) + self.assertIn("access_token=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).fragment)) + self.assertEqual(params["state"], "bar") + self.validate_claims(params["id_token"], params) + + code = params["code"] + headers = self.create_basic_header("hybrid-client", "hybrid-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "authorization_code", + "redirect_uri": "https://a.b", + "code": code, + }, + headers=headers, + ) resp = rv.json() - self.assertIn('access_token', resp) - self.assertIn('id_token', resp) + self.assertIn("access_token", resp) + self.assertIn("id_token", resp) def test_response_mode_query(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code id_token token', - 'response_mode': 'query', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) - self.assertIn('code=', rv.headers['location']) - self.assertIn('id_token=', rv.headers['location']) - self.assertIn('access_token=', rv.headers['location']) - - params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) - self.assertEqual(params['state'], 'bar') + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code id_token token", + "response_mode": "query", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) + self.assertIn("code=", rv.headers["location"]) + self.assertIn("id_token=", rv.headers["location"]) + self.assertIn("access_token=", rv.headers["location"]) + + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + self.assertEqual(params["state"], "bar") def test_response_mode_form_post(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'client_id': 'hybrid-client', - 'response_type': 'code id_token token', - 'response_mode': 'form_post', - 'state': 'bar', - 'nonce': 'abc', - 'scope': 'openid profile', - 'redirect_uri': 'https://a.b', - 'user_id': '1', - }) + rv = self.client.post( + "/oauth/authorize", + data={ + "client_id": "hybrid-client", + "response_type": "code id_token token", + "response_mode": "form_post", + "state": "bar", + "nonce": "abc", + "scope": "openid profile", + "redirect_uri": "https://a.b", + "user_id": "1", + }, + ) resp = rv.json() self.assertIn('name="code"', resp) self.assertIn('name="id_token"', resp) diff --git a/tests/fastapi/test_oauth2/test_openid_implict_grant.py b/tests/fastapi/test_oauth2/test_openid_implict_grant.py index b6e0a953..67b2fbbd 100644 --- a/tests/fastapi/test_oauth2/test_openid_implict_grant.py +++ b/tests/fastapi/test_oauth2/test_openid_implict_grant.py @@ -1,17 +1,16 @@ -from authlib.jose import JWT +from authlib.common.urls import add_params_to_uri, url_decode, urlparse +from authlib.jose import jwt from authlib.oidc.core import ImplicitIDToken -from authlib.oidc.core.grants import ( +from authlib.oidc.core.grants import \ OpenIDImplicitGrant as _OpenIDImplicitGrant -) -from authlib.common.urls import urlparse, url_decode, add_params_to_uri -from .models import db, User, Client, exists_nonce -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server + +from .models import Client, User, db, exists_nonce +from .oauth2_server import TestCase, create_authorization_server class OpenIDImplicitGrant(_OpenIDImplicitGrant): def get_jwt_config(self): - return dict(key='secret', alg='HS256', iss='Authlib', exp=3600) + return dict(key="secret", alg="HS256", iss="Authlib", exp=3600) def generate_user_info(self, user, scopes): return user.generate_user_info(scopes) @@ -25,148 +24,172 @@ def prepare_data(self): server = create_authorization_server(self.app) server.register_grant(OpenIDImplicitGrant) - user = User(username='foo') + user = User(username="foo") db.add(user) db.commit() client = Client( user_id=user.id, - client_id='implicit-client', - client_secret='', + client_id="implicit-client", + client_secret="", + ) + client.set_client_metadata( + { + "redirect_uris": ["https://a.b/c"], + "scope": "openid profile", + "token_endpoint_auth_method": "none", + "response_types": ["id_token", "id_token token"], + } ) - client.set_client_metadata({ - 'redirect_uris': ['https://a.b/c'], - 'scope': 'openid profile', - 'token_endpoint_auth_method': 'none', - 'response_types': ['id_token', 'id_token token'], - }) self.authorize_url = ( - '/oauth/authorize?response_type=token' - '&client_id=implicit-client' + "/oauth/authorize?response_type=token" "&client_id=implicit-client" ) db.add(client) db.commit() def validate_claims(self, id_token, params): - jwt = JWT(['HS256']) claims = jwt.decode( - id_token, 'secret', - claims_cls=ImplicitIDToken, - claims_params=params + id_token, "secret", claims_cls=ImplicitIDToken, claims_params=params ) claims.validate() def test_consent_view(self): self.prepare_data() - rv = self.client.get(add_params_to_uri('/oauth/authorize', { - 'response_type': 'id_token', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'foo', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - })) - self.assertIn('error=invalid_request', rv.json()) - self.assertIn('nonce', rv.json()) + rv = self.client.get( + add_params_to_uri( + "/oauth/authorize", + { + "response_type": "id_token", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "foo", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) + ) + self.assertIn("error=invalid_request", rv.json()) + self.assertIn("nonce", rv.json()) def test_require_nonce(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'bar', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - }) - self.assertIn('error=invalid_request', rv.headers['location']) - self.assertIn('nonce', rv.headers['location']) + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "bar", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) + self.assertIn("error=invalid_request", rv.headers["location"]) + self.assertIn("nonce", rv.headers["location"]) def test_missing_openid_in_scope(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token token', - 'client_id': 'implicit-client', - 'scope': 'profile', - 'state': 'bar', - 'nonce': 'abc', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - }) - self.assertIn('error=invalid_scope', rv.headers['location']) + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token token", + "client_id": "implicit-client", + "scope": "profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) + self.assertIn("error=invalid_scope", rv.headers["location"]) def test_denied(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'bar', - 'nonce': 'abc', - 'redirect_uri': 'https://a.b/c', - }) - self.assertIn('error=access_denied', rv.headers['location']) + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://a.b/c", + }, + ) + self.assertIn("error=access_denied", rv.headers["location"]) def test_authorize_access_token(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token token', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'bar', - 'nonce': 'abc', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - }) - self.assertIn('access_token=', rv.headers['location']) - self.assertIn('id_token=', rv.headers['location']) - self.assertIn('state=bar', rv.headers['location']) - params = dict(url_decode(urlparse.urlparse(rv.headers['location']).fragment)) - self.validate_claims(params['id_token'], params) + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token token", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) + self.assertIn("access_token=", rv.headers["location"]) + self.assertIn("id_token=", rv.headers["location"]) + self.assertIn("state=bar", rv.headers["location"]) + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).fragment)) + self.validate_claims(params["id_token"], params) def test_authorize_id_token(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'bar', - 'nonce': 'abc', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - }) - self.assertIn('id_token=', rv.headers['location']) - self.assertIn('state=bar', rv.headers['location']) - params = dict(url_decode(urlparse.urlparse(rv.headers['location']).fragment)) - self.validate_claims(params['id_token'], params) + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) + self.assertIn("id_token=", rv.headers["location"]) + self.assertIn("state=bar", rv.headers["location"]) + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).fragment)) + self.validate_claims(params["id_token"], params) def test_response_mode_query(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token', - 'response_mode': 'query', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'bar', - 'nonce': 'abc', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - }) - self.assertIn('id_token=', rv.headers['location']) - self.assertIn('state=bar', rv.headers['location']) - params = dict(url_decode(urlparse.urlparse(rv.headers['location']).query)) - self.validate_claims(params['id_token'], params) + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "response_mode": "query", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) + self.assertIn("id_token=", rv.headers["location"]) + self.assertIn("state=bar", rv.headers["location"]) + params = dict(url_decode(urlparse.urlparse(rv.headers["location"]).query)) + self.validate_claims(params["id_token"], params) def test_response_mode_form_post(self): self.prepare_data() - rv = self.client.post('/oauth/authorize', data={ - 'response_type': 'id_token', - 'response_mode': 'form_post', - 'client_id': 'implicit-client', - 'scope': 'openid profile', - 'state': 'bar', - 'nonce': 'abc', - 'redirect_uri': 'https://a.b/c', - 'user_id': '1' - }) + rv = self.client.post( + "/oauth/authorize", + data={ + "response_type": "id_token", + "response_mode": "form_post", + "client_id": "implicit-client", + "scope": "openid profile", + "state": "bar", + "nonce": "abc", + "redirect_uri": "https://a.b/c", + "user_id": "1", + }, + ) self.assertIn('name="id_token"', rv.json()) self.assertIn('name="state"', rv.json()) diff --git a/tests/fastapi/test_oauth2/test_password_grant.py b/tests/fastapi/test_oauth2/test_password_grant.py index a7fd50f9..92c467a5 100644 --- a/tests/fastapi/test_oauth2/test_password_grant.py +++ b/tests/fastapi/test_oauth2/test_password_grant.py @@ -1,10 +1,9 @@ from authlib.common.urls import add_params_to_uri -from authlib.oauth2.rfc6749.grants import ( - ResourceOwnerPasswordCredentialsGrant as _PasswordGrant, -) -from .models import db, User, Client -from .oauth2_server import TestCase -from .oauth2_server import create_authorization_server +from authlib.oauth2.rfc6749.grants import \ + ResourceOwnerPasswordCredentialsGrant as _PasswordGrant + +from .models import Client, User, db +from .oauth2_server import TestCase, create_authorization_server class PasswordGrant(_PasswordGrant): @@ -15,151 +14,182 @@ def authenticate_user(self, username, password): class PasswordTest(TestCase): - def prepare_data(self, grant_type='password'): + def prepare_data(self, grant_type="password"): server = create_authorization_server(self.app) server.register_grant(PasswordGrant) self.server = server - user = User(username='foo') + user = User(username="foo") db.add(user) db.commit() client = Client( user_id=user.id, - client_id='password-client', - client_secret='password-secret', + client_id="password-client", + client_secret="password-secret", + ) + client.set_client_metadata( + { + "scope": "profile", + "grant_types": [grant_type], + "redirect_uris": ["http://localhost/authorized"], + } ) - client.set_client_metadata({ - 'scope': 'profile', - 'grant_types': [grant_type], - 'redirect_uris': ['http://localhost/authorized'], - }) db.add(client) db.commit() def test_invalid_client(self): self.prepare_data() - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }) + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_client') - - headers = self.create_basic_header( - 'password-client', 'invalid-secret' + self.assertEqual(resp["error"], "invalid_client") + + headers = self.create_basic_header("password-client", "invalid-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }, headers=headers) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_client') + self.assertEqual(resp["error"], "invalid_client") def test_invalid_scope(self): self.prepare_data() - self.server.metadata = {'scopes_supported': ['profile']} - headers = self.create_basic_header( - 'password-client', 'password-secret' + self.server.scopes_supported = "profile" + headers = self.create_basic_header("password-client", "password-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + "scope": "invalid", + }, + headers=headers, ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - 'scope': 'invalid', - }, headers=headers) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_scope') + self.assertEqual(resp["error"], "invalid_scope") def test_invalid_request(self): self.prepare_data() - headers = self.create_basic_header( - 'password-client', 'password-secret' + headers = self.create_basic_header("password-client", "password-secret") + + rv = self.client.get( + add_params_to_uri( + "/oauth/token", + { + "grant_type": "password", + }, + ), + headers=headers, ) - - rv = self.client.get(add_params_to_uri('/oauth/token', { - 'grant_type': 'password', - }), headers=headers) resp = rv.json() - self.assertEqual(resp['error'], 'unsupported_grant_type') - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - }, headers=headers) + self.assertEqual(resp["error"], "unsupported_grant_type") + + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + }, + headers=headers, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - }, headers=headers) + self.assertEqual(resp["error"], "invalid_request") + + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + }, + headers=headers, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') - - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'wrong', - }, headers=headers) + self.assertEqual(resp["error"], "invalid_request") + + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "wrong", + }, + headers=headers, + ) resp = rv.json() - self.assertEqual(resp['error'], 'invalid_request') + self.assertEqual(resp["error"], "invalid_request") def test_invalid_grant_type(self): - self.prepare_data(grant_type='invalid') - headers = self.create_basic_header( - 'password-client', 'password-secret' + self.prepare_data(grant_type="invalid") + headers = self.create_basic_header("password-client", "password-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }, headers=headers) resp = rv.json() - self.assertEqual(resp['error'], 'unauthorized_client') + self.assertEqual(resp["error"], "unauthorized_client") def test_authorize_token(self): self.prepare_data() - headers = self.create_basic_header( - 'password-client', 'password-secret' + headers = self.create_basic_header("password-client", "password-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }, headers=headers) resp = rv.json() - self.assertIn('access_token', resp) + self.assertIn("access_token", resp) def test_token_generator(self): - m = 'tests.fastapi.test_oauth2.oauth2_server:token_generator' - self.app.config.update({'OAUTH2_ACCESS_TOKEN_GENERATOR': m}) + m = "tests.fastapi.test_oauth2.oauth2_server:token_generator" + self.app.config.update({"OAUTH2_ACCESS_TOKEN_GENERATOR": m}) self.prepare_data() - headers = self.create_basic_header( - 'password-client', 'password-secret' + headers = self.create_basic_header("password-client", "password-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }, headers=headers) resp = rv.json() - self.assertIn('access_token', resp) - self.assertIn('p-password.1.', resp['access_token']) + self.assertIn("access_token", resp) + self.assertIn("p-password.1.", resp["access_token"]) def test_custom_expires_in(self): - self.app.config.update({ - 'OAUTH2_TOKEN_EXPIRES_IN': {'password': 1800} - }) + self.app.config.update({"OAUTH2_TOKEN_EXPIRES_IN": {"password": 1800}}) self.prepare_data() - headers = self.create_basic_header( - 'password-client', 'password-secret' + headers = self.create_basic_header("password-client", "password-secret") + rv = self.client.post( + "/oauth/token", + data={ + "grant_type": "password", + "username": "foo", + "password": "ok", + }, + headers=headers, ) - rv = self.client.post('/oauth/token', data={ - 'grant_type': 'password', - 'username': 'foo', - 'password': 'ok', - }, headers=headers) resp = rv.json() - self.assertIn('access_token', resp) - self.assertEqual(resp['expires_in'], 1800) + self.assertIn("access_token", resp) + self.assertEqual(resp["expires_in"], 1800) diff --git a/tox.ini b/tox.ini index fb545f5e..2ae88c74 100644 --- a/tox.ini +++ b/tox.ini @@ -9,7 +9,6 @@ deps = -rrequirements-test.txt fastapi: FastAPI fastapi: sqlalchemy - fastapi: mangum fastapi: werkzeug fastapi: python-multipart flask: Flask