From 9d42181e960c3efabbd5c0f7c0f40c438404d638 Mon Sep 17 00:00:00 2001 From: Mikel Alejo Barcina Ribera Date: Tue, 20 Feb 2024 16:21:03 +0100 Subject: [PATCH] refactor: don't transform the incoming client IDs into UUIDs RHCLOUD-30299 --- rbac/management/group/view.py | 6 +-- rbac/management/principal/it_service.py | 9 ++-- tests/management/principal/test_it_service.py | 52 +++++++------------ 3 files changed, 25 insertions(+), 42 deletions(-) diff --git a/rbac/management/group/view.py b/rbac/management/group/view.py index 294f0eba7..529efa2c9 100644 --- a/rbac/management/group/view.py +++ b/rbac/management/group/view.py @@ -693,10 +693,10 @@ def principals(self, request: Request, uuid: Optional[UUID] = None): # Turn the received and comma separated client IDs into a manageable set. received_client_ids: set[str] = set(service_account_client_ids_raw.split(",")) - processed_client_ids: set[UUID] = set() + # Validate that the provided strings are actually UUIDs. for rci in received_client_ids: try: - processed_client_ids.add(UUID(str(rci))) + UUID(rci) except ValueError: return Response( status=status.HTTP_400_BAD_REQUEST, @@ -715,7 +715,7 @@ def principals(self, request: Request, uuid: Optional[UUID] = None): # ones are available to be added to the given group. it_service = ITService() result: dict = it_service.generate_service_accounts_report_in_group( - group=group, client_ids=processed_client_ids + group=group, client_ids=received_client_ids ) # Prettify the output payload and return it. diff --git a/rbac/management/principal/it_service.py b/rbac/management/principal/it_service.py index 8dd8c32a1..18b6bfdf8 100644 --- a/rbac/management/principal/it_service.py +++ b/rbac/management/principal/it_service.py @@ -19,7 +19,6 @@ import time import uuid from typing import Optional, Tuple, Union -from uuid import UUID import requests from django.conf import settings @@ -403,7 +402,7 @@ def extract_client_id_service_account_username(username: str) -> uuid.UUID: } ) - def generate_service_accounts_report_in_group(self, group: Group, client_ids: set[UUID]) -> dict[str, bool]: + def generate_service_accounts_report_in_group(self, group: Group, client_ids: set[str]) -> dict[str, bool]: """Check if the given service accounts are in the specified group.""" # Fetch the service accounts from the group. group_service_account_principals = ( @@ -414,10 +413,8 @@ def generate_service_accounts_report_in_group(self, group: Group, client_ids: se # Mark the specified client IDs as "present or missing" from the result set. result: dict[str, bool] = {} - for rci_uuid in client_ids: - rci = str(rci_uuid) - - result[rci] = rci in group_service_account_principals + for incoming_client_id in client_ids: + result[incoming_client_id] = incoming_client_id in group_service_account_principals return result diff --git a/tests/management/principal/test_it_service.py b/tests/management/principal/test_it_service.py index f3479d3c1..94c2889fa 100644 --- a/tests/management/principal/test_it_service.py +++ b/tests/management/principal/test_it_service.py @@ -238,10 +238,10 @@ def test_generate_service_accounts_report_in_group_zero_matches(self): group.save() # Simulate that a few client IDs were specified in the request. - request_client_ids = set[uuid.UUID]() - request_client_ids.add(uuid.uuid4()) - request_client_ids.add(uuid.uuid4()) - request_client_ids.add(uuid.uuid4()) + request_client_ids = set[str]() + request_client_ids.add(str(uuid.uuid4())) + request_client_ids.add(str(uuid.uuid4())) + request_client_ids.add(str(uuid.uuid4())) # Call the function under test. result: dict[str, bool] = self.it_service.generate_service_accounts_report_in_group( @@ -250,17 +250,12 @@ def test_generate_service_accounts_report_in_group_zero_matches(self): # Assert that only the specified client IDs are present in the result. self.assertEqual(3, len(result)) - # Transform the UUIDs to strings to match the generated result and be able to create assertions. - request_client_ids_str: set[str] = set() - for rci in request_client_ids: - request_client_ids_str.add(str(rci)) - # Assert that all the service accounts were flagged as not present in the group. for client_id, is_present_in_group in result.items(): # Make sure the specified client IDs are in the set. self.assertEqual( True, - client_id in request_client_ids_str, + client_id in request_client_ids, "expected to find the specified client ID from the request in the returning result", ) # Make sure they are all set to "false" since there shouldn't be any of those client IDs in the group. @@ -348,14 +343,15 @@ def test_generate_service_accounts_report_in_group_mixed_results(self): } # Add all the UUIDs to a set to pass it to the function under test. - request_client_ids = set[uuid.UUID]() - request_client_ids.add(not_in_group) - request_client_ids.add(not_in_group_2) - request_client_ids.add(not_in_group_3) + request_client_ids = set[str]() + request_client_ids.add(str(not_in_group)) + request_client_ids.add(str(not_in_group_2)) + request_client_ids.add(str(not_in_group_3)) + # Specify the service accounts' UUIDs here too, because the function under test should flag them as present in # the group. - request_client_ids.add(client_uuid_1) - request_client_ids.add(client_uuid_2) + request_client_ids.add(str(client_uuid_1)) + request_client_ids.add(str(client_uuid_2)) # Call the function under test. result: dict[str, bool] = self.it_service.generate_service_accounts_report_in_group( @@ -365,11 +361,6 @@ def test_generate_service_accounts_report_in_group_mixed_results(self): # Assert that all the specified client IDs are present in the result. self.assertEqual(5, len(result)) - # Transform the UUIDs to strings to match the generated result and be able to create assertions. - request_client_ids_str: set[str] = set() - for rci in request_client_ids: - request_client_ids_str.add(str(rci)) - # Assert that the mixed matches are identified correctly. for client_id, is_it_present_in_group in result.items(): # If the value is "true" it should be present in the service accounts' result set from above. Else, it @@ -456,12 +447,12 @@ def test_generate_service_accounts_report_in_group_full_match(self): group.save() # Simulate that a few client IDs were specified in the request. - request_client_ids = set[uuid.UUID]() - request_client_ids.add(client_uuid_1) - request_client_ids.add(client_uuid_2) - request_client_ids.add(client_uuid_3) - request_client_ids.add(client_uuid_4) - request_client_ids.add(client_uuid_5) + request_client_ids = set[str]() + request_client_ids.add(str(client_uuid_1)) + request_client_ids.add(str(client_uuid_2)) + request_client_ids.add(str(client_uuid_3)) + request_client_ids.add(str(client_uuid_4)) + request_client_ids.add(str(client_uuid_5)) # Call the function under test. result: dict[str, bool] = self.it_service.generate_service_accounts_report_in_group( @@ -471,16 +462,11 @@ def test_generate_service_accounts_report_in_group_full_match(self): # Assert that all the specified client IDs are present in the result. self.assertEqual(5, len(result)) - # Transform the UUIDs to strings to match the generated result and be able to create assertions. - request_client_ids_str: set[str] = set() - for rci in request_client_ids: - request_client_ids_str.add(str(rci)) - # Assert that all the results are flagged as being part of the group. for client_id, is_present_in_group in result.items(): self.assertEqual( True, - client_id in request_client_ids_str, + client_id in request_client_ids, "expected to find the specified client ID from the request in the returning result", ) self.assertEqual(