-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for users, groups and link them to operational units.
- Loading branch information
Showing
10 changed files
with
292 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,13 +4,15 @@ | |
from datetime import datetime | ||
from typing import Any, Dict | ||
|
||
from polyfactory import PostGenerated | ||
from polyfactory import PostGenerated, Use | ||
from polyfactory.factories.pydantic_factory import ModelFactory | ||
from polyfactory.pytest_plugin import register_fixture | ||
|
||
from qualicharge.conf import settings | ||
from qualicharge.factories import FrenchDataclassFactory, TimestampedSQLModelFactory | ||
|
||
from .models import IDToken | ||
from .schemas import Group, User | ||
|
||
|
||
def set_token_exp(name: str, values: Dict[str, int], *args: Any, **kwargs: Any) -> int: | ||
|
@@ -29,3 +31,20 @@ class IDTokenFactory(ModelFactory[IDToken]): | |
iat = int(datetime.now().timestamp()) | ||
scope = "email profile" | ||
email = "[email protected]" | ||
|
||
|
||
class UserFactory(TimestampedSQLModelFactory[User]): | ||
"""User schema factory.""" | ||
|
||
username = Use( | ||
lambda: FrenchDataclassFactory.__faker__.simple_profile().get("username") | ||
) | ||
email = Use(FrenchDataclassFactory.__faker__.ascii_company_email) | ||
first_name = Use(FrenchDataclassFactory.__faker__.first_name) | ||
last_name = Use(FrenchDataclassFactory.__faker__.last_name) | ||
|
||
|
||
class GroupFactory(TimestampedSQLModelFactory[Group]): | ||
"""Group schema factory.""" | ||
|
||
name = Use(FrenchDataclassFactory.__faker__.company) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""QualiCharge authentication schemas.""" | ||
|
||
from typing import TYPE_CHECKING, Optional | ||
from uuid import UUID, uuid4 | ||
|
||
from pydantic import EmailStr | ||
from sqlalchemy.types import String | ||
from sqlmodel import Field, Relationship, SQLModel | ||
|
||
from qualicharge.schemas import BaseTimestampedSQLModel | ||
|
||
if TYPE_CHECKING: | ||
from qualicharge.schemas.core import OperationalUnit | ||
|
||
|
||
# -- Many-to-many relationships | ||
class UserGroup(SQLModel, table=True): | ||
"""M2M User-Group intermediate table.""" | ||
|
||
user_id: UUID = Field(foreign_key="user.id", primary_key=True) | ||
group_id: UUID = Field(foreign_key="group.id", primary_key=True) | ||
|
||
|
||
class GroupOperationalUnit(SQLModel, table=True): | ||
"""M2M Group-OperationalUnit intermediate table.""" | ||
|
||
group_id: UUID = Field(foreign_key="group.id", primary_key=True) | ||
operational_unit_id: UUID = Field( | ||
foreign_key="operationalunit.id", primary_key=True | ||
) | ||
|
||
|
||
# -- Core schemas | ||
class User(BaseTimestampedSQLModel, table=True): | ||
"""QualiCharge User.""" | ||
|
||
id: Optional[UUID] = Field(default_factory=lambda: uuid4().hex, primary_key=True) | ||
username: str = Field(unique=True, max_length=150) | ||
email: EmailStr = Field(sa_type=String) | ||
first_name: Optional[str] = Field(max_length=150) | ||
last_name: Optional[str] = Field(max_length=150) | ||
is_active: bool = False | ||
is_staff: bool = False | ||
is_superuser: bool = False | ||
|
||
# Relationships | ||
groups: list["Group"] = Relationship(back_populates="users", link_model=UserGroup) | ||
|
||
|
||
class Group(BaseTimestampedSQLModel, table=True): | ||
"""QualiCharge Group.""" | ||
|
||
id: Optional[UUID] = Field(default_factory=lambda: uuid4().hex, primary_key=True) | ||
name: str = Field(unique=True, max_length=150) | ||
|
||
# Relationships | ||
users: list["User"] = Relationship(back_populates="groups", link_model=UserGroup) | ||
operational_units: list["OperationalUnit"] = Relationship( | ||
back_populates="groups", link_model=GroupOperationalUnit | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""QualiCharge auth.utils module.""" | ||
|
||
from typing import Sequence, cast | ||
|
||
from sqlalchemy import Column as SAColumn | ||
from sqlalchemy.sql.roles import JoinTargetRole | ||
from sqlmodel import Session as SMSession | ||
from sqlmodel import select | ||
|
||
from qualicharge.schemas.core import OperationalUnit | ||
|
||
from .schemas import Group, User | ||
|
||
|
||
def get_user_operational_units(user: User, session: SMSession) -> Sequence[str]: | ||
"""Get user related operational unit codes.""" | ||
return session.exec( | ||
select(OperationalUnit.code) | ||
.join(cast(JoinTargetRole, OperationalUnit.groups)) | ||
.filter(cast(SAColumn, Group.id).in_(group.id for group in user.groups)) | ||
).all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
88 changes: 88 additions & 0 deletions
88
src/api/qualicharge/migrations/versions/7568f5ff860e_add_user_and_group_schemas.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""add user and group schemas | ||
Revision ID: 7568f5ff860e | ||
Revises: fda96abb970d | ||
Create Date: 2024-05-20 14:20:28.454872 | ||
""" | ||
|
||
from typing import Sequence, Union | ||
|
||
from alembic import op | ||
import sqlalchemy as sa | ||
import sqlmodel | ||
from sqlalchemy.dialects import postgresql | ||
|
||
# revision identifiers, used by Alembic. | ||
revision: str = "7568f5ff860e" | ||
down_revision: Union[str, None] = "fda96abb970d" | ||
branch_labels: Union[str, Sequence[str], None] = None | ||
depends_on: Union[str, Sequence[str], None] = None | ||
|
||
|
||
def upgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.create_table( | ||
"group", | ||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), | ||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), | ||
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), | ||
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), | ||
sa.CheckConstraint("created_at <= updated_at", name="pre-creation-update"), | ||
sa.PrimaryKeyConstraint("id"), | ||
sa.UniqueConstraint("name"), | ||
) | ||
op.create_table( | ||
"user", | ||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), | ||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), | ||
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), | ||
sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False), | ||
sa.Column("email", sa.String(), nullable=False), | ||
sa.Column("first_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True), | ||
sa.Column("last_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True), | ||
sa.Column("is_active", sa.Boolean(), nullable=False), | ||
sa.Column("is_staff", sa.Boolean(), nullable=False), | ||
sa.Column("is_superuser", sa.Boolean(), nullable=False), | ||
sa.CheckConstraint("created_at <= updated_at", name="pre-creation-update"), | ||
sa.PrimaryKeyConstraint("id"), | ||
sa.UniqueConstraint("username"), | ||
) | ||
op.create_table( | ||
"groupoperationalunit", | ||
sa.Column("group_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), | ||
sa.Column("operational_unit_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), | ||
sa.ForeignKeyConstraint( | ||
["group_id"], | ||
["group.id"], | ||
), | ||
sa.ForeignKeyConstraint( | ||
["operational_unit_id"], | ||
["operationalunit.id"], | ||
), | ||
sa.PrimaryKeyConstraint("group_id", "operational_unit_id"), | ||
) | ||
op.create_table( | ||
"usergroup", | ||
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), | ||
sa.Column("group_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), | ||
sa.ForeignKeyConstraint( | ||
["group_id"], | ||
["group.id"], | ||
), | ||
sa.ForeignKeyConstraint( | ||
["user_id"], | ||
["user.id"], | ||
), | ||
sa.PrimaryKeyConstraint("user_id", "group_id"), | ||
) | ||
# ### end Alembic commands ### | ||
|
||
|
||
def downgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.drop_table("usergroup") | ||
op.drop_table("groupoperationalunit") | ||
op.drop_table("user") | ||
op.drop_table("group") | ||
# ### end Alembic commands ### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""QualiCharge auth module tests.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
"""Tests for qualicharge.auth.schemas module.""" | ||
|
||
from sqlmodel import select | ||
|
||
from qualicharge.auth.factories import GroupFactory, UserFactory | ||
from qualicharge.auth.schemas import GroupOperationalUnit, UserGroup | ||
from qualicharge.schemas.core import OperationalUnit | ||
|
||
|
||
def test_create_user_group_operational_units(db_session): | ||
"""Test the user to operational unit relationship.""" | ||
UserFactory.__session__ = db_session | ||
GroupFactory.__session__ = db_session | ||
|
||
# Create users and groups | ||
user_one, user_two = UserFactory.create_batch_sync(2) | ||
group_one, group_two = GroupFactory.create_batch_sync(2) | ||
db_session.add(UserGroup(user_id=user_one.id, group_id=group_one.id)) | ||
db_session.add(UserGroup(user_id=user_two.id, group_id=group_two.id)) | ||
|
||
assert group_one.users == [ | ||
user_one, | ||
] | ||
assert group_two.users == [ | ||
user_two, | ||
] | ||
|
||
# Link group to an operational unit | ||
code = "FRS63" | ||
operational_unit = db_session.exec( | ||
select(OperationalUnit).where(OperationalUnit.code == code) | ||
).one() | ||
db_session.add( | ||
GroupOperationalUnit( | ||
group_id=group_one.id, operational_unit_id=operational_unit.id | ||
) | ||
) | ||
|
||
assert user_one.groups[0].operational_units[0].id == operational_unit.id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""Tests for qualicharge.auth.utils schemas module.""" | ||
|
||
from random import sample | ||
from typing import cast | ||
|
||
from sqlalchemy import Column as SAColumn | ||
from sqlmodel import select | ||
|
||
from qualicharge.auth.factories import GroupFactory, UserFactory | ||
from qualicharge.auth.schemas import GroupOperationalUnit, UserGroup | ||
from qualicharge.auth.utils import get_user_operational_units | ||
from qualicharge.fixtures.operational_units import data as operational_unit_data | ||
from qualicharge.schemas.core import OperationalUnit | ||
|
||
|
||
def test_user_get_operational_units(db_session): | ||
"""Test the User get_operational_units utility.""" | ||
UserFactory.__session__ = db_session | ||
GroupFactory.__session__ = db_session | ||
|
||
# Create user, groups and link them (with operational units) | ||
user = UserFactory.create_sync() | ||
n_groups = 8 | ||
groups = GroupFactory.create_batch_sync(n_groups) | ||
user_n_groups = 2 | ||
user_groups = sample(groups, user_n_groups) | ||
operational_unit_codes = [ | ||
operational_unit.code | ||
for operational_unit in sample(operational_unit_data, n_groups) | ||
] | ||
operational_units = db_session.exec( | ||
select(OperationalUnit).where( | ||
cast(SAColumn, OperationalUnit.code).in_(operational_unit_codes) | ||
) | ||
) | ||
db_session.add_all( | ||
UserGroup(user_id=user.id, group_id=group.id) for group in user_groups | ||
) | ||
db_session.add_all( | ||
GroupOperationalUnit(group_id=group.id, operational_unit_id=operational_unit.id) | ||
for group, operational_unit in zip(groups, operational_units) | ||
) | ||
|
||
# Get operational unit codes | ||
user_operational_unit_codes = get_user_operational_units(user, db_session) | ||
assert len(user_operational_unit_codes) == user_n_groups | ||
assert set(user_operational_unit_codes) == { | ||
operational_unit.code | ||
for group in user.groups | ||
for operational_unit in group.operational_units | ||
} |