Skip to content

Commit

Permalink
Update service account auth to not require rbac enabled org
Browse files Browse the repository at this point in the history
  • Loading branch information
matiasb committed Dec 12, 2024
1 parent b8dc7af commit 84806d4
Show file tree
Hide file tree
Showing 15 changed files with 124 additions and 86 deletions.
2 changes: 1 addition & 1 deletion engine/apps/api/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def user_is_authorized(user: "User", required_permissions: LegacyAccessControlCo
`required_permissions` - A list of permissions that a user must have to be considered authorized
"""
organization = user.organization
if organization.is_rbac_permissions_enabled:
if organization.is_rbac_permissions_enabled or user.is_service_account:
user_permissions = [u["action"] for u in user.permissions]
required_permission_values = get_required_permission_values(organization, required_permissions)
return all(permission in user_permissions for permission in required_permission_values)
Expand Down
4 changes: 0 additions & 4 deletions engine/apps/auth_token/models/service_account_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ def organization(self):

@classmethod
def validate_token(cls, organization, token):
# require RBAC enabled to allow service account auth
if not organization.is_rbac_permissions_enabled:
raise InvalidToken

# Grafana API request: get permissions and confirm token is valid
permissions = get_service_account_token_permissions(organization, token)
if not permissions:
Expand Down
41 changes: 6 additions & 35 deletions engine/apps/auth_token/tests/test_grafana_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from apps.auth_token.auth import X_GRAFANA_INSTANCE_ID, GrafanaServiceAccountAuthentication
from apps.auth_token.models import ServiceAccountToken
from apps.auth_token.tests.helpers import setup_service_account_api_mocks
from apps.user_management.models import Organization, ServiceAccountUser
from apps.user_management.models import Organization
from common.constants.plugin_ids import PluginID
from settings.base import CLOUD_LICENSE_NAME, OPEN_SOURCE_LICENSE_NAME, SELF_HOSTED_SETTINGS

Expand Down Expand Up @@ -115,31 +115,10 @@ def test_grafana_authentication_invalid_grafana_url():
assert exc.value.detail == "Organization not found."


@pytest.mark.django_db
@httpretty.activate(verbose=True, allow_net_connect=False)
def test_grafana_authentication_rbac_disabled_fails(make_organization):
organization = make_organization(grafana_url="http://grafana.test")
if organization.is_rbac_permissions_enabled:
return

token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz"
headers = {
"HTTP_AUTHORIZATION": token,
"HTTP_X_GRAFANA_URL": organization.grafana_url,
}
request = APIRequestFactory().get("/", **headers)

with pytest.raises(exceptions.AuthenticationFailed) as exc:
GrafanaServiceAccountAuthentication().authenticate(request)
assert exc.value.detail == "Invalid token."


@pytest.mark.django_db
@httpretty.activate(verbose=True, allow_net_connect=False)
def test_grafana_authentication_permissions_call_fails(make_organization):
organization = make_organization(grafana_url="http://grafana.test")
if not organization.is_rbac_permissions_enabled:
return

token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz"
headers = {
Expand Down Expand Up @@ -170,8 +149,6 @@ def test_grafana_authentication_existing_token(
make_organization, make_service_account_for_organization, make_token_for_service_account
):
organization = make_organization(grafana_url="http://grafana.test")
if not organization.is_rbac_permissions_enabled:
return
service_account = make_service_account_for_organization(organization)
token_string = "glsa_the-token"
token = make_token_for_service_account(service_account, token_string)
Expand All @@ -187,7 +164,7 @@ def test_grafana_authentication_existing_token(

user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request)

assert isinstance(user, ServiceAccountUser)
assert user.is_service_account
assert user.service_account == service_account
assert user.public_primary_key == service_account.public_primary_key
assert user.username == service_account.username
Expand All @@ -206,8 +183,6 @@ def test_grafana_authentication_existing_token(
@httpretty.activate(verbose=True, allow_net_connect=False)
def test_grafana_authentication_token_created(make_organization):
organization = make_organization(grafana_url="http://grafana.test")
if not organization.is_rbac_permissions_enabled:
return
token_string = "glsa_the-token"

headers = {
Expand All @@ -223,7 +198,7 @@ def test_grafana_authentication_token_created(make_organization):

user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request)

assert isinstance(user, ServiceAccountUser)
assert user.is_service_account
service_account = user.service_account
assert service_account.organization == organization
assert user.public_primary_key == service_account.public_primary_key
Expand All @@ -248,8 +223,6 @@ def test_grafana_authentication_token_created(make_organization):
@httpretty.activate(verbose=True, allow_net_connect=False)
def test_grafana_authentication_token_created_older_grafana(make_organization):
organization = make_organization(grafana_url="http://grafana.test")
if not organization.is_rbac_permissions_enabled:
return
token_string = "glsa_the-token"

headers = {
Expand All @@ -265,7 +238,7 @@ def test_grafana_authentication_token_created_older_grafana(make_organization):

user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request)

assert isinstance(user, ServiceAccountUser)
assert user.is_service_account
service_account = user.service_account
assert service_account.organization == organization
# use fallback data
Expand All @@ -278,8 +251,6 @@ def test_grafana_authentication_token_created_older_grafana(make_organization):
@httpretty.activate(verbose=True, allow_net_connect=False)
def test_grafana_authentication_token_reuse_service_account(make_organization, make_service_account_for_organization):
organization = make_organization(grafana_url="http://grafana.test")
if not organization.is_rbac_permissions_enabled:
return
service_account = make_service_account_for_organization(organization)
token_string = "glsa_the-token"

Expand All @@ -299,7 +270,7 @@ def test_grafana_authentication_token_reuse_service_account(make_organization, m

user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request)

assert isinstance(user, ServiceAccountUser)
assert user.is_service_account
assert user.service_account == service_account
assert auth_token.service_account == service_account

Expand Down Expand Up @@ -335,7 +306,7 @@ def sync_org():

mock_setup_org.assert_called_once()

assert isinstance(user, ServiceAccountUser)
assert user.is_service_account
service_account = user.service_account
# organization is created
organization = Organization.objects.filter(grafana_url=grafana_url).get()
Expand Down
6 changes: 5 additions & 1 deletion engine/apps/grafana_plugin/serializers/sync_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class SyncOnCallSettingsSerializer(serializers.Serializer):
labels_enabled = serializers.BooleanField()
irm_enabled = serializers.BooleanField(default=False)

def validate_grafana_url(self, value):
# remove trailing slash for URL consistency
return value.rstrip("/")

def create(self, validated_data):
return SyncSettings(**validated_data)

Expand All @@ -81,7 +85,7 @@ def to_representation(self, instance):


class SyncDataSerializer(serializers.Serializer):
users = serializers.ListField(child=SyncUserSerializer())
users = serializers.ListField(child=SyncUserSerializer(), allow_null=True, allow_empty=True)
teams = serializers.ListField(child=SyncTeamSerializer(), allow_null=True, allow_empty=True)
team_members = TeamMemberMappingField()
settings = SyncOnCallSettingsSerializer()
Expand Down
65 changes: 64 additions & 1 deletion engine/apps/grafana_plugin/tests/test_sync_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rest_framework.test import APIClient

from apps.api.permissions import LegacyAccessControlRole
from apps.grafana_plugin.serializers.sync_data import SyncTeamSerializer
from apps.grafana_plugin.serializers.sync_data import SyncOnCallSettingsSerializer, SyncTeamSerializer
from apps.grafana_plugin.sync_data import SyncData, SyncSettings, SyncUser
from apps.grafana_plugin.tasks.sync_v2 import start_sync_organizations_v2, sync_organizations_v2
from common.constants.plugin_ids import PluginID
Expand Down Expand Up @@ -197,6 +197,47 @@ def test_sync_v2_irm_enabled(
assert organization.is_grafana_irm_enabled == expected


@patch("apps.grafana_plugin.helpers.client.GrafanaAPIClient.check_token", return_value=(None, {"connected": True}))
@pytest.mark.django_db
def test_sync_v2_none_values(
# mock this out so that we're not making a real network call, the sync v2 endpoint ends up calling
# user_management.sync._sync_organization which calls GrafanaApiClient.check_token
_mock_grafana_api_client_check_token,
make_organization_and_user_with_plugin_token,
make_user_auth_headers,
settings,
):
settings.LICENSE = settings.CLOUD_LICENSE_NAME
organization, _, token = make_organization_and_user_with_plugin_token()

client = APIClient()
headers = make_user_auth_headers(None, token, organization=organization)
url = reverse("grafana-plugin:sync-v2")

data = SyncData(
users=None,
teams=None,
team_members={},
settings=SyncSettings(
stack_id=organization.stack_id,
org_id=organization.org_id,
license=settings.CLOUD_LICENSE_NAME,
oncall_api_url="http://localhost",
oncall_token="",
grafana_url="http://localhost",
grafana_token="fake_token",
rbac_enabled=False,
incident_enabled=False,
incident_backend_url="",
labels_enabled=False,
irm_enabled=False,
),
)

response = client.post(url, format="json", data=asdict(data), **headers)
assert response.status_code == status.HTTP_200_OK


@pytest.mark.parametrize(
"test_team, validation_pass",
[
Expand All @@ -218,6 +259,28 @@ def test_sync_team_serialization(test_team, validation_pass):
assert (validation_error is None) == validation_pass


@pytest.mark.django_db
def test_sync_grafana_url_serialization():
data = {
"stack_id": 123,
"org_id": 321,
"license": "OSS",
"oncall_api_url": "http://localhost",
"oncall_token": "",
"grafana_url": "http://localhost/",
"grafana_token": "fake_token",
"rbac_enabled": False,
"incident_enabled": False,
"incident_backend_url": "",
"labels_enabled": False,
"irm_enabled": False,
}
serializer = SyncOnCallSettingsSerializer(data=data)
serializer.is_valid(raise_exception=True)
cleaned_data = serializer.save()
assert cleaned_data.grafana_url == "http://localhost"


@pytest.mark.django_db
def test_sync_batch_tasks(make_organization, settings):
settings.SYNC_V2_MAX_TASKS = 2
Expand Down
5 changes: 2 additions & 3 deletions engine/apps/public_api/serializers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from apps.alerts.models import AlertReceiveChannel
from apps.base.messaging import get_messaging_backends
from apps.integrations.legacy_prefix import has_legacy_prefix, remove_legacy_prefix
from apps.user_management.models import ServiceAccountUser
from common.api_helpers.custom_fields import TeamPrimaryKeyRelatedField
from common.api_helpers.exceptions import BadRequest
from common.api_helpers.mixins import PHONE_CALL, SLACK, SMS, TELEGRAM, WEB, EagerLoadingMixin
Expand Down Expand Up @@ -129,8 +128,8 @@ def create(self, validated_data):
try:
instance = AlertReceiveChannel.create(
**validated_data,
author=user if not isinstance(user, ServiceAccountUser) else None,
service_account=user.service_account if isinstance(user, ServiceAccountUser) else None,
author=user if not user.is_service_account else None,
service_account=user.service_account if user.is_service_account else None,
organization=organization,
)
except AlertReceiveChannel.DuplicateDirectPagingError:
Expand Down
3 changes: 1 addition & 2 deletions engine/apps/public_api/serializers/resolution_notes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from rest_framework import serializers

from apps.alerts.models import AlertGroup, ResolutionNote
from apps.user_management.models import ServiceAccountUser
from common.api_helpers.custom_fields import OrganizationFilteredPrimaryKeyRelatedField, UserIdField
from common.api_helpers.exceptions import BadRequest
from common.api_helpers.mixins import EagerLoadingMixin
Expand Down Expand Up @@ -36,7 +35,7 @@ class Meta:

def create(self, validated_data):
user = self.context["request"].user
if not isinstance(user, ServiceAccountUser) and user.pk:
if not user.is_service_account and user.pk:
validated_data["author"] = user
validated_data["source"] = ResolutionNote.Source.WEB
return super().create(validated_data)
Expand Down
1 change: 1 addition & 0 deletions engine/apps/public_api/serializers/webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def validate_preset(self, preset):
raise serializers.ValidationError(PRESET_VALIDATION_MESSAGE)

def validate_user(self, user):
# user may also be a string when handling requests from the deprecated custom action API
if isinstance(user, ServiceAccountUser):
return None
return user
Expand Down
9 changes: 3 additions & 6 deletions engine/apps/public_api/tests/test_integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,9 @@ def test_create_integration_via_service_account(
HTTP_AUTHORIZATION=f"{token_string}",
HTTP_X_GRAFANA_URL=organization.grafana_url,
)
if not organization.is_rbac_permissions_enabled:
assert response.status_code == status.HTTP_403_FORBIDDEN
else:
assert response.status_code == status.HTTP_201_CREATED
integration = AlertReceiveChannel.objects.get(public_primary_key=response.data["id"])
assert integration.service_account == service_account
assert response.status_code == status.HTTP_201_CREATED
integration = AlertReceiveChannel.objects.get(public_primary_key=response.data["id"])
assert integration.service_account == service_account


@pytest.mark.django_db
Expand Down
15 changes: 6 additions & 9 deletions engine/apps/public_api/tests/test_resolution_notes.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,12 @@ def test_create_resolution_note_via_service_account(
HTTP_AUTHORIZATION=f"{token_string}",
HTTP_X_GRAFANA_URL=organization.grafana_url,
)
if not organization.is_rbac_permissions_enabled:
assert response.status_code == status.HTTP_403_FORBIDDEN
else:
assert response.status_code == status.HTTP_201_CREATED
mock_send_update_resolution_note_signal.assert_called_once()
resolution_note = ResolutionNote.objects.get(public_primary_key=response.data["id"])
assert resolution_note.author is None
assert resolution_note.text == data["text"]
assert resolution_note.alert_group == alert_group
assert response.status_code == status.HTTP_201_CREATED
mock_send_update_resolution_note_signal.assert_called_once()
resolution_note = ResolutionNote.objects.get(public_primary_key=response.data["id"])
assert resolution_note.author is None
assert resolution_note.text == data["text"]
assert resolution_note.alert_group == alert_group


@pytest.mark.django_db
Expand Down
11 changes: 4 additions & 7 deletions engine/apps/public_api/tests/test_webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,10 @@ def test_create_webhook_via_service_account(
HTTP_AUTHORIZATION=f"{token_string}",
HTTP_X_GRAFANA_URL=organization.grafana_url,
)
if not organization.is_rbac_permissions_enabled:
assert response.status_code == status.HTTP_403_FORBIDDEN
else:
assert response.status_code == status.HTTP_201_CREATED
webhook = Webhook.objects.get(public_primary_key=response.data["id"])
expected_result = _get_expected_result(webhook)
assert response.data == expected_result
assert response.status_code == status.HTTP_201_CREATED
webhook = Webhook.objects.get(public_primary_key=response.data["id"])
expected_result = _get_expected_result(webhook)
assert response.data == expected_result


@pytest.mark.django_db
Expand Down
Loading

0 comments on commit 84806d4

Please sign in to comment.