From c14ec9551f5551f557fb8ed99d39555884b537b8 Mon Sep 17 00:00:00 2001 From: Mikel Alejo Barcina Ribera Date: Thu, 29 Feb 2024 18:03:54 +0100 Subject: [PATCH] fix: service accounts not being validated When validating a single service account, we were sending requests to IT filtered by that single service account to avoid fetching the entire collection. However, due to a bug, the collections were always being returned empty. Until the bug is solved, we need to fetch the entire collections instead. RHCLOUD-31265 --- rbac/management/principal/it_service.py | 24 ++++----- tests/management/principal/test_it_service.py | 54 +++++++++++++++++++ 2 files changed, 63 insertions(+), 15 deletions(-) diff --git a/rbac/management/principal/it_service.py b/rbac/management/principal/it_service.py index 0906b180..6d9351f6 100644 --- a/rbac/management/principal/it_service.py +++ b/rbac/management/principal/it_service.py @@ -202,22 +202,16 @@ def _is_service_account_valid(self, user: User, client_id: str) -> bool: if settings.IT_BYPASS_IT_CALLS: return True else: - service_accounts: list[dict] = self.request_service_accounts( - bearer_token=user.bearer_token, - client_ids=[client_id], - ) + # In theory, we should be able to pass the client ID to the function below to just get the specified + # service account and check if it is present or not. However, due to a bug, we need to fetch the whole + # collection for now. More details in https://issues.redhat.com/browse/RHCLOUD-31265 . + service_accounts: list[dict] = self.request_service_accounts(bearer_token=user.bearer_token) - if len(service_accounts) == 0: - return False - elif len(service_accounts) == 1: - sa = service_accounts[0] - return client_id == sa.get("clientID") - else: - LOGGER.error( - f'unexpected number of service accounts received from IT. Wanted one with client ID "{client_id}",' - f" got {len(service_accounts)}: {service_accounts}" - ) - return False + for sa in service_accounts: + if client_id == sa.get("clientID"): + return True + + return False def get_service_accounts(self, user: User, options: dict = {}) -> Tuple[list[dict], int]: """Request and returns the service accounts for the given tenant.""" diff --git a/tests/management/principal/test_it_service.py b/tests/management/principal/test_it_service.py index e80e92e8..27ffb5f5 100644 --- a/tests/management/principal/test_it_service.py +++ b/tests/management/principal/test_it_service.py @@ -87,6 +87,60 @@ def test_is_service_account_valid_bypass_it_calls(self, _): finally: settings.IT_BYPASS_IT_CALLS = original_bypass_it_calls_value + @mock.patch("management.principal.it_service.ITService.request_service_accounts") + def test_is_service_account_valid(self, request_service_accounts: mock.Mock): + """Tests that the service account is considered valid when there is a match between the response from IT and the requested service account""" + user = User() + user.bearer_token = "mocked-bt" + + expected_client_id = str(uuid.uuid4()) + request_service_accounts.return_value = [{"clientID": expected_client_id}] + + self.assertEqual( + True, + self.it_service._is_service_account_valid(user=user, client_id=expected_client_id), + "when IT responds with a single service account and it matches, the function under test should return 'True'", + ) + + request_service_accounts.return_value = [ + {"clientID": str(uuid.uuid4())}, + {"clientID": str(uuid.uuid4())}, + {"clientID": expected_client_id}, + ] + + self.assertEqual( + True, + self.it_service._is_service_account_valid(user=user, client_id=expected_client_id), + "when IT responds with multiple service accounts and one of them matches, the function under test should return 'True'", + ) + + @mock.patch("management.principal.it_service.ITService.request_service_accounts") + def test_is_service_account_invalid(self, request_service_accounts: mock.Mock): + """Tests that the service account is considered invalid when there isn't a match between the response from IT and the requested service account""" + user = User() + user.bearer_token = "mocked-bt" + + expected_client_id = str(uuid.uuid4()) + request_service_accounts.return_value = [] + + self.assertEqual( + False, + self.it_service._is_service_account_valid(user=user, client_id=expected_client_id), + "when IT responds with a single service account and it does not match, the function under test should return 'False'", + ) + + request_service_accounts.return_value = [ + {"clientID": str(uuid.uuid4())}, + {"clientID": str(uuid.uuid4())}, + {"clientID": str(uuid.uuid4())}, + ] + + self.assertEqual( + False, + self.it_service._is_service_account_valid(user=user, client_id=expected_client_id), + "when IT responds with multiple service accounts and none of them match, the function under test should return 'False'", + ) + @mock.patch("management.principal.it_service.ITService.request_service_accounts") def test_is_service_account_valid_zero_results_from_it(self, request_service_accounts: mock.Mock): """Test that the function under test treats an empty result from IT as an invalid service account."""