diff --git a/rbac/management/group/view.py b/rbac/management/group/view.py index a1581ae8..180451a1 100644 --- a/rbac/management/group/view.py +++ b/rbac/management/group/view.py @@ -24,6 +24,7 @@ from django.conf import settings from django.db import connection from django.db import transaction +from django.db.models import Q from django.db.models.aggregates import Count from django.utils.translation import gettext as _ from django_filters import rest_framework as filters @@ -160,7 +161,8 @@ class GroupViewSet( """ queryset = Group.objects.annotate( - principalCount=Count("principals", distinct=True), policyCount=Count("policies", distinct=True) + principalCount=Count("principals", filter=Q(principals__type="user"), distinct=True), + policyCount=Count("policies", distinct=True), ) permission_classes = (GroupAccessPermission,) lookup_field = "uuid" diff --git a/rbac/management/querysets.py b/rbac/management/querysets.py index 1c82e19b..dbd868f4 100644 --- a/rbac/management/querysets.py +++ b/rbac/management/querysets.py @@ -16,7 +16,7 @@ # """Queryset helpers for management module.""" from django.conf import settings -from django.db.models import QuerySet +from django.db.models import Q, QuerySet from django.db.models.aggregates import Count from django.urls import reverse from django.utils.translation import gettext as _ @@ -60,7 +60,8 @@ def get_annotated_groups(): """Return an annotated set of groups for the tenant.""" return Group.objects.annotate( - principalCount=Count("principals", distinct=True), policyCount=Count("policies", distinct=True) + principalCount=Count("principals", filter=Q(principals__type="user"), distinct=True), + policyCount=Count("policies", distinct=True), ) diff --git a/tests/management/group/test_view.py b/tests/management/group/test_view.py index 27b03843..cb1f2dc1 100644 --- a/tests/management/group/test_view.py +++ b/tests/management/group/test_view.py @@ -29,6 +29,7 @@ from api.models import Tenant, User from management.cache import TenantCache +from management.group.serializer import GroupInputSerializer from management.models import Group, Principal, Policy, Role, ExtRoleRelation, ExtTenant from tests.core.test_kafka import copy_call_args from tests.identity_request import IdentityRequest @@ -348,6 +349,98 @@ def test_read_group_list_success(self): self.assertIsNotNone(group.get("name")) self.assertEqual(group.get("name"), self.group.name) + # check that all fields from GroupInputSerializer are present + for key in GroupInputSerializer().fields.keys(): + self.assertIn(key, group.keys()) + + @override_settings(IT_BYPASS_TOKEN_VALIDATION=True) + @patch( + "management.principal.proxy.PrincipalProxy.request_filtered_principals", + return_value={ + "status_code": 200, + "data": [ + { + "org_id": "100001", + "is_org_admin": False, + "is_internal": False, + "id": 52567473, + "username": "user_based_principal", + "account_number": "1111111", + "is_active": True, + } + ], + }, + ) + @patch( + "management.principal.it_service.ITService.request_service_accounts", + return_value=[ + { + "clientID": "b7a82f30-bcef-013c-2452-6aa2427b506c", + "name": f"service_account_name", + "description": f"Service Account description", + "owner": "jsmith", + "username": "service_account-b7a82f30-bcef-013c-2452-6aa2427b506c", + "time_created": 1706784741, + "type": "service-account", + } + ], + ) + def test_read_group_list_principalCount(self, mock_request, sa_mock_request): + """Test that correct number is returned for principalCount.""" + # Create a test data - group with 1 user based and 1 service account principal + group_name = "TestGroup" + group = Group(name=group_name, tenant=self.tenant) + group.save() + + user_based_principal = Principal(username="user_based_principal", tenant=self.test_tenant) + user_based_principal.save() + + sa_uuid = "b7a82f30-bcef-013c-2452-6aa2427b506c" + sa_based_principal = Principal( + username="service_account-" + sa_uuid, + tenant=self.tenant, + type="service-account", + service_account_id=sa_uuid, + ) + sa_based_principal.save() + + group.principals.add(user_based_principal, sa_based_principal) + self.group.save() + + # Test that /groups/{uuid}/principals/ returns correct count of user based principals + url = f"{reverse('group-principals', kwargs={'uuid': group.uuid})}" + client = APIClient() + response = client.get(url, **self.headers) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(int(response.data.get("meta").get("count")), 1) + self.assertEqual(len(response.data.get("data")), 1) + principal_out = response.data.get("data")[0] + self.assertEqual(principal_out["username"], user_based_principal.username) + + # Test that /groups/{uuid}/principals/?principal_type=service-account returns + # correct count of service account based principals + url = f"{reverse('group-principals', kwargs={'uuid': group.uuid})}?principal_type=service-account" + client = APIClient() + response = client.get(url, **self.headers) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(int(response.data.get("meta").get("count")), 1) + self.assertEqual(len(response.data.get("data")), 1) + sa_out = response.data.get("data")[0] + self.assertEqual(sa_out["username"], sa_based_principal.username) + + # Test that /groups/?name= returns 1 group with principalCount for only user based principals + url = f"{reverse('group-list')}?name={group_name}" + client = APIClient() + response = client.get(url, **self.headers) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data.get("data")), 1) + + group = response.data.get("data")[0] + self.assertEqual(group["principalCount"], 1) + def test_get_group_by_partial_name_by_default(self): """Test that getting groups by name returns partial match by default.""" url = reverse("group-list")