From 34781a6da5499102f0e6e836d41f048e393771e0 Mon Sep 17 00:00:00 2001 From: Mikel Alejo Date: Wed, 14 Feb 2024 14:50:27 +0100 Subject: [PATCH] RHCLOUD-30838 | refactor: move the token validation to the middleware (#1021) * refactor: move the token validation to the middleware Instead of validating the token that comes in the requests made by the service accounts in every view, it makes more sense to validate it when we are building our internal user object, and to leave it there for easy access for the rest of the code. RHCLOUD-30838 * refactor: apply review suggestions RHCLOUD-30838 * refactor: exception handling closer to the function that raises it RHCLOUD-30838 --- rbac/api/models.py | 1 + rbac/management/group/view.py | 101 +----------- rbac/management/principal/it_service.py | 8 +- rbac/management/principal/view.py | 50 +----- rbac/rbac/middleware.py | 55 +++++++ tests/identity_request.py | 43 ++++- tests/management/principal/test_view.py | 13 -- tests/rbac/test_middleware.py | 198 ++++++++++++++++++++++++ 8 files changed, 300 insertions(+), 169 deletions(-) diff --git a/rbac/api/models.py b/rbac/api/models.py index 733dd1617..95678a660 100644 --- a/rbac/api/models.py +++ b/rbac/api/models.py @@ -72,5 +72,6 @@ class User: is_active = True org_id = None # Service account properties. + bearer_token: str = "" client_id: str = "" is_service_account: bool = False diff --git a/rbac/management/group/view.py b/rbac/management/group/view.py index 5e33d1139..8c872cfa5 100644 --- a/rbac/management/group/view.py +++ b/rbac/management/group/view.py @@ -53,13 +53,12 @@ from rest_framework import mixins, serializers, status, viewsets from rest_framework.decorators import action from rest_framework.filters import OrderingFilter +from rest_framework.request import Request from rest_framework.response import Response from api.models import Tenant, User from .insufficient_privileges import InsufficientPrivilegesError from .service_account_not_found_error import ServiceAccountNotFoundError -from ..authorization.token_validator import ITSSOTokenValidator, InvalidTokenError, MissingAuthorizationError -from ..authorization.token_validator import UnableMeetPrerequisitesError from ..principal.unexpected_status_code_from_it import UnexpectedStatusCodeFromITError USERNAMES_KEY = "usernames" @@ -397,7 +396,6 @@ def add_service_accounts( self, user: User, group: Group, - bearer_token: str, service_accounts: Iterable[dict], account_name: str = "", org_id: str = "", @@ -407,7 +405,7 @@ def add_service_accounts( # want to skip calling IT it_service = ITService() if not settings.IT_BYPASS_IT_CALLS: - it_service_accounts = it_service.request_service_accounts(bearer_token=bearer_token) + it_service_accounts = it_service.request_service_accounts(bearer_token=user.bearer_token) # Organize them by their client ID. it_service_accounts_by_client_ids: dict[str, dict] = {} @@ -502,7 +500,7 @@ def remove_principals(self, group, principals, account=None, org_id=None): return group @action(detail=True, methods=["get", "post", "delete"]) - def principals(self, request, uuid=None): + def principals(self, request: Request, uuid=None): """Get, add or remove principals from a group.""" """ @api {get} /api/v1/groups/:uuid/principals/ Get principals for a group @@ -612,56 +610,11 @@ def principals(self, request, uuid=None): # Process the service accounts and add them to the group. if len(service_accounts) > 0: - try: - # Attempt validating the JWT token. - token_validator = ITSSOTokenValidator() - bearer_token = token_validator.validate_token(request=request) - except MissingAuthorizationError: - return Response( - status=status.HTTP_401_UNAUTHORIZED, - data={ - "errors": [ - { - "detail": "The authorization header is required for fetching service accounts.", - "source": "groups", - "status": str(status.HTTP_401_UNAUTHORIZED), - } - ] - }, - ) - except InvalidTokenError: - return Response( - status=status.HTTP_401_UNAUTHORIZED, - data={ - "errors": [ - { - "detail": "Invalid token provided.", - "source": "groups", - "status": str(status.HTTP_401_UNAUTHORIZED), - } - ] - }, - ) - except UnableMeetPrerequisitesError: - return Response( - status=status.HTTP_500_INTERNAL_SERVER_ERROR, - data={ - "errors": [ - { - "detail": "Unable to validate token.", - "source": "groups", - "status": str(status.HTTP_500_INTERNAL_SERVER_ERROR), - } - ] - }, - ) - try: resp = self.add_service_accounts( user=request.user, group=group, service_accounts=service_accounts, - bearer_token=bearer_token, account_name=account, org_id=org_id, ) @@ -727,61 +680,19 @@ def principals(self, request, uuid=None): # Make sure we return early for service accounts. if principalType == "service-account": - try: - # Attempt validating the JWT token. - token_validator = ITSSOTokenValidator() - bearer_token = token_validator.validate_token(request=request) - except MissingAuthorizationError: - return Response( - status=status.HTTP_401_UNAUTHORIZED, - data={ - "errors": [ - { - "detail": "The authorization header is required for fetching service accounts.", - "source": "groups", - "status": str(status.HTTP_401_UNAUTHORIZED), - } - ] - }, - ) - except InvalidTokenError: - return Response( - status=status.HTTP_401_UNAUTHORIZED, - data={ - "errors": [ - { - "detail": "Invalid token provided.", - "source": "groups", - "status": str(status.HTTP_401_UNAUTHORIZED), - } - ] - }, - ) - except UnableMeetPrerequisitesError: - return Response( - status=status.HTTP_500_INTERNAL_SERVER_ERROR, - data={ - "errors": [ - { - "detail": "Unable to validate token.", - "source": "groups", - "status": str(status.HTTP_500_INTERNAL_SERVER_ERROR), - } - ] - }, - ) - # Get the service account's description and name filters, and the principal's username filter too. # Finally, get the limit and offset parameters. options[SERVICE_ACCOUNT_DESCRIPTION_KEY] = request.query_params.get(SERVICE_ACCOUNT_DESCRIPTION_KEY) options[SERVICE_ACCOUNT_NAME_KEY] = request.query_params.get(SERVICE_ACCOUNT_NAME_KEY) + + # Get the "principal username" parameter. options[PRINCIPAL_USERNAME_KEY] = request.query_params.get(PRINCIPAL_USERNAME_KEY) # Fetch the group's service accounts. it_service = ITService() try: service_accounts = it_service.get_service_accounts_group( - group=group, bearer_token=bearer_token, options=options + group=group, user=request.user, options=options ) except (requests.exceptions.ConnectionError, UnexpectedStatusCodeFromITError): return Response( diff --git a/rbac/management/principal/it_service.py b/rbac/management/principal/it_service.py index 7ec79b01c..aae4e0b8e 100644 --- a/rbac/management/principal/it_service.py +++ b/rbac/management/principal/it_service.py @@ -155,12 +155,12 @@ def request_service_accounts(self, bearer_token: str) -> list[dict]: return service_accounts - def get_service_accounts(self, user: User, bearer_token: str, options: dict = {}) -> Tuple[list[dict], int]: + def get_service_accounts(self, user: User, options: dict = {}) -> Tuple[list[dict], int]: """Request and returns the service accounts for the given tenant.""" # We might want to bypass calls to the IT service on ephemeral or test environments. it_service_accounts: list[dict] = [] if not settings.IT_BYPASS_IT_CALLS: - it_service_accounts = self.request_service_accounts(bearer_token=bearer_token) + it_service_accounts = self.request_service_accounts(bearer_token=user.bearer_token) # Get the service accounts from the database. The weird filter is to fetch the service accounts depending on # the account number or the organization ID the user gave. @@ -238,12 +238,12 @@ def get_service_accounts(self, user: User, bearer_token: str, options: dict = {} return service_accounts, count - def get_service_accounts_group(self, group: Group, bearer_token: str, options: dict = {}) -> list[dict]: + def get_service_accounts_group(self, group: Group, user: User, options: dict = {}) -> list[dict]: """Get the service accounts for the given group.""" # We might want to bypass calls to the IT service on ephemeral or test environments. it_service_accounts: list[dict] = [] if not settings.IT_BYPASS_IT_CALLS: - it_service_accounts = self.request_service_accounts(bearer_token=bearer_token) + it_service_accounts = self.request_service_accounts(bearer_token=user.bearer_token) # Fetch the service accounts from the group. group_service_account_principals = group.principals.filter(type=TYPE_SERVICE_ACCOUNT) diff --git a/rbac/management/principal/view.py b/rbac/management/principal/view.py index f8d58e645..fc57ca157 100644 --- a/rbac/management/principal/view.py +++ b/rbac/management/principal/view.py @@ -26,8 +26,6 @@ from .it_service import ITService from .proxy import PrincipalProxy from .unexpected_status_code_from_it import UnexpectedStatusCodeFromITError -from ..authorization.token_validator import ITSSOTokenValidator, InvalidTokenError, MissingAuthorizationError -from ..authorization.token_validator import UnableMeetPrerequisitesError from ..permissions.principal_access import PrincipalAccessPermission USERNAMES_KEY = "usernames" @@ -133,50 +131,6 @@ def get(self, request): # Get either service accounts or user principals, depending on what the user specified. if principal_type == "service-account": - try: - # Attempt validating the JWT token. - token_validator = ITSSOTokenValidator() - bearer_token = token_validator.validate_token(request=request) - except MissingAuthorizationError: - return Response( - status=status.HTTP_401_UNAUTHORIZED, - data={ - "errors": [ - { - "detail": "The authorization header is required for fetching service accounts.", - "source": "principals", - "status": str(status.HTTP_401_UNAUTHORIZED), - } - ] - }, - ) - except InvalidTokenError: - return Response( - status=status.HTTP_401_UNAUTHORIZED, - data={ - "errors": [ - { - "detail": "Invalid token provided.", - "source": "principals", - "status": str(status.HTTP_401_UNAUTHORIZED), - } - ] - }, - ) - except UnableMeetPrerequisitesError: - return Response( - status=status.HTTP_500_INTERNAL_SERVER_ERROR, - data={ - "errors": [ - { - "detail": "Unable to validate token.", - "source": "principals", - "status": str(status.HTTP_500_INTERNAL_SERVER_ERROR), - } - ] - }, - ) - options["email"] = query_params.get(EMAIL_KEY) options["match_criteria"] = validate_and_get_key( query_params, MATCH_CRITERIA_KEY, VALID_MATCH_VALUE, required=False @@ -189,9 +143,7 @@ def get(self, request): # Fetch the service accounts from IT. try: it_service = ITService() - service_accounts, sa_count = it_service.get_service_accounts( - user=user, bearer_token=bearer_token, options=options - ) + service_accounts, sa_count = it_service.get_service_accounts(user=user, options=options) except (requests.exceptions.ConnectionError, UnexpectedStatusCodeFromITError): return Response( status=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/rbac/rbac/middleware.py b/rbac/rbac/middleware.py index 2fd53960f..9f2dcbf6b 100644 --- a/rbac/rbac/middleware.py +++ b/rbac/rbac/middleware.py @@ -27,6 +27,8 @@ from django.http import Http404, HttpResponse, QueryDict from django.urls import resolve from django.utils.deprecation import MiddlewareMixin +from management.authorization.token_validator import ITSSOTokenValidator, InvalidTokenError, MissingAuthorizationError +from management.authorization.token_validator import UnableMeetPrerequisitesError from management.cache import TenantCache from management.models import Principal from management.utils import APPLICATION_KEY, access_for_principal, validate_psk @@ -232,6 +234,59 @@ def process_request(self, request): # pylint: disable=R1710 user.user_id = None user.system = False + # The requests made by service accounts are expected to come with an Authorization header which + # contains a Bearer token. Therefore, we will attempt to extract it and validate it, and also store it + # in case we need to use it to contact IT with it. + token_validator = ITSSOTokenValidator() + try: + user.bearer_token = token_validator.validate_token(request=request) + except InvalidTokenError: + return HttpResponse( + content=json.dumps( + { + "errors": [ + { + "detail": "Invalid token provided.", + "status": str(status.HTTP_401_UNAUTHORIZED), + } + ] + } + ), + content_type="application/json", + status=status.HTTP_401_UNAUTHORIZED, + ) + except MissingAuthorizationError: + return HttpResponse( + content=json.dumps( + { + "errors": [ + { + "detail": "A Bearer token in an authorization header is required when" + " contacting RBAC with a service account.", + "status": str(status.HTTP_401_UNAUTHORIZED), + } + ] + } + ), + content_type="application/json", + status=status.HTTP_401_UNAUTHORIZED, + ) + except UnableMeetPrerequisitesError: + return HttpResponse( + content=json.dumps( + { + "errors": [ + { + "detail": "Unable to validate the provided token.", + "status": str(status.HTTP_500_INTERNAL_SERVER_ERROR), + } + ] + } + ), + content_type="application/json", + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + # If we did not get the user information or service account information from the "x-rh-identity" header, # then the request is directly unauthorized. if not user_info and not service_account: diff --git a/tests/identity_request.py b/tests/identity_request.py index cf185269a..dc106be1d 100644 --- a/tests/identity_request.py +++ b/tests/identity_request.py @@ -15,6 +15,8 @@ # along with this program. If not, see . # """Test Case extension to collect common test data.""" +import uuid + from base64 import b64encode from json import dumps as json_dumps from unittest.mock import Mock @@ -69,21 +71,29 @@ def _create_user_data(cls): user_data = {"username": cls.fake.user_name(), "email": cls.fake.email()} return user_data + def _create_service_account_data(cls) -> dict[str, str]: + """Create service account data""" + client_id = str(uuid.uuid4()) + return {"client_id": client_id, "username": f"service-account-{client_id}"} + @classmethod def _create_request_context( cls, - customer_data, - user_data, - is_org_admin=True, - is_internal=False, - cross_account=False, + customer_data: dict[str, str], + user_data: dict[str, str], + is_org_admin: bool = True, + is_internal: bool = False, + cross_account: bool = False, + service_account_data: dict[str, str] = None, ): """Create the request context for a user.""" customer = customer_data account = customer.get("account_id") org_id = customer.get("org_id", None) - identity = cls._build_identity(user_data, account, org_id, is_org_admin, is_internal) + identity = cls._build_identity( + user_data, account, org_id, is_org_admin, is_internal, service_account_data=service_account_data + ) if cross_account: identity["identity"]["internal"] = {"cross_access": True} json_identity = json_dumps(identity) @@ -95,7 +105,15 @@ def _create_request_context( return request_context @classmethod - def _build_identity(cls, user_data, account, org_id, is_org_admin, is_internal): + def _build_identity( + cls, + user_data: dict[str, str], + account: str, + org_id: str, + is_org_admin: bool, + is_internal: bool, + service_account_data: dict[str, str] = None, + ): identity = {"identity": {"account_number": account, "org_id": org_id}} if user_data is not None: identity["identity"]["user"] = { @@ -105,11 +123,20 @@ def _build_identity(cls, user_data, account, org_id, is_org_admin, is_internal): "user_id": "1111111", } + if service_account_data: + identity["identity"]["service_account"] = { + "client_id": service_account_data.get("client_id"), + "username": service_account_data.get("username"), + } + if is_internal: identity["identity"]["type"] = "Associate" identity["identity"]["associate"] = identity.get("identity").get("user") identity["identity"]["user"]["is_internal"] = True else: - identity["identity"]["type"] = "User" + if user_data: + identity["identity"]["type"] = "User" + else: + identity["identity"]["type"] = "ServiceAccount" return identity diff --git a/tests/management/principal/test_view.py b/tests/management/principal/test_view.py index e83acd02a..79cb066c0 100644 --- a/tests/management/principal/test_view.py +++ b/tests/management/principal/test_view.py @@ -827,19 +827,6 @@ def test_read_principal_users(self, mock_request): cross_account_principal.delete() - def test_fetch_service_accounts(self): - """Test fetching service accounts while not providing a token - - Test that when the user request the service accounts without an authorization token, an unauthorized response - is returned - """ - - url = f'{reverse("principals")}?type=service-account' - client = APIClient() - response: Response = client.get(url, **self.headers) - - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - @override_settings(IT_BYPASS_TOKEN_VALIDATION=True) @patch("management.principal.it_service.ITService.request_service_accounts") def test_read_principal_service_account_list_success(self, mock_request): diff --git a/tests/rbac/test_middleware.py b/tests/rbac/test_middleware.py index 2290acf46..0c33e6bfe 100644 --- a/tests/rbac/test_middleware.py +++ b/tests/rbac/test_middleware.py @@ -16,7 +16,9 @@ # """Test the project middleware.""" import collections +import json import os +from unittest import mock from unittest.mock import Mock from django.conf import settings from django.http import QueryDict @@ -31,6 +33,7 @@ from api.serializers import create_tenant_name from tests.identity_request import IdentityRequest from rbac.middleware import HttpResponseUnauthorizedRequest, IdentityHeaderMiddleware +from management.authorization.token_validator import UnableMeetPrerequisitesError from management.models import Access, Group, Permission, Principal, Policy, ResourceDefinition, Role @@ -342,6 +345,201 @@ def test_should_load_user_permissions_regular_user_access_missing_query_params(s self.assertEqual(middleware.should_load_user_permissions(request, user), False) + def test_service_account_with_no_authorization_header(self): + """Test that a 401 is returned when a service account with no authorization header contacts RBAC.""" + # Prepare the mocked service account data. + service_account_data = self._create_service_account_data() + customer = self._create_customer_data() + + # Prepare the mocked request. + request_context = self._create_request_context( + customer_data=customer, user_data=None, service_account_data=service_account_data + ) + mock_request = request_context["request"] + mock_request.path = "/api/v1/access/" + middleware = IdentityHeaderMiddleware(get_response=IdentityHeaderMiddleware.process_request) + + # Set the authorization header to an empty one to test that a missing authorization exception is raised. + mock_request.headers = {"Authorization": None} + + # Call the middleware under test. + result = middleware.process_request(mock_request) + + # Assert that the content type header has the correct value. + self.assertEqual( + "application/json", result.headers.get("Content-Type"), "the content type header has the incorrect value" + ) + + # Assert that the status code is the expected one. + self.assertEqual( + status.HTTP_401_UNAUTHORIZED, + result.status_code, + "unexpected status code received when the authorization header is missing", + ) + + # Assert that the contents of the body are the expected ones. + content = json.loads(result.content.decode("utf-8")) + errors = content["errors"] + if not errors: + self.fail('expected an "errors" array in the received response\'s body, but it was not found') + + if len(errors) != 1: + self.fail(f'expected a single error in the "errors" array, {len(errors)} errors received') + + error = errors[0] + error_detail = error.get("detail") + + if not error_detail: + self.fail("the error detail is missing from the error object") + + self.assertEqual( + error_detail, + "A Bearer token in an authorization header is required when contacting RBAC with a service account.", + "unexpected error detail received in the response", + ) + + error_status = error.get("status") + + if not error_status: + self.fail("the error object is missing the status code") + + self.assertEqual( + str(status.HTTP_401_UNAUTHORIZED), + error_status, + "unexpected status code received in the body of the response", + ) + + @mock.patch("management.authorization.token_validator.ITSSOTokenValidator._get_json_web_keyset") + def test_service_account_unable_validate_token(self, _get_json_web_keyset: Mock): + """Test 500 response for a service account request when unable to meet prerequisites to validate the token.""" + # Prepare the mocked service account data. + service_account_data = self._create_service_account_data() + customer = self._create_customer_data() + + # Prepare the mocked request. + request_context = self._create_request_context( + customer_data=customer, user_data=None, service_account_data=service_account_data + ) + mock_request = request_context["request"] + mock_request.path = "/api/v1/access/" + middleware = IdentityHeaderMiddleware(get_response=IdentityHeaderMiddleware.process_request) + + # Set a non-empty authorization header. + mock_request.headers = {"Authorization": "invalid-bearer-token"} + # Pretend that we are not able to contact IT in order to fetch the required key set to validate the token. + _get_json_web_keyset.side_effect = UnableMeetPrerequisitesError() + + # Call the middleware under test. + result = middleware.process_request(mock_request) + + # Assert that the content type header has the correct value. + self.assertEqual( + "application/json", result.headers.get("Content-Type"), "the content type header has the incorrect value" + ) + + # Assert that the status code is the expected one. + self.assertEqual( + status.HTTP_500_INTERNAL_SERVER_ERROR, + result.status_code, + "unexpected status code received when we are unable to meet the prerequisites to validate the token", + ) + + # Assert that the contents of the body are the expected ones. + content = json.loads(result.content.decode("utf-8")) + errors = content["errors"] + if not errors: + self.fail('expected an "errors" array in the received response\'s body, but it was not found') + + if len(errors) != 1: + self.fail(f'expected a single error in the "errors" array, {len(errors)} errors received') + + error = errors[0] + error_detail = error.get("detail") + + if not error_detail: + self.fail("the error detail is missing from the error object") + + self.assertEqual( + error_detail, "Unable to validate the provided token.", "unexpected error detail received in the response" + ) + + error_status = error.get("status") + + if not error_status: + self.fail("the error object is missing the status code") + + self.assertEqual( + str(status.HTTP_500_INTERNAL_SERVER_ERROR), + error_status, + "unexpected status code received in the body of the response", + ) + + @mock.patch("management.authorization.token_validator.ITSSOTokenValidator._get_json_web_keyset") + @mock.patch("management.authorization.token_validator.jwt.decode") + def test_service_account_with_invalid_bearer_token(self, decode: Mock, _get_json_web_keyset: Mock): + """Test 401 response for a service account request with an invalid bearer token.""" + # Prepare the mocked service account data. + service_account_data = self._create_service_account_data() + customer = self._create_customer_data() + + # Prepare the mocked request. + request_context = self._create_request_context( + customer_data=customer, user_data=None, service_account_data=service_account_data + ) + mock_request = request_context["request"] + mock_request.path = "/api/v1/access/" + middleware = IdentityHeaderMiddleware(get_response=IdentityHeaderMiddleware.process_request) + + # Set a non-empty authorization header. + mock_request.headers = {"Authorization": "invalid-bearer-token"} + # The token validator should avoid calling IT for the test. + _get_json_web_keyset.return_value = "invalid-return-value" + # Make the "decode" function raise any exception to test that it gets properly handled. + decode.side_effect = ValueError("invalid value") + + # Call the middleware under test. + result = middleware.process_request(mock_request) + + # Assert that the content type header has the correct value. + self.assertEqual( + "application/json", result.headers.get("Content-Type"), "the content type header has the incorrect value" + ) + + # Assert that the status code is the expected one. + self.assertEqual( + status.HTTP_401_UNAUTHORIZED, + result.status_code, + "unexpected status code received when the bearer token is invalid", + ) + + # Assert that the contents of the body are the expected ones. + content = json.loads(result.content.decode("utf-8")) + errors = content["errors"] + if not errors: + self.fail('expected an "errors" array in the received response\'s body, but it was not found') + + if len(errors) != 1: + self.fail(f'expected a single error in the "errors" array, {len(errors)} errors received') + + error = errors[0] + error_detail = error.get("detail") + + if not error_detail: + self.fail("the error detail is missing from the error object") + + self.assertEqual(error_detail, "Invalid token provided.", "unexpected error detail received in the response") + + error_status = error.get("status") + + if not error_status: + self.fail("the error object is missing the status code") + + self.assertEqual( + str(status.HTTP_401_UNAUTHORIZED), + error_status, + "unexpected status code received in the body of the response", + ) + class ServiceToService(IdentityRequest): """Tests requests without an identity header."""