Skip to content

Commit

Permalink
feat(organization-invites): handle duplicate invites (#26404)
Browse files Browse the repository at this point in the history
  • Loading branch information
raquelmsmith authored Nov 26, 2024
1 parent 2bd9f57 commit 0f78df2
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 3 deletions.
99 changes: 99 additions & 0 deletions posthog/api/organization_invite.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from datetime import datetime, timedelta
from typing import Any, Optional, cast
from uuid import UUID

import posthoganalytics
from django.db.models import QuerySet
from rest_framework import (
exceptions,
mixins,
Expand All @@ -15,6 +18,7 @@
from posthog.api.routing import TeamAndOrgViewSetMixin
from posthog.api.shared import UserBasicSerializer
from posthog.api.utils import action
from posthog.constants import INVITE_DAYS_VALIDITY
from posthog.email import is_email_available
from posthog.event_usage import report_bulk_invited, report_team_member_invited
from posthog.models import OrganizationInvite, OrganizationMembership
Expand All @@ -24,9 +28,85 @@
from posthog.tasks.email import send_invite


class OrganizationInviteManager:
@staticmethod
def combine_invites(
organization_id: UUID | str, validated_data: dict[str, Any], combine_pending_invites: bool = True
) -> dict[str, Any]:
"""Combines multiple pending invites for the same email address."""
if not combine_pending_invites:
return validated_data

existing_invites = OrganizationInviteManager._get_invites_for_user_org(
organization_id=organization_id, target_email=validated_data["target_email"]
)

if not existing_invites.exists():
return validated_data

validated_data["level"] = OrganizationInviteManager._get_highest_level(
existing_invites=existing_invites,
new_level=validated_data.get("level", OrganizationMembership.Level.MEMBER),
)

validated_data["private_project_access"] = OrganizationInviteManager._combine_project_access(
existing_invites=existing_invites, new_access=validated_data.get("private_project_access", [])
)

return validated_data

@staticmethod
def _get_invites_for_user_org(
organization_id: UUID | str, target_email: str, include_expired: bool = False
) -> QuerySet:
filters: dict[str, Any] = {
"organization_id": organization_id,
"target_email": target_email,
}

if not include_expired:
filters["created_at__gt"] = datetime.now() - timedelta(days=INVITE_DAYS_VALIDITY)

return OrganizationInvite.objects.filter(**filters).order_by("-created_at")

@staticmethod
def _get_highest_level(existing_invites: QuerySet, new_level: int) -> int:
levels = [invite.level for invite in existing_invites]
levels.append(new_level)
return max(levels)

@staticmethod
def _combine_project_access(existing_invites: QuerySet, new_access: list[dict]) -> list[dict]:
combined_access: dict[int, int] = {}

# Add new access first
for access in new_access:
combined_access[access["id"]] = access["level"]

# Combine with existing access, keeping highest levels
for invite in existing_invites:
if not invite.private_project_access:
continue

for access in invite.private_project_access:
project_id = access["id"]
if project_id not in combined_access or access["level"] > combined_access[project_id]:
combined_access[project_id] = access["level"]

return [{"id": project_id, "level": level} for project_id, level in combined_access.items()]

@staticmethod
def delete_existing_invites(organization_id: UUID | str, target_email: str) -> None:
"""Deletes all existing invites for a given email in an organization."""
OrganizationInviteManager._get_invites_for_user_org(
organization_id=organization_id, target_email=target_email, include_expired=True
).delete()


class OrganizationInviteSerializer(serializers.ModelSerializer):
created_by = UserBasicSerializer(read_only=True)
send_email = serializers.BooleanField(write_only=True, default=True)
combine_pending_invites = serializers.BooleanField(write_only=True, default=False)

class Meta:
model = OrganizationInvite
Expand All @@ -43,6 +123,7 @@ class Meta:
"message",
"private_project_access",
"send_email",
"combine_pending_invites",
]
read_only_fields = [
"id",
Expand Down Expand Up @@ -107,12 +188,30 @@ def create(self, validated_data: dict[str, Any], *args: Any, **kwargs: Any) -> O
user__email=validated_data["target_email"],
).exists():
raise exceptions.ValidationError("A user with this email address already belongs to the organization.")

combine_pending_invites = validated_data.pop("combine_pending_invites", False)
send_email = validated_data.pop("send_email", True)

# Handle invite combination if requested
if combine_pending_invites:
validated_data = OrganizationInviteManager.combine_invites(
organization_id=self.context["organization_id"],
validated_data=validated_data,
combine_pending_invites=True,
)

# Delete existing invites for this email
OrganizationInviteManager.delete_existing_invites(
organization_id=self.context["organization_id"], target_email=validated_data["target_email"]
)

# Create new invite
invite: OrganizationInvite = OrganizationInvite.objects.create(
organization_id=self.context["organization_id"],
created_by=self.context["request"].user,
**validated_data,
)

if is_email_available(with_absolute_urls=True) and send_email:
invite.emailing_attempt_made = True
send_invite(invite_id=invite.id)
Expand Down
152 changes: 149 additions & 3 deletions posthog/api/test/test_organization_invites.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import ANY, patch

from django.core import mail
from freezegun import freeze_time
from rest_framework import status

from ee.models.explicit_team_membership import ExplicitTeamMembership
Expand Down Expand Up @@ -156,18 +157,18 @@ def test_add_organization_invite_with_email_on_instance_but_send_email_prop_fals
# Assert invite email is not sent
self.assertEqual(len(mail.outbox), 0)

def test_can_create_invites_for_the_same_email_multiple_times(self):
def test_create_invites_for_the_same_email_multiple_times_deletes_older_invites(self):
email = "[email protected]"
count = OrganizationInvite.objects.count()

for _ in range(0, 2):
for _ in range(0, 3):
response = self.client.post("/api/organizations/@current/invites/", {"target_email": email})
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
obj = OrganizationInvite.objects.get(id=response.json()["id"])
self.assertEqual(obj.target_email, email)
self.assertEqual(obj.created_by, self.user)

self.assertEqual(OrganizationInvite.objects.count(), count + 2)
self.assertEqual(OrganizationInvite.objects.count(), count + 1)

def test_can_specify_membership_level_in_invite(self):
email = "[email protected]"
Expand Down Expand Up @@ -508,3 +509,148 @@ def test_delete_organization_invite_if_plain_member(self):
self.assertEqual(response.content, b"") # Empty response
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
self.assertFalse(OrganizationInvite.objects.exists())

# Combine pending invites

def test_combine_pending_invites_combines_levels_and_project_access(self):
email = "[email protected]"
private_team_1 = Team.objects.create(organization=self.organization, name="Private Team 1", access_control=True)
private_team_2 = Team.objects.create(organization=self.organization, name="Private Team 2", access_control=True)

ExplicitTeamMembership.objects.create(
team=private_team_1,
parent_membership=self.organization_membership,
level=ExplicitTeamMembership.Level.ADMIN,
)
ExplicitTeamMembership.objects.create(
team=private_team_2,
parent_membership=self.organization_membership,
level=ExplicitTeamMembership.Level.ADMIN,
)

# Create first invite with member access to team 1
first_invite = self.client.post(
"/api/organizations/@current/invites/",
{
"target_email": email,
"level": OrganizationMembership.Level.MEMBER,
"private_project_access": [{"id": private_team_1.id, "level": ExplicitTeamMembership.Level.MEMBER}],
},
).json()

# Create second invite with admin access to team 2
second_invite = self.client.post(
"/api/organizations/@current/invites/",
{
"target_email": email,
"level": OrganizationMembership.Level.ADMIN,
"private_project_access": [{"id": private_team_2.id, "level": ExplicitTeamMembership.Level.ADMIN}],
},
).json()

# Create third invite combining previous invites
response = self.client.post(
"/api/organizations/@current/invites/",
{
"target_email": email,
"level": OrganizationMembership.Level.MEMBER,
"private_project_access": [{"id": private_team_1.id, "level": ExplicitTeamMembership.Level.ADMIN}],
"combine_pending_invites": True,
},
)

self.assertEqual(response.status_code, status.HTTP_201_CREATED)
combined_invite = response.json()

# Check that previous invites are deleted
self.assertFalse(OrganizationInvite.objects.filter(id=first_invite["id"]).exists())
self.assertFalse(OrganizationInvite.objects.filter(id=second_invite["id"]).exists())

# Check that the new invite has the highest level (ADMIN)
self.assertEqual(combined_invite["level"], OrganizationMembership.Level.ADMIN)

# Check that private project access is combined with highest levels
expected_access = [
{"id": private_team_1.id, "level": ExplicitTeamMembership.Level.ADMIN},
{"id": private_team_2.id, "level": ExplicitTeamMembership.Level.ADMIN},
]
self.assertEqual(len(combined_invite["private_project_access"]), 2)
for access in expected_access:
self.assertIn(access, combined_invite["private_project_access"])

def test_combine_pending_invites_with_no_existing_invites(self):
email = "[email protected]"
response = self.client.post(
"/api/organizations/@current/invites/",
{
"target_email": email,
"level": OrganizationMembership.Level.MEMBER,
"combine_pending_invites": True,
},
)

self.assertEqual(response.status_code, status.HTTP_201_CREATED)
invite = response.json()
self.assertEqual(invite["level"], OrganizationMembership.Level.MEMBER)
self.assertEqual(invite["target_email"], email)
self.assertEqual(invite["private_project_access"], [])

@freeze_time("2024-01-10")
def test_combine_pending_invites_with_expired_invites(self):
email = "[email protected]"

# Create an expired invite
with freeze_time("2023-01-05"):
OrganizationInvite.objects.create(
organization=self.organization,
target_email=email,
level=OrganizationMembership.Level.ADMIN,
)

response = self.client.post(
"/api/organizations/@current/invites/",
{
"target_email": email,
"level": OrganizationMembership.Level.MEMBER,
"combine_pending_invites": True,
},
)

self.assertEqual(response.status_code, status.HTTP_201_CREATED)
invite = response.json()

# Check that the new invite uses its own level, not the expired invite's level
self.assertEqual(invite["level"], OrganizationMembership.Level.MEMBER)
self.assertEqual(invite["target_email"], email)
self.assertEqual(invite["private_project_access"], [])

def test_combine_pending_invites_false_expires_existing_invites(self):
email = "[email protected]"

# Create first invite
first_invite = self.client.post(
"/api/organizations/@current/invites/",
{
"target_email": email,
"level": OrganizationMembership.Level.ADMIN,
},
).json()

# Create second invite with combine_pending_invites=False
response = self.client.post(
"/api/organizations/@current/invites/",
{
"target_email": email,
"level": OrganizationMembership.Level.MEMBER,
"combine_pending_invites": False,
},
)

self.assertEqual(response.status_code, status.HTTP_201_CREATED)
new_invite = response.json()

# Check that previous invite is deleted
self.assertFalse(OrganizationInvite.objects.filter(id=first_invite["id"]).exists())

# Check that new invite uses its own level
self.assertEqual(new_invite["level"], OrganizationMembership.Level.MEMBER)

0 comments on commit 0f78df2

Please sign in to comment.