Skip to content

Commit

Permalink
✨(auth) add User and Group schemas
Browse files Browse the repository at this point in the history
Add support for users, groups and link them to operational units.
  • Loading branch information
jmaupetit committed May 20, 2024
1 parent fdecc3a commit 2f32ff9
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to

- Implement `OperationalUnit` schema
- Link `OperationalUnit` to `Station` using AFIREV prefixes
- Implement `User` and `Group` schemas

## [0.5.0] - 2024-05-15

Expand Down
21 changes: 20 additions & 1 deletion src/api/qualicharge/auth/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from datetime import datetime
from typing import Any, Dict

from polyfactory import PostGenerated
from polyfactory import PostGenerated, Use
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.pytest_plugin import register_fixture

from qualicharge.conf import settings
from qualicharge.factories import FrenchDataclassFactory, TimestampedSQLModelFactory

from .models import IDToken
from .schemas import Group, User


def set_token_exp(name: str, values: Dict[str, int], *args: Any, **kwargs: Any) -> int:
Expand All @@ -29,3 +31,20 @@ class IDTokenFactory(ModelFactory[IDToken]):
iat = int(datetime.now().timestamp())
scope = "email profile"
email = "[email protected]"


class UserFactory(TimestampedSQLModelFactory[User]):
"""User schema factory."""

username = Use(
lambda: FrenchDataclassFactory.__faker__.simple_profile().get("username")
)
email = Use(FrenchDataclassFactory.__faker__.ascii_company_email)
first_name = Use(FrenchDataclassFactory.__faker__.first_name)
last_name = Use(FrenchDataclassFactory.__faker__.last_name)


class GroupFactory(TimestampedSQLModelFactory[Group]):
"""Group schema factory."""

name = Use(FrenchDataclassFactory.__faker__.company)
60 changes: 60 additions & 0 deletions src/api/qualicharge/auth/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""QualiCharge authentication schemas."""

from typing import TYPE_CHECKING, Optional
from uuid import UUID, uuid4

from pydantic import EmailStr
from sqlalchemy.types import String
from sqlmodel import Field, Relationship, SQLModel

from qualicharge.schemas import BaseTimestampedSQLModel

if TYPE_CHECKING:
from qualicharge.schemas.core import OperationalUnit


# -- Many-to-many relationships
class UserGroup(SQLModel, table=True):
"""M2M User-Group intermediate table."""

user_id: UUID = Field(foreign_key="user.id", primary_key=True)
group_id: UUID = Field(foreign_key="group.id", primary_key=True)


class GroupOperationalUnit(SQLModel, table=True):
"""M2M Group-OperationalUnit intermediate table."""

group_id: UUID = Field(foreign_key="group.id", primary_key=True)
operational_unit_id: UUID = Field(
foreign_key="operationalunit.id", primary_key=True
)


# -- Core schemas
class User(BaseTimestampedSQLModel, table=True):
"""QualiCharge User."""

id: Optional[UUID] = Field(default_factory=lambda: uuid4().hex, primary_key=True)
username: str = Field(unique=True, max_length=150)
email: EmailStr = Field(sa_type=String)
first_name: Optional[str] = Field(max_length=150)
last_name: Optional[str] = Field(max_length=150)
is_active: bool = False
is_staff: bool = False
is_superuser: bool = False

# Relationships
groups: list["Group"] = Relationship(back_populates="users", link_model=UserGroup)


class Group(BaseTimestampedSQLModel, table=True):
"""QualiCharge Group."""

id: Optional[UUID] = Field(default_factory=lambda: uuid4().hex, primary_key=True)
name: str = Field(unique=True, max_length=150)

# Relationships
users: list["User"] = Relationship(back_populates="groups", link_model=UserGroup)
operational_units: list["OperationalUnit"] = Relationship(
back_populates="groups", link_model=GroupOperationalUnit
)
21 changes: 21 additions & 0 deletions src/api/qualicharge/auth/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""QualiCharge auth.utils module."""

from typing import Sequence, cast

from sqlalchemy import Column as SAColumn
from sqlalchemy.sql.roles import JoinTargetRole
from sqlmodel import Session as SMSession
from sqlmodel import select

from qualicharge.schemas.core import OperationalUnit

from .schemas import Group, User


def get_user_operational_units(user: User, session: SMSession) -> Sequence[str]:
"""Get user related operational unit codes."""
return session.exec(
select(OperationalUnit.code)
.join(cast(JoinTargetRole, OperationalUnit.groups))
.filter(cast(SAColumn, Group.id).in_(group.id for group in user.groups))
).all()
6 changes: 6 additions & 0 deletions src/api/qualicharge/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from qualicharge.conf import settings

