diff --git a/rbac/rbac/middleware.py b/rbac/rbac/middleware.py index f2c2c2b8..2fd53960 100644 --- a/rbac/rbac/middleware.py +++ b/rbac/rbac/middleware.py @@ -22,8 +22,9 @@ from json.decoder import JSONDecodeError from django.conf import settings +from django.core.handlers.wsgi import WSGIRequest from django.db import IntegrityError -from django.http import Http404, HttpResponse +from django.http import Http404, HttpResponse, QueryDict from django.urls import resolve from django.utils.deprecation import MiddlewareMixin from management.cache import TenantCache @@ -259,7 +260,7 @@ def process_request(self, request): # pylint: disable=R1710 return HttpResponse(json.dumps(payload), content_type="application/json", status=400) if settings.AUTHENTICATE_WITH_ORG_ID: - if not user.admin and not (request.path.endswith("/access/") and request.method == "GET"): + if self.should_load_user_permissions(request, user): try: tenant = Tenant.objects.filter(org_id=user.org_id).get() except Tenant.DoesNotExist: @@ -268,7 +269,7 @@ def process_request(self, request): # pylint: disable=R1710 user.access = IdentityHeaderMiddleware._get_access_for_user(user.username, tenant) else: - if not user.admin and not (request.path.endswith("/access/") and request.method == "GET"): + if self.should_load_user_permissions(request, user): try: tenant_name = create_tenant_name(user.account) tenant = Tenant.objects.filter(tenant_name=tenant_name).get() @@ -416,6 +417,29 @@ def process_response(self, request, response): # pylint: disable=no-self-use IdentityHeaderMiddleware.log_request(request, response, is_internal) return response + def should_load_user_permissions(self, request: WSGIRequest, user: User) -> bool: + """Decide whether RBAC should load the access permissions for the user based on the given request.""" + # Organization administrators will have already all the permissions so there is no need to load permissions for + # them. + if user.admin: + return False + + # The access endpoint gets a lot of traffic, so we need to restrict for which queries we are actually going + # to load the user permissions, since it is a very heavy operation. The following Jira tickets have more + # details: + # + # - RHCLOUD-15394 + # - RHCLOUD-29631 + # + # There is one use case where we need to load the user's permissions: whenever they want to query for their + # or other users' permissions. In that case, we need to know if they're allowed to do so, and for that, we + # need to preload their permissions to check them afterward in the subsequent permission checkers. + if request.path.endswith("/access/") and request.method == "GET": + query_params: QueryDict = request.GET + return "username" in query_params and "application" in query_params + else: + return True + class DisableCSRF(MiddlewareMixin): # pylint: disable=too-few-public-methods """Middleware to disable CSRF for 3scale usecase.""" diff --git a/tests/rbac/test_middleware.py b/tests/rbac/test_middleware.py index d4c3755e..2290acf4 100644 --- a/tests/rbac/test_middleware.py +++ b/tests/rbac/test_middleware.py @@ -19,6 +19,7 @@ import os from unittest.mock import Mock from django.conf import settings +from django.http import QueryDict from django.test import TestCase from django.urls import reverse @@ -273,6 +274,74 @@ def test_tenant_process_without_org_id(self): self.assertEqual(Tenant.objects.filter(tenant_name="test_user").count(), 1) self.assertEqual(Tenant.objects.filter(tenant_name="test_user").first().org_id, None) + def test_should_load_user_permissions_org_admin(self): + """Tests that the function that determines if user permissions should be loaded returns False for org admins.""" + user = User() + user.admin = True + + middleware = IdentityHeaderMiddleware(get_response=Mock()) + self.assertEqual(middleware.should_load_user_permissions(Mock(), user), False) + + def test_should_load_user_permissions_regular_user_non_access_endpoint(self): + """Tests that the function under test returns True for regular users who have requested a path which isn't the access path""" + user = User() + user.admin = False + + request = Mock() + request.path = "/principals/" + + middleware = IdentityHeaderMiddleware(get_response=Mock()) + self.assertEqual(middleware.should_load_user_permissions(request, user), True) + + def test_should_load_user_permissions_regular_user_access_non_get_request(self): + """Tests that the function under test returns True for regular users who have requested the access path but with a different HTTP verb than GET""" + user = User() + user.admin = False + + request = Mock() + request.path = "/access/" + + middleware = IdentityHeaderMiddleware(get_response=Mock()) + + http_verbs = ["DELETE", "PATCH", "POST"] + for verb in http_verbs: + request.method = verb + self.assertEqual(middleware.should_load_user_permissions(request, user), True) + + def test_should_load_user_permissions_regular_user_access(self): + """Tests that the function under test returns True for regular users who have requested the access path with the expected query parameters""" + user = User() + user.admin = False + + request = Mock() + request.path = "/access/" + request.method = "GET" + request.GET = QueryDict("application=rbac&username=foo") + middleware = IdentityHeaderMiddleware(get_response=Mock()) + self.assertEqual(middleware.should_load_user_permissions(request, user), True) + + def test_should_load_user_permissions_regular_user_access_missing_query_params(self): + """Tests that the function under test returns False for regular users who have requested the access path without the expected query parameters""" + user = User() + user.admin = False + + request = Mock() + request.path = "/access/" + request.method = "GET" + + test_cases: list[QueryDict] = [ + QueryDict("application=rbac"), + QueryDict("username=foo"), + QueryDict("applications=rbac&username=foo"), + QueryDict("application=rbac&usernames=foo"), + ] + + middleware = IdentityHeaderMiddleware(get_response=Mock()) + for test_case in test_cases: + request.GET = test_case + + self.assertEqual(middleware.should_load_user_permissions(request, user), False) + class ServiceToService(IdentityRequest): """Tests requests without an identity header."""