diff --git a/rbac/management/querysets.py b/rbac/management/querysets.py index 538640b1..d5a67eca 100644 --- a/rbac/management/querysets.py +++ b/rbac/management/querysets.py @@ -36,6 +36,7 @@ policies_for_principal, queryset_by_id, roles_for_principal, + validate_and_get_key, ) from rest_framework import permissions, serializers from rest_framework.request import Request @@ -97,7 +98,7 @@ def _gather_group_querysets(request, args, kwargs): """Decide which groups to provide for request.""" username = request.query_params.get("username") - scope = request.query_params.get(SCOPE_KEY, ORG_ID_SCOPE) + scope = validate_and_get_key(request.query_params, SCOPE_KEY, VALID_SCOPES, ORG_ID_SCOPE) if scope != ORG_ID_SCOPE and not username: return get_object_principal_queryset(request, scope, Group) @@ -149,7 +150,7 @@ def annotate_roles_with_counts(queryset): def get_role_queryset(request) -> QuerySet: """Obtain the queryset for roles.""" - scope = request.query_params.get(SCOPE_KEY, ORG_ID_SCOPE) + scope = validate_and_get_key(request.query_params, SCOPE_KEY, VALID_SCOPES, ORG_ID_SCOPE) public_tenant = Tenant.objects.get(tenant_name="public") base_query = annotate_roles_with_counts(Role.objects.prefetch_related("access")).filter( tenant__in=[request.tenant, public_tenant] @@ -213,7 +214,7 @@ def get_role_queryset(request) -> QuerySet: def get_policy_queryset(request): """Obtain the queryset for policies.""" - scope = request.query_params.get(SCOPE_KEY, ORG_ID_SCOPE) + scope = validate_and_get_key(request.query_params, SCOPE_KEY, VALID_SCOPES, ORG_ID_SCOPE) if scope != ORG_ID_SCOPE: return get_object_principal_queryset(request, scope, Policy) diff --git a/rbac/management/role/serializer.py b/rbac/management/role/serializer.py index 7ba142f4..3caff09b 100644 --- a/rbac/management/role/serializer.py +++ b/rbac/management/role/serializer.py @@ -20,12 +20,12 @@ from management.group.model import Group from management.notifications.notification_handlers import role_obj_change_notification_handler from management.serializer_override_mixin import SerializerCreateOverrideMixin -from management.utils import filter_queryset_by_tenant, get_principal +from management.utils import filter_queryset_by_tenant, get_principal, validate_and_get_key from rest_framework import serializers from api.models import Tenant from .model import Access, Permission, ResourceDefinition, Role -from ..querysets import PRINCIPAL_SCOPE +from ..querysets import ORG_ID_SCOPE, PRINCIPAL_SCOPE, SCOPE_KEY, VALID_SCOPES ALLOWED_OPERATIONS = ["in", "equal"] FILTER_FIELDS = {"key", "value", "operation"} @@ -329,7 +329,7 @@ def obtain_applications(obj): def obtain_groups_in(obj, request): """Shared function to get the groups the roles is in.""" - scope_param = request.query_params.get("scope") + scope_param = validate_and_get_key(request.query_params, SCOPE_KEY, VALID_SCOPES, ORG_ID_SCOPE) username_param = request.query_params.get("username") policy_ids = list(obj.policies.values_list("id", flat=True)) diff --git a/tests/management/group/test_view.py b/tests/management/group/test_view.py index 4e849ed0..36351ec1 100644 --- a/tests/management/group/test_view.py +++ b/tests/management/group/test_view.py @@ -2720,7 +2720,7 @@ def setUp(self): "permission." ) self.invalid_value_for_scope_query_param = ( - "scope query parameter value foo is invalid. [org_id, principal] are valid inputs." + "scope query parameter value 'foo' is invalid. ['org_id', 'principal'] are valid inputs." ) self.user_access_admin_role_err_message = ( "Non org admin users are not allowed to add RBAC role with higher than 'read' permission into groups." @@ -4771,7 +4771,7 @@ def test_read_group_with_username_and_scope_params(self, mock): # Adding the 'scope' param doesn't affect the response because the 'scope' param is ignored # when query contains the 'username' param - for scope in ("org_id", "principal", "foo"): + for scope in ("org_id", "principal"): url_with_scope = url + f"&scope={scope}" response = client.get(url_with_scope, format="json", **self.headers_user_based_principal) self.assertEqual(response.status_code, status.HTTP_200_OK)