Skip to content

Commit

Permalink
refactor: don't transform the incoming client IDs into UUIDs
Browse files Browse the repository at this point in the history
RHCLOUD-30299
  • Loading branch information
MikelAlejoBR committed Feb 21, 2024
1 parent d13fcf5 commit 9d42181
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 42 deletions.
6 changes: 3 additions & 3 deletions rbac/management/group/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,10 +693,10 @@ def principals(self, request: Request, uuid: Optional[UUID] = None):
# Turn the received and comma separated client IDs into a manageable set.
received_client_ids: set[str] = set(service_account_client_ids_raw.split(","))

processed_client_ids: set[UUID] = set()
# Validate that the provided strings are actually UUIDs.
for rci in received_client_ids:
try:
processed_client_ids.add(UUID(str(rci)))
UUID(rci)
except ValueError:
return Response(
status=status.HTTP_400_BAD_REQUEST,
Expand All @@ -715,7 +715,7 @@ def principals(self, request: Request, uuid: Optional[UUID] = None):
# ones are available to be added to the given group.
it_service = ITService()
result: dict = it_service.generate_service_accounts_report_in_group(
group=group, client_ids=processed_client_ids
group=group, client_ids=received_client_ids
)

# Prettify the output payload and return it.
Expand Down
9 changes: 3 additions & 6 deletions rbac/management/principal/it_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import time
import uuid
from typing import Optional, Tuple, Union
from uuid import UUID

import requests
from django.conf import settings
Expand Down Expand Up @@ -403,7 +402,7 @@ def extract_client_id_service_account_username(username: str) -> uuid.UUID:
}
)

def generate_service_accounts_report_in_group(self, group: Group, client_ids: set[UUID]) -> dict[str, bool]:
def generate_service_accounts_report_in_group(self, group: Group, client_ids: set[str]) -> dict[str, bool]:
"""Check if the given service accounts are in the specified group."""
# Fetch the service accounts from the group.
group_service_account_principals = (
Expand All @@ -414,10 +413,8 @@ def generate_service_accounts_report_in_group(self, group: Group, client_ids: se

# Mark the specified client IDs as "present or missing" from the result set.
result: dict[str, bool] = {}
for rci_uuid in client_ids:
rci = str(rci_uuid)

result[rci] = rci in group_service_account_principals
for incoming_client_id in client_ids:
result[incoming_client_id] = incoming_client_id in group_service_account_principals

return result

Expand Down
52 changes: 19 additions & 33 deletions tests/management/principal/test_it_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,10 @@ def test_generate_service_accounts_report_in_group_zero_matches(self):
group.save()

# Simulate that a few client IDs were specified in the request.
request_client_ids = set[uuid.UUID]()
request_client_ids.add(uuid.uuid4())
request_client_ids.add(uuid.uuid4())
request_client_ids.add(uuid.uuid4())
request_client_ids = set[str]()
request_client_ids.add(str(uuid.uuid4()))
request_client_ids.add(str(uuid.uuid4()))
request_client_ids.add(str(uuid.uuid4()))

# Call the function under test.
result: dict[str, bool] = self.it_service.generate_service_accounts_report_in_group(
Expand All @@ -250,17 +250,12 @@ def test_generate_service_accounts_report_in_group_zero_matches(self):
# Assert that only the specified client IDs are present in the result.
self.assertEqual(3, len(result))

# Transform the UUIDs to strings to match the generated result and be able to create assertions.
request_client_ids_str: set[str] = set()
for rci in request_client_ids:
request_client_ids_str.add(str(rci))

# Assert that all the service accounts were flagged as not present in the group.
for client_id, is_present_in_group in result.items():
# Make sure the specified client IDs are in the set.
self.assertEqual(
True,
client_id in request_client_ids_str,
client_id in request_client_ids,
"expected to find the specified client ID from the request in the returning result",
)
# Make sure they are all set to "false" since there shouldn't be any of those client IDs in the group.
Expand Down Expand Up @@ -348,14 +343,15 @@ def test_generate_service_accounts_report_in_group_mixed_results(self):
}

# Add all the UUIDs to a set to pass it to the function under test.
request_client_ids = set[uuid.UUID]()
request_client_ids.add(not_in_group)
request_client_ids.add(not_in_group_2)
request_client_ids.add(not_in_group_3)
request_client_ids = set[str]()
request_client_ids.add(str(not_in_group))
request_client_ids.add(str(not_in_group_2))
request_client_ids.add(str(not_in_group_3))

# Specify the service accounts' UUIDs here too, because the function under test should flag them as present in
# the group.
request_client_ids.add(client_uuid_1)
request_client_ids.add(client_uuid_2)
request_client_ids.add(str(client_uuid_1))
request_client_ids.add(str(client_uuid_2))

# Call the function under test.
result: dict[str, bool] = self.it_service.generate_service_accounts_report_in_group(
Expand All @@ -365,11 +361,6 @@ def test_generate_service_accounts_report_in_group_mixed_results(self):
# Assert that all the specified client IDs are present in the result.
self.assertEqual(5, len(result))

# Transform the UUIDs to strings to match the generated result and be able to create assertions.
request_client_ids_str: set[str] = set()
for rci in request_client_ids:
request_client_ids_str.add(str(rci))

# Assert that the mixed matches are identified correctly.
for client_id, is_it_present_in_group in result.items():
# If the value is "true" it should be present in the service accounts' result set from above. Else, it
Expand Down Expand Up @@ -456,12 +447,12 @@ def test_generate_service_accounts_report_in_group_full_match(self):
group.save()

# Simulate that a few client IDs were specified in the request.
request_client_ids = set[uuid.UUID]()
request_client_ids.add(client_uuid_1)
request_client_ids.add(client_uuid_2)
request_client_ids.add(client_uuid_3)
request_client_ids.add(client_uuid_4)
request_client_ids.add(client_uuid_5)
request_client_ids = set[str]()
request_client_ids.add(str(client_uuid_1))
request_client_ids.add(str(client_uuid_2))
request_client_ids.add(str(client_uuid_3))
request_client_ids.add(str(client_uuid_4))
request_client_ids.add(str(client_uuid_5))

# Call the function under test.
result: dict[str, bool] = self.it_service.generate_service_accounts_report_in_group(
Expand All @@ -471,16 +462,11 @@ def test_generate_service_accounts_report_in_group_full_match(self):
# Assert that all the specified client IDs are present in the result.
self.assertEqual(5, len(result))

# Transform the UUIDs to strings to match the generated result and be able to create assertions.
request_client_ids_str: set[str] = set()
for rci in request_client_ids:
request_client_ids_str.add(str(rci))

# Assert that all the results are flagged as being part of the group.
for client_id, is_present_in_group in result.items():
self.assertEqual(
True,
client_id in request_client_ids_str,
client_id in request_client_ids,
"expected to find the specified client ID from the request in the returning result",
)
self.assertEqual(
Expand Down

0 comments on commit 9d42181

Please sign in to comment.