diff --git a/rbac/management/role/serializer.py b/rbac/management/role/serializer.py index cb4bf515..1cd0dd19 100644 --- a/rbac/management/role/serializer.py +++ b/rbac/management/role/serializer.py @@ -144,10 +144,8 @@ def update(self, instance, validated_data): """Update the role object in the database.""" access_list = validated_data.pop("access") tenant = self.context["request"].tenant - role_name = instance.name - update_data = validate_role_update(instance, validated_data) - instance = update_role(role_name, update_data, tenant) + instance = update_role(instance, validated_data) create_access_for_role(instance, access_list, tenant) @@ -161,6 +159,15 @@ def get_external_tenant(self, obj): """Get the external tenant name if it's from an external tenant.""" return obj.external_tenant_name() + def validate(self, data): + """Validate the input data of role.""" + if self.instance and self.instance.system: + key = "role.update" + message = "System roles may not be updated." + error = {key: [_(message)]} + raise serializers.ValidationError(error) + return super().validate(data) + class RoleMinimumSerializer(SerializerCreateOverrideMixin, serializers.ModelSerializer): """Serializer for the Role model that doesn't return access info.""" @@ -308,13 +315,18 @@ class RolePatchSerializer(RoleSerializer): def update(self, instance, validated_data): """Patch the role object.""" - tenant = self.context["request"].tenant - role_name = instance.name - update_data = validate_role_update(instance, validated_data) - - instance = update_role(role_name, update_data, tenant, clear_access=False) + instance = update_role(instance, validated_data, clear_access=False) return instance + def validate(self, data): + """Validate the input data of patching role.""" + if self.instance.system: + key = "role.update" + message = "System roles may not be updated." + error = {key: [_(message)]} + raise serializers.ValidationError(error) + return super().validate(data) + class BindingMappingSerializer(serializers.ModelSerializer): """Serializer for the binding mapping.""" @@ -382,43 +394,19 @@ def create_access_for_role(role, access_list, tenant): ResourceDefinition.objects.create(**resource_def_item, access=access_obj, tenant=tenant) -def validate_role_update(instance, validated_data): - """Validate if role could be updated.""" - if instance.system: - key = "role.update" - message = "System roles may not be updated." - error = {key: [_(message)]} - raise serializers.ValidationError(error) - updated_name = validated_data.get("name", instance.name) - updated_display_name = validated_data.get("display_name", instance.display_name) - updated_description = validated_data.get("description", instance.description) - - return { - "updated_name": updated_name, - "updated_display_name": updated_display_name, - "updated_description": updated_description, - } - - -def update_role(role_name, update_data, tenant, clear_access=True): +def update_role(instance, validated_data, clear_access=True): """Update role attribute.""" - role = Role.objects.get(name=role_name, tenant=tenant) - update_fields = [] - if "updated_name" in update_data: - role.name = update_data["updated_name"] - update_fields.append("name") - if "updated_display_name" in update_data: - role.display_name = update_data["updated_display_name"] - update_fields.append("display_name") - if "updated_description" in update_data: - role.description = update_data["updated_description"] - update_fields.append("description") + for field_name in ["name", "display_name", "description"]: + if field_name not in validated_data: + continue + setattr(instance, field_name, validated_data[field_name]) + update_fields.append(field_name) - role.save(update_fields=update_fields) + instance.save(update_fields=update_fields) if clear_access: - role.access.all().delete() + instance.access.all().delete() - return role + return instance diff --git a/tests/management/role/test_view.py b/tests/management/role/test_view.py index c0c10c1e..7f4477c9 100644 --- a/tests/management/role/test_view.py +++ b/tests/management/role/test_view.py @@ -176,16 +176,16 @@ def setUp(self): self.groupTwo.policies.add(self.policyTwo) self.groupTwo.save() - self.adminRole = Role(**admin_def_role_config, tenant=self.tenant) - self.adminRole.save() - - self.platformAdminRole = Role(**platform_admin_def_role_config, tenant=self.tenant) - self.platformAdminRole.save() - self.public_tenant = Tenant.objects.get(tenant_name="public") self.sysPubRole = Role(**sys_pub_role_config, tenant=self.public_tenant) self.sysPubRole.save() + self.adminRole = Role(**admin_def_role_config, tenant=self.public_tenant) + self.adminRole.save() + + self.platformAdminRole = Role(**platform_admin_def_role_config, tenant=self.public_tenant) + self.platformAdminRole.save() + self.sysRole = Role(**sys_role_config, tenant=self.public_tenant) self.sysRole.save() @@ -1684,6 +1684,7 @@ def test_update_admin_default_role(self): test_data = {"name": "role_name", "display_name": "role_display", "access": access_data} response = client.put(url, test_data, format="json", **self.headers) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data["errors"][0]["detail"], "System roles may not be updated.") def test_delete_default_role(self): """Test that default roles are protected from deletion"""