diff --git a/rbac/api/models.py b/rbac/api/models.py index 733dd161..95678a66 100644 --- a/rbac/api/models.py +++ b/rbac/api/models.py @@ -72,5 +72,6 @@ class User: is_active = True org_id = None # Service account properties. + bearer_token: str = "" client_id: str = "" is_service_account: bool = False diff --git a/rbac/management/group/view.py b/rbac/management/group/view.py index 5e33d113..8c872cfa 100644 --- a/rbac/management/group/view.py +++ b/rbac/management/group/view.py @@ -53,13 +53,12 @@ from rest_framework import mixins, serializers, status, viewsets from rest_framework.decorators import action from rest_framework.filters import OrderingFilter +from rest_framework.request import Request from rest_framework.response import Response from api.models import Tenant, User from .insufficient_privileges import InsufficientPrivilegesError from .service_account_not_found_error import ServiceAccountNotFoundError -from ..authorization.token_validator import ITSSOTokenValidator, InvalidTokenError, MissingAuthorizationError -from ..authorization.token_validator import UnableMeetPrerequisitesError from ..principal.unexpected_status_code_from_it import UnexpectedStatusCodeFromITError USERNAMES_KEY = "usernames" @@ -397,7 +396,6 @@ def add_service_accounts( self, user: User, group: Group, - bearer_token: str, service_accounts: Iterable[dict], account_name: str = "", org_id: str = "", @@ -407,7 +405,7 @@ def add_service_accounts( # want to skip calling IT it_service = ITService() if not settings.IT_BYPASS_IT_CALLS: - it_service_accounts = it_service.request_service_accounts(bearer_token=bearer_token) + it_service_accounts = it_service.request_service_accounts(bearer_token=user.bearer_token) # Organize them by their client ID. it_service_accounts_by_client_ids: dict[str, dict] = {} @@ -502,7 +500,7 @@ def remove_principals(self, group, principals, account=None, org_id=None): return group @action(detail=True, methods=["get", "post", "delete"]) - def principals(self, request, uuid=None): + def principals(self, request: Request, uuid=None): """Get, add or remove principals from a group.""" """ @api {get} /api/v1/groups/:uuid/principals/ Get principals for a group @@ -612,56 +610,11 @@ def principals(self, request, uuid=None): # Process the service accounts and add them to the group. if len(service_accounts) > 0: - try: - # Attempt validating the JWT token. - token_validator = ITSSOTokenValidator() - bearer_token = token_validator.validate_token(request=request) - except MissingAuthorizationError: - return Response( - status=status.HTTP_401_UNAUTHORIZED, - data={ - "errors": [ - { - "detail": "The authorization header is required for fetching service accounts.", - "source": "groups", - "status": str(status.HTTP_401_UNAUTHORIZED), - } - ] - }, - ) - except InvalidTokenError: - return Response( - status=status.HTTP_401_UNAUTHORIZED, - data={ - "errors": [ - { - "detail": "Invalid token provided.", - "source": "groups", - "status": str(status.HTTP_401_UNAUTHORIZED), - } - ] - }, - ) - except UnableMeetPrerequisitesError: - return Response( - status=status.HTTP_500_INTERNAL_SERVER_ERROR, - data={ - "errors": [ - { - "detail": "Unable to validate token.", - "source": "groups", - "status": str(status.HTTP_500_INTERNAL_SERVER_ERROR), - } - ] - }, - ) - try: resp = self.add_service_accounts( user=request.user, group=group, service_accounts=service_accounts, - bearer_token=bearer_token, account_name=account, org_id=org_id, ) @@ -727,61 +680,19 @@ def principals(self, request, uuid=None): # Make sure we return early for service accounts. if principalType == "service-account": - try: - # Attempt validating the JWT token. - token_validator = ITSSOTokenValidator() - bearer_token = token_validator.validate_token(request=request) - except MissingAuthorizationError: - return Response( - status=status.HTTP_401_UNAUTHORIZED, - data={ - "errors": [ - { - "detail": "The authorization header is required for fetching service accounts.", - "source": "groups", - "status": str(status.HTTP_401_UNAUTHORIZED), - } - ] - }, - ) - except InvalidTokenError: - return Response( - status=status.HTTP_401_UNAUTHORIZED, - data={ - "errors": [ - { - "detail": "Invalid token provided.", - "source": "groups", - "status": str(status.HTTP_401_UNAUTHORIZED), - } - ] - }, - ) - except UnableMeetPrerequisitesError: - return Response( - status=status.HTTP_500_INTERNAL_SERVER_ERROR, - data={ - "errors": [ - { - "detail": "Unable to validate token.", - "source": "groups", - "status": str(status.HTTP_500_INTERNAL_SERVER_ERROR), - } - ] - }, - ) - # Get the service account's description and name filters, and the principal's username filter too. # Finally, get the limit and offset parameters. options[SERVICE_ACCOUNT_DESCRIPTION_KEY] = request.query_params.get(SERVICE_ACCOUNT_DESCRIPTION_KEY) options[SERVICE_ACCOUNT_NAME_KEY] = request.query_params.get(SERVICE_ACCOUNT_NAME_KEY) + + # Get the "principal username" parameter. options[PRINCIPAL_USERNAME_KEY] = request.query_params.get(PRINCIPAL_USERNAME_KEY) # Fetch the group's service accounts. it_service = ITService() try: service_accounts = it_service.get_service_accounts_group( - group=group, bearer_token=bearer_token, options=options + group=group, user=request.user, options=options ) except (requests.exceptions.ConnectionError, UnexpectedStatusCodeFromITError): return Response( diff --git a/rbac/management/principal/it_service.py b/rbac/management/principal/it_service.py index 7ec79b01..aae4e0b8 100644 --- a/rbac/management/principal/it_service.py +++ b/rbac/management/principal/it_service.py @@ -155,12 +155,12 @@ def request_service_accounts(self, bearer_token: str) -> list[dict]: return service_accounts - def get_service_accounts(self, user: User, bearer_token: str, options: dict = {}) -> Tuple[list[dict], int]: + def get_service_accounts(self, user: User, options: dict = {}) -> Tuple[list[dict], int]: """Request and returns the service accounts for the given tenant.""" # We might want to bypass calls to the IT service on ephemeral or test environments. it_service_accounts: list[dict] = [] if not settings.IT_BYPASS_IT_CALLS: - it_service_accounts = self.request_service_accounts(bearer_token=bearer_token) + it_service_accounts = self.request_service_accounts(bearer_token=user.bearer_token) # Get the service accounts from the database. The weird filter is to fetch the service accounts depending on # the account number or the organization ID the user gave. @@ -238,12 +238,12 @@ def get_service_accounts(self, user: User, bearer_token: str, options: dict = {} return service_accounts, count - def get_service_accounts_group(self, group: Group, bearer_token: str, options: dict = {}) -> list[dict]: + def get_service_accounts_group(self, group: Group, user: User, options: dict = {}) -> list[dict]: """Get the service accounts for the given group.""" # We might want to bypass calls to the IT service on ephemeral or test environments. it_service_accounts: list[dict] = [] if not settings.IT_BYPASS_IT_CALLS: - it_service_accounts = self.request_service_accounts(bearer_token=bearer_token) + it_service_accounts = self.request_service_accounts(bearer_token=user.bearer_token) # Fetch the service accounts from the group. group_service_account_principals = group.principals.filter(type=TYPE_SERVICE_ACCOUNT) diff --git a/rbac/management/principal/view.py b/rbac/management/principal/view.py index f8d58e64..fc57ca15 100644 --- a/rbac/management/principal/view.py +++ b/rbac/management/principal/view.py @@ -26,8 +26,6 @@ from .it_service import ITService from .proxy import PrincipalProxy from .unexpected_status_code_from_it import UnexpectedStatusCodeFromITError -from ..authorization.token_validator import ITSSOTokenValidator, InvalidTokenError, MissingAuthorizationError -from ..authorization.token_validator import UnableMeetPrerequisitesError from ..permissions.principal_access import PrincipalAccessPermission USERNAMES_KEY = "usernames" @@ -133,50 +131,6 @@ def get(self, request): # Get either service accounts or user principals, depending on what the user specified. if principal_type == "service-account": - try: - # Attempt validating the JWT token. - token_validator = ITSSOTokenValidator() - bearer_token = token_validator.validate_token(request=request) - except MissingAuthorizationError: - return Response( - status=status.HTTP_401_UNAUTHORIZED, - data={ - "errors": [ - { - "detail": "The authorization header is required for fetching service accounts.", - "source": "principals", - "status": str(status.HTTP_401_UNAUTHORIZED), - } - ] - }, - ) - except InvalidTokenError: - return Response( - status=status.HTTP_401_UNAUTHORIZED, - data={ - "errors": [ - { - "detail": "Invalid token provided.", - "source": "principals", - "status": str(status.HTTP_401_UNAUTHORIZED), - } - ] - }, - ) - except UnableMeetPrerequisitesError: - return Response( - status=status.HTTP_500_INTERNAL_SERVER_ERROR, - data={ - "errors": [ - { - "detail": "Unable to validate token.", - "source": "principals", - "status": str(status.HTTP_500_INTERNAL_SERVER_ERROR), - } - ] - }, - ) - options["email"] = query_params.get(EMAIL_KEY) options["match_criteria"] = validate_and_get_key( query_params, MATCH_CRITERIA_KEY, VALID_MATCH_VALUE, required=False @@ -189,9 +143,7 @@ def get(self, request): # Fetch the service accounts from IT. try: it_service = ITService() - service_accounts, sa_count = it_service.get_service_accounts( - user=user, bearer_token=bearer_token, options=options - ) + service_accounts, sa_count = it_service.get_service_accounts(user=user, options=options) except (requests.exceptions.ConnectionError, UnexpectedStatusCodeFromITError): return Response( status=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/rbac/rbac/middleware.py b/rbac/rbac/middleware.py index 2fd53960..9f2dcbf6 100644 --- a/rbac/rbac/middleware.py +++ b/rbac/rbac/middleware.py @@ -27,6 +27,8 @@ from django.http import Http404, HttpResponse, QueryDict from django.urls import resolve from django.utils.deprecation import MiddlewareMixin +from management.authorization.token_validator import ITSSOTokenValidator, InvalidTokenError, MissingAuthorizationError +from management.authorization.token_validator import UnableMeetPrerequisitesError from management.cache import TenantCache from management.models import Principal from management.utils import APPLICATION_KEY, access_for_principal, validate_psk @@ -232,6 +234,59 @@ def process_request(self, request): # pylint: disable=R1710 user.user_id = None user.system = False + # The requests made by service accounts are expected to come with an Authorization header which + # contains a Bearer token. Therefore, we will attempt to extract it and validate it, and also store it + # in case we need to use it to contact IT with it. + token_validator = ITSSOTokenValidator() + try: + user.bearer_token = token_validator.validate_token(request=request) + except InvalidTokenError: + return HttpResponse( + content=json.dumps( + { + "errors": [ + { + "detail": "Invalid token provided.", + "status": str(status.HTTP_401_UNAUTHORIZED), + } + ] + } + ), + content_type="application/json", + status=status.HTTP_401_UNAUTHORIZED, + ) + except MissingAuthorizationError: + return HttpResponse( + content=json.dumps( + { + "errors": [ + { + "detail": "A Bearer token in an authorization header is required when" + " contacting RBAC with a service account.", + "status": str(status.HTTP_401_UNAUTHORIZED), + } + ] + } + ), + content_type="application/json", + status=status.HTTP_401_UNAUTHORIZED, + ) + except UnableMeetPrerequisitesError: + return HttpResponse( + content=json.dumps( + { + "errors": [ + { + "detail": "Unable to validate the provided token.", + "status": str(status.HTTP_500_INTERNAL_SERVER_ERROR), + } + ] + } + ), + content_type="application/json", + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + # If we did not get the user information or service account information from the "x-rh-identity" header, # then the request is directly unauthorized. if not user_info and not service_account: diff --git a/tests/identity_request.py b/tests/identity_request.py index cf185269..dc106be1 100644 --- a/tests/identity_request.py +++ b/tests/identity_request.py @@ -15,6 +15,8 @@ # along with this program. If not, see . # """Test Case extension to collect common test data.""" +import uuid + from base64 import b64encode from json import dumps as json_dumps from unittest.mock import Mock @@ -69,21 +71,29 @@ def _create_user_data(cls): user_data = {"username": cls.fake.user_name(), "email": cls.fake.email()} return user_data + def _create_service_account_data(cls) -> dict[str, str]: + """Create service account data""" + client_id = str(uuid.uuid4()) + return {"client_id": client_id, "username": f"service-account-{client_id}"} + @classmethod def _create_request_context( cls, - customer_data, - user_data, - is_org_admin=True, - is_internal=False, - cross_account=False, + customer_data: dict[str, str], + user_data: dict[str, str], + is_org_admin: bool = True, + is_internal: bool = False, + cross_account: bool = False, + service_account_data: dict[str, str] = None, ): """Create the request context for a user.""" customer = customer_data account = customer.get("account_id") org_id = customer.get("org_id", None) - identity = cls._build_identity(user_data, account, org_id, is_org_admin, is_internal) + identity = cls._build_identity( + user_data, account, org_id, is_org_admin, is_internal, service_account_data=service_account_data + ) if cross_account: identity["identity"]["internal"] = {"cross_access": True} json_identity = json_dumps(identity) @@ -95,7 +105,15 @@ def _create_request_context( return request_context @classmethod - def _build_identity(cls, user_data, account, org_id, is_org_admin, is_internal): + def _build_identity( + cls, + user_data: dict[str, str], + account: str, + org_id: str, + is_org_admin: bool, + is_internal: bool, + service_account_data: dict[str, str] = None, + ): identity = {"identity": {"account_number": account, "org_id": org_id}} if user_data is not None: identity["identity"]["user"] = { @@ -105,11 +123,20 @@ def _build_identity(cls, user_data, account, org_id, is_org_admin, is_internal): "user_id": "1111111", } + if service_account_data: + identity["identity"]["service_account"] = { + "client_id": service_account_data.get("client_id"), + "username": service_account_data.get("username"), + } + if is_internal: identity["identity"]["type"] = "Associate" identity["identity"]["associate"] = identity.get("identity").get("user") identity["identity"]["user"]["is_internal"] = True else: - identity["identity"]["type"] = "User" + if user_data: + identity["identity"]["type"] = "User" + else: + identity["identity"]["type"] = "ServiceAccount" return identity diff --git a/tests/management/principal/test_view.py b/tests/management/principal/test_view.py index e83acd02..79cb066c 100644 --- a/tests/management/principal/test_view.py +++ b/tests/management/principal/test_view.py @@ -827,19 +827,6 @@ def test_read_principal_users(self, mock_request): cross_account_principal.delete() - def test_fetch_service_accounts(self): - """Test fetching service accounts while not providing a token - - Test that when the user request the service accounts without an authorization token, an unauthorized response - is returned - """ - - url = f'{reverse("principals")}?type=service-account' - client = APIClient() - response: Response = client.get(url, **self.headers) - - self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) - @override_settings(IT_BYPASS_TOKEN_VALIDATION=True) @patch("management.principal.it_service.ITService.request_service_accounts") def test_read_principal_service_account_list_success(self, mock_request): diff --git a/tests/rbac/test_middleware.py b/tests/rbac/test_middleware.py index 2290acf4..0c33e6bf 100644 --- a/tests/rbac/test_middleware.py +++ b/tests/rbac/test_middleware.py @@ -16,7 +16,9 @@ # """Test the project middleware.""" import collections +import json import os +from unittest import mock from unittest.mock import Mock from django.conf import settings from django.http import QueryDict @@ -31,6 +33,7 @@ from api.serializers import create_tenant_name from tests.identity_request import IdentityRequest from rbac.middleware import HttpResponseUnauthorizedRequest, IdentityHeaderMiddleware +from management.authorization.token_validator import UnableMeetPrerequisitesError from management.models import Access, Group, Permission, Principal, Policy, ResourceDefinition, Role @@ -342,6 +345,201 @@ def test_should_load_user_permissions_regular_user_access_missing_query_params(s self.assertEqual(middleware.should_load_user_permissions(request, user), False) + def test_service_account_with_no_authorization_header(self): + """Test that a 401 is returned when a service account with no authorization header contacts RBAC.""" + # Prepare the mocked service account data. + service_account_data = self._create_service_account_data() + customer = self._create_customer_data() + + # Prepare the mocked request. + request_context = self._create_request_context( + customer_data=customer, user_data=None, service_account_data=service_account_data + ) + mock_request = request_context["request"] + mock_request.path = "/api/v1/access/" + middleware = IdentityHeaderMiddleware(get_response=IdentityHeaderMiddleware.process_request) + + # Set the authorization header to an empty one to test that a missing authorization exception is raised. + mock_request.headers = {"Authorization": None} + + # Call the middleware under test. + result = middleware.process_request(mock_request) + + # Assert that the content type header has the correct value. + self.assertEqual( + "application/json", result.headers.get("Content-Type"), "the content type header has the incorrect value" + ) + + # Assert that the status code is the expected one. + self.assertEqual( + status.HTTP_401_UNAUTHORIZED, + result.status_code, + "unexpected status code received when the authorization header is missing", + ) + + # Assert that the contents of the body are the expected ones. + content = json.loads(result.content.decode("utf-8")) + errors = content["errors"] + if not errors: + self.fail('expected an "errors" array in the received response\'s body, but it was not found') + + if len(errors) != 1: + self.fail(f'expected a single error in the "errors" array, {len(errors)} errors received') + + error = errors[0] + error_detail = error.get("detail") + + if not error_detail: + self.fail("the error detail is missing from the error object") + + self.assertEqual( + error_detail, + "A Bearer token in an authorization header is required when contacting RBAC with a service account.", + "unexpected error detail received in the response", + ) + + error_status = error.get("status") + + if not error_status: + self.fail("the error object is missing the status code") + + self.assertEqual( + str(status.HTTP_401_UNAUTHORIZED), + error_status, + "unexpected status code received in the body of the response", + ) + + @mock.patch("management.authorization.token_validator.ITSSOTokenValidator._get_json_web_keyset") + def test_service_account_unable_validate_token(self, _get_json_web_keyset: Mock): + """Test 500 response for a service account request when unable to meet prerequisites to validate the token.""" + # Prepare the mocked service account data. + service_account_data = self._create_service_account_data() + customer = self._create_customer_data() + + # Prepare the mocked request. + request_context = self._create_request_context( + customer_data=customer, user_data=None, service_account_data=service_account_data + ) + mock_request = request_context["request"] + mock_request.path = "/api/v1/access/" + middleware = IdentityHeaderMiddleware(get_response=IdentityHeaderMiddleware.process_request) + + # Set a non-empty authorization header. + mock_request.headers = {"Authorization": "invalid-bearer-token"} + # Pretend that we are not able to contact IT in order to fetch the required key set to validate the token. + _get_json_web_keyset.side_effect = UnableMeetPrerequisitesError() + + # Call the middleware under test. + result = middleware.process_request(mock_request) + + # Assert that the content type header has the correct value. + self.assertEqual( + "application/json", result.headers.get("Content-Type"), "the content type header has the incorrect value" + ) + + # Assert that the status code is the expected one. + self.assertEqual( + status.HTTP_500_INTERNAL_SERVER_ERROR, + result.status_code, + "unexpected status code received when we are unable to meet the prerequisites to validate the token", + ) + + # Assert that the contents of the body are the expected ones. + content = json.loads(result.content.decode("utf-8")) + errors = content["errors"] + if not errors: + self.fail('expected an "errors" array in the received response\'s body, but it was not found') + + if len(errors) != 1: + self.fail(f'expected a single error in the "errors" array, {len(errors)} errors received') + + error = errors[0] + error_detail = error.get("detail") + + if not error_detail: + self.fail("the error detail is missing from the error object") + + self.assertEqual( + error_detail, "Unable to validate the provided token.", "unexpected error detail received in the response" + ) + + error_status = error.get("status") + + if not error_status: + self.fail("the error object is missing the status code") + + self.assertEqual( + str(status.HTTP_500_INTERNAL_SERVER_ERROR), + error_status, + "unexpected status code received in the body of the response", + ) + + @mock.patch("management.authorization.token_validator.ITSSOTokenValidator._get_json_web_keyset") + @mock.patch("management.authorization.token_validator.jwt.decode") + def test_service_account_with_invalid_bearer_token(self, decode: Mock, _get_json_web_keyset: Mock): + """Test 401 response for a service account request with an invalid bearer token.""" + # Prepare the mocked service account data. + service_account_data = self._create_service_account_data() + customer = self._create_customer_data() + + # Prepare the mocked request. + request_context = self._create_request_context( + customer_data=customer, user_data=None, service_account_data=service_account_data + ) + mock_request = request_context["request"] + mock_request.path = "/api/v1/access/" + middleware = IdentityHeaderMiddleware(get_response=IdentityHeaderMiddleware.process_request) + + # Set a non-empty authorization header. + mock_request.headers = {"Authorization": "invalid-bearer-token"} + # The token validator should avoid calling IT for the test. + _get_json_web_keyset.return_value = "invalid-return-value" + # Make the "decode" function raise any exception to test that it gets properly handled. + decode.side_effect = ValueError("invalid value") + + # Call the middleware under test. + result = middleware.process_request(mock_request) + + # Assert that the content type header has the correct value. + self.assertEqual( + "application/json", result.headers.get("Content-Type"), "the content type header has the incorrect value" + ) + + # Assert that the status code is the expected one. + self.assertEqual( + status.HTTP_401_UNAUTHORIZED, + result.status_code, + "unexpected status code received when the bearer token is invalid", + ) + + # Assert that the contents of the body are the expected ones. + content = json.loads(result.content.decode("utf-8")) + errors = content["errors"] + if not errors: + self.fail('expected an "errors" array in the received response\'s body, but it was not found') + + if len(errors) != 1: + self.fail(f'expected a single error in the "errors" array, {len(errors)} errors received') + + error = errors[0] + error_detail = error.get("detail") + + if not error_detail: + self.fail("the error detail is missing from the error object") + + self.assertEqual(error_detail, "Invalid token provided.", "unexpected error detail received in the response") + + error_status = error.get("status") + + if not error_status: + self.fail("the error object is missing the status code") + + self.assertEqual( + str(status.HTTP_401_UNAUTHORIZED), + error_status, + "unexpected status code received in the body of the response", + ) + class ServiceToService(IdentityRequest): """Tests requests without an identity header."""