From 1b2c88226f0926551983f5a714815a67ef8de43c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Belchior?= Date: Thu, 29 Feb 2024 14:55:48 -0300 Subject: [PATCH] RHCLOUD-29641: Validate Service Account usernames before requests --- rbac/management/principal/it_service.py | 34 ++++++++++++------- tests/management/principal/test_it_service.py | 17 ++++++++-- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/rbac/management/principal/it_service.py b/rbac/management/principal/it_service.py index 6d9351f6..34a6d028 100644 --- a/rbac/management/principal/it_service.py +++ b/rbac/management/principal/it_service.py @@ -35,6 +35,7 @@ LOGGER = logging.getLogger(__name__) SERVICE_ACCOUNT_CLIENT_IDS_KEY = "service_account_client_ids" TYPE_SERVICE_ACCOUNT = "service-account" +KEY_SERVICE_ACCOUNT = "service-account-" # IT path to fetch the service accounts. IT_PATH_GET_SERVICE_ACCOUNTS = "/service_accounts/v1" @@ -191,7 +192,7 @@ def is_service_account_valid_by_username(self, user: User, service_account_usern return True if self.is_username_service_account(service_account_username): - client_id = service_account_username.replace("service-account-", "") + client_id = service_account_username.replace(KEY_SERVICE_ACCOUNT, "") else: client_id = service_account_username @@ -376,25 +377,32 @@ def get_service_accounts_group(self, group: Group, user: User, options: dict = { @staticmethod def is_username_service_account(username: str) -> bool: """Check if the given username belongs to a service account.""" - return username.startswith("service-account-") + starts_with = username.startswith(KEY_SERVICE_ACCOUNT) + + # Validate the UUID for the ClientID reference + if starts_with: + try: + if username.count(KEY_SERVICE_ACCOUNT) != 1: + raise ValueError + + uuid.UUID(username.replace(KEY_SERVICE_ACCOUNT, "")) + except ValueError: + raise serializers.ValidationError({"detail": "Invalid format for a Service Account username"}) + + return starts_with @staticmethod def extract_client_id_service_account_username(username: str) -> uuid.UUID: """Extract the client ID from the service account's username.""" # If it has the "service-account" prefix, we just need to strip it and return the rest of the username, which # contains the client ID. Else, we have just received the client ID. - try: - if ITService.is_username_service_account(username=username): - return uuid.UUID(username.replace("service-account-", "")) - else: + if ITService.is_username_service_account(username=username): + return uuid.UUID(username.replace(KEY_SERVICE_ACCOUNT, "")) + else: + try: return uuid.UUID(username) - except ValueError: - raise serializers.ValidationError( - { - "detail": "unable to extract the client ID from the service account's username because the" - " provided UUID is invalid" - } - ) + except ValueError: + raise serializers.ValidationError({"detail": "Invalid ClientId for a Service Account username"}) def generate_service_accounts_report_in_group(self, group: Group, client_ids: set[str]) -> dict[str, bool]: """Check if the given service accounts are in the specified group.""" diff --git a/tests/management/principal/test_it_service.py b/tests/management/principal/test_it_service.py index 27ffb5f5..8d8a8aef 100644 --- a/tests/management/principal/test_it_service.py +++ b/tests/management/principal/test_it_service.py @@ -239,15 +239,28 @@ def test_extract_client_id_service_account_username(self) -> None: "the client ID was not correctly extracted from a full username", ) + # Call the function under test with a username without client ID (UUID). + try: + self.assertFalse(ITService.extract_client_id_service_account_username(username="abcde")) + self.fail( + "when providing an invalid UUID as the client ID to be extracted, the function under test should raise an error" + ) + except serializers.ValidationError as ve: + self.assertEqual( + "Invalid ClientId for a Service Account username", + str(ve.detail.get("detail")), + "unexpected error message when providing an invalid UUID as the client ID", + ) + # Call the function under test with an invalid username which contains a bad formed UUID. try: - ITService.extract_client_id_service_account_username(username="abcde") + ITService.extract_client_id_service_account_username(username="service-account-xxxxx") self.fail( "when providing an invalid UUID as the client ID to be extracted, the function under test should raise an error" ) except serializers.ValidationError as ve: self.assertEqual( - "unable to extract the client ID from the service account's username because the provided UUID is invalid", + "Invalid format for a Service Account username", str(ve.detail.get("detail")), "unexpected error message when providing an invalid UUID as the client ID", )