# Nota bene: be sure to import all models that need to be migrated here
from qualicharge.auth.schemas import ( # noqa: F401
User,
UserGroup,
Group,
GroupOperationalUnit,
)
from qualicharge.schemas.core import ( # noqa: F401
Amenageur,
Enseigne,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""add user and group schemas
Revision ID: 7568f5ff860e
Revises: fda96abb970d
Create Date: 2024-05-20 14:20:28.454872
"""

from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
import sqlmodel
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = "7568f5ff860e"
down_revision: Union[str, None] = "fda96abb970d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"group",
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.CheckConstraint("created_at <= updated_at", name="pre-creation-update"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
op.create_table(
"user",
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("email", sa.String(), nullable=False),
sa.Column("first_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("last_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("is_staff", sa.Boolean(), nullable=False),
sa.Column("is_superuser", sa.Boolean(), nullable=False),
sa.CheckConstraint("created_at <= updated_at", name="pre-creation-update"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("username"),
)
op.create_table(
"groupoperationalunit",
sa.Column("group_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("operational_unit_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.ForeignKeyConstraint(
["group_id"],
["group.id"],
),
sa.ForeignKeyConstraint(
["operational_unit_id"],
["operationalunit.id"],
),
sa.PrimaryKeyConstraint("group_id", "operational_unit_id"),
)
op.create_table(
"usergroup",
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("group_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.ForeignKeyConstraint(
["group_id"],
["group.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("user_id", "group_id"),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("usergroup")
op.drop_table("groupoperationalunit")
op.drop_table("user")
op.drop_table("group")
# ### end Alembic commands ###
5 changes: 5 additions & 0 deletions src/api/qualicharge/schemas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sqlmodel import Session as SMSession
from sqlmodel.main import SQLModelConfig

from qualicharge.auth.schemas import Group, GroupOperationalUnit
from qualicharge.exceptions import ObjectDoesNotExist

from ..models.dynamic import SessionBase, StatusBase
Expand Down Expand Up @@ -187,7 +188,11 @@ class OperationalUnit(BaseTimestampedSQLModel, table=True):
name: str
type: OperationalUnitTypeEnum

# Relationships
stations: List["Station"] = Relationship(back_populates="operational_unit")
groups: List["Group"] = Relationship(
back_populates="operational_units", link_model=GroupOperationalUnit
)

def create_stations_fk(self, session: SMSession):
"""Create linked stations foreign keys."""
Expand Down
1 change: 1 addition & 0 deletions src/api/tests/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""QualiCharge auth module tests."""
39 changes: 39 additions & 0 deletions src/api/tests/auth/test_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Tests for qualicharge.auth.schemas module."""

from sqlmodel import select

from qualicharge.auth.factories import GroupFactory, UserFactory
from qualicharge.auth.schemas import GroupOperationalUnit, UserGroup
from qualicharge.schemas.core import OperationalUnit


def test_create_user_group_operational_units(db_session):
"""Test the user to operational unit relationship."""
UserFactory.__session__ = db_session
GroupFactory.__session__ = db_session

# Create users and groups
user_one, user_two = UserFactory.create_batch_sync(2)
group_one, group_two = GroupFactory.create_batch_sync(2)
db_session.add(UserGroup(user_id=user_one.id, group_id=group_one.id))
db_session.add(UserGroup(user_id=user_two.id, group_id=group_two.id))

assert group_one.users == [
user_one,
]
assert group_two.users == [
user_two,
]

# Link group to an operational unit
code = "FRS63"
operational_unit = db_session.exec(
select(OperationalUnit).where(OperationalUnit.code == code)
).one()
db_session.add(
GroupOperationalUnit(
group_id=group_one.id, operational_unit_id=operational_unit.id
)
)

assert user_one.groups[0].operational_units[0].id == operational_unit.id
51 changes: 51 additions & 0 deletions src/api/tests/auth/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Tests for qualicharge.auth.utils schemas module."""

from random import sample
from typing import cast

from sqlalchemy import Column as SAColumn
from sqlmodel import select

from qualicharge.auth.factories import GroupFactory, UserFactory
from qualicharge.auth.schemas import GroupOperationalUnit, UserGroup
from qualicharge.auth.utils import get_user_operational_units
from qualicharge.fixtures.operational_units import data as operational_unit_data
from qualicharge.schemas.core import OperationalUnit


def test_user_get_operational_units(db_session):
"""Test the User get_operational_units utility."""
UserFactory.__session__ = db_session
GroupFactory.__session__ = db_session

# Create user, groups and link them (with operational units)
user = UserFactory.create_sync()
n_groups = 8
groups = GroupFactory.create_batch_sync(n_groups)
user_n_groups = 2
user_groups = sample(groups, user_n_groups)
operational_unit_codes = [
operational_unit.code
for operational_unit in sample(operational_unit_data, n_groups)
]
operational_units = db_session.exec(
select(OperationalUnit).where(
cast(SAColumn, OperationalUnit.code).in_(operational_unit_codes)
)
)
db_session.add_all(
UserGroup(user_id=user.id, group_id=group.id) for group in user_groups
)
db_session.add_all(
GroupOperationalUnit(group_id=group.id, operational_unit_id=operational_unit.id)
for group, operational_unit in zip(groups, operational_units)
)

# Get operational unit codes
user_operational_unit_codes = get_user_operational_units(user, db_session)
assert len(user_operational_unit_codes) == user_n_groups
assert set(user_operational_unit_codes) == {
operational_unit.code
for group in user.groups
for operational_unit in group.operational_units
}

0 comments on commit 2f32ff9

Please sign in to comment.