Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: ad groups update of the user model #103

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 154 additions & 1 deletion helusers/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import pytest
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group

from helusers.jwt import JWT
from helusers.models import OIDCBackChannelLogoutEvent
from helusers.models import ADGroup, ADGroupMapping, OIDCBackChannelLogoutEvent

from .conftest import encoded_jwt_factory, ISSUER1

user_model = get_user_model()


@pytest.mark.django_db
class TestOIDCBackChannelLogoutEvent:
Expand Down Expand Up @@ -62,3 +66,152 @@ def test_receiving_the_same_logout_token_more_than_once_has_no_effect(self):
OIDCBackChannelLogoutEvent.objects.logout_token_received(logout_token)

assert OIDCBackChannelLogoutEvent.objects.count() == 1


@pytest.mark.django_db
class TestUserAdGroups:
ALL_AD_GROUPS_MAPPING = (
("ad_group_1", "group_1"),
("ad_group_2", "group_2"),
("ad_group_3", "group_3"),
)
ALL_AD_GROUP_NAMES = ("ad_group_1", "ad_group_2", "ad_group_3")
ALL_GROUP_NAMES = ("group_1", "group_2", "group_3")

@pytest.mark.parametrize(
"ad_group_mapping,old_ad_groups_names,old_groups_names,new_ad_groups_names,new_groups_names",
[
# Nothing changes
pytest.param(
ALL_AD_GROUPS_MAPPING,
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
id="nothing_changes",
),
# If not mapped, not added
pytest.param(
(("ad_group_1", "group_1"),),
("ad_group_1",),
("group_1",),
ALL_AD_GROUP_NAMES,
("group_1",),
id="not_mapped_not_added",
),
# New ones are added
pytest.param(
ALL_AD_GROUPS_MAPPING,
("ad_group_1",),
("group_1",),
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
id="new_added",
),
# Old ones are removed
pytest.param(
ALL_AD_GROUPS_MAPPING,
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
("ad_group_1",),
("group_1",),
id="old_removed",
),
# Mapped twice, given once
pytest.param(
(
("ad_group_1", "group_1"),
("ad_group_1_1", "group_1"),
("ad_group_2", "group_2"),
("ad_group_2_2", "group_2"),
("ad_group_3", "group_3"),
),
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
id="mapped_twice_given_once",
),
# Mapped twice, given twice & 1 removed.
pytest.param(
(
("ad_group_1", "group_1"),
("ad_group_1_1", "group_1"),
("ad_group_2", "group_2"),
("ad_group_2_2", "group_2"),
("ad_group_3", "group_3"),
),
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
(
"ad_group_1",
"ad_group_1_1",
"ad_group_2",
),
("group_1", "group_2"),
id="mapped_twice_given_twice",
),
# All mapped, empty list given: All should be removed.
pytest.param(
ALL_AD_GROUPS_MAPPING,
ALL_AD_GROUP_NAMES,
ALL_GROUP_NAMES,
[],
[],
id="all_removed",
),
],
)
def test_update_ad_groups(
self,
ad_group_mapping,
old_ad_groups_names,
old_groups_names,
new_ad_groups_names,
new_groups_names,
):
# Setup ad groups mapping
ADGroupMapping.objects.bulk_create(
[
ADGroupMapping(
ad_group=ADGroup.objects.get_or_create(
name=ad_group_name, display_name=ad_group_name
)[0],
group=Group.objects.get_or_create(name=group_name)[0],
)
for ad_group_name, group_name in ad_group_mapping
]
)

# Setup existing AD-groups
old_ad_groups = [
ADGroup.objects.get_or_create(name=name, display_name=name)[0]
for name in old_ad_groups_names
]

# Setup existing groups
old_groups = [
Group.objects.get_or_create(name=name)[0] for name in old_groups_names
]

# Setup a user
user = user_model.objects.create(username="testguy")
user.ad_groups.set(old_ad_groups)
user.groups.set(old_groups)
user.save()

# Expect that the ad groups and groups are persisted to the user instance
assert ADGroupMapping.objects.count() == len(ad_group_mapping)
assert user.ad_groups.count() == len(old_ad_groups_names)
assert user.groups.count() == len(old_groups_names)

# When user update_ad_groups is called
user.update_ad_groups(ad_group_names=new_ad_groups_names)

# Then user has a new set of groups
assert sorted([ad_group.name for ad_group in user.ad_groups.all()]) == list(
new_ad_groups_names
)
assert sorted([group.name for group in user.groups.all()]) == list(
new_groups_names
)