Skip to content

Commit

Permalink
⚡️(api) improve bulk endpoints permissions checking
Browse files Browse the repository at this point in the history
Refactored a bit dynamic data linking to points of charge and
permissions using sets.
  • Loading branch information
jmaupetit committed Dec 12, 2024
1 parent 76d22d3 commit 1322d06
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 55 deletions.
1 change: 1 addition & 0 deletions src/api/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to
### Changed

- Prefetch user-related groups and operational units in `get_user` dependency
- Improve bulk endpoints permissions checking

## [0.16.0] - 2024-12-12

Expand Down
93 changes: 43 additions & 50 deletions src/api/qualicharge/api/v1/routers/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from qualicharge.schemas.core import PointDeCharge, Station, Status
from qualicharge.schemas.core import Session as QCSession
from qualicharge.schemas.utils import is_pdc_allowed_for_user
from qualicharge.schemas.utils import are_pdcs_allowed_for_user, is_pdc_allowed_for_user

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -313,18 +313,18 @@ async def create_status(
if not is_pdc_allowed_for_user(status.id_pdc_itinerance, user):
raise PermissionDenied("You cannot create statuses for this point of charge")

pdc = session.exec(
select(PointDeCharge).where(
pdc_id = session.exec(
select(PointDeCharge.id).where(
PointDeCharge.id_pdc_itinerance == status.id_pdc_itinerance
)
).one_or_none()
if pdc is None:
if pdc_id is None:
raise HTTPException(
status_code=fa_status.HTTP_404_NOT_FOUND,
detail="Attached point of charge does not exist",
)
db_status = Status(**status.model_dump(exclude={"id_pdc_itinerance"}))
db_status.point_de_charge_id = pdc.id
db_status.point_de_charge_id = pdc_id
session.add(db_status)
session.commit()

Expand All @@ -338,23 +338,23 @@ async def create_status_bulk(
session: Session = Depends(get_session),
) -> DynamiqueItemsCreatedResponse:
"""Create a statuses batch."""
for status in statuses:
if not is_pdc_allowed_for_user(status.id_pdc_itinerance, user):
raise PermissionDenied(
"You cannot submit data for an organization you are not assigned to"
)

# Check if all points of charge exist
# ids_pdc_itinerance = list({status.id_pdc_itinerance for status in statuses})
ids_pdc_itinerance = [status.id_pdc_itinerance for status in statuses]
ids_pdc_itinerance_set = set(ids_pdc_itinerance)
db_pdcs = session.exec(
select(PointDeCharge).filter(
cast(SAColumn, PointDeCharge.id_pdc_itinerance).in_(ids_pdc_itinerance_set)
ids_pdc_itinerance = {s.id_pdc_itinerance for s in statuses}
if not are_pdcs_allowed_for_user(ids_pdc_itinerance, user):
raise PermissionDenied(
"You cannot submit data for an organization you are not assigned to"
)
).all()

if len(db_pdcs) != len(ids_pdc_itinerance_set):
# Create a dict with keys as id_pdc_itinerance and values as PDC id
# for existing PDCs
db_pdcs = dict(
session.exec(
select(PointDeCharge.id_pdc_itinerance, PointDeCharge.id).filter(
cast(SAColumn, PointDeCharge.id_pdc_itinerance).in_(ids_pdc_itinerance)
)
).all()
)

if len(db_pdcs) != len(ids_pdc_itinerance):
raise HTTPException(
status_code=fa_status.HTTP_404_NOT_FOUND,
detail=(
Expand All @@ -363,15 +363,11 @@ async def create_status_bulk(
),
)

# Prepare statuses PDC index
db_pdc_ids = [pdc.id_pdc_itinerance for pdc in db_pdcs]
pdc_indexes = [db_pdc_ids.index(id_) for id_ in ids_pdc_itinerance]

# Create all statuses
db_statuses = []
for status, pdc_index in zip(statuses, pdc_indexes, strict=True):
for status in statuses:
db_status = Status(**status.model_dump(exclude={"id_pdc_itinerance"}))
db_status.point_de_charge_id = db_pdcs[pdc_index].id
db_status.point_de_charge_id = db_pdcs[status.id_pdc_itinerance]
db_statuses.append(db_status)
session.add_all(db_statuses)
session.commit()
Expand All @@ -396,18 +392,18 @@ async def create_session(
#
# - `db_session` / `Session` refers to the database session, while,
# - `session` / `QCSession` / `SessionCreate` refers to qualicharge charging session
pdc = db_session.exec(
select(PointDeCharge).where(
pdc_id = db_session.exec(
select(PointDeCharge.id).where(
PointDeCharge.id_pdc_itinerance == session.id_pdc_itinerance
)
).one_or_none()
if pdc is None:
if pdc_id is None:
raise HTTPException(
status_code=fa_status.HTTP_404_NOT_FOUND,
detail="Attached point of charge does not exist",
)
db_qc_session = QCSession(**session.model_dump(exclude={"id_pdc_itinerance"}))
db_qc_session.point_de_charge_id = pdc.id
db_qc_session.point_de_charge_id = pdc_id
db_session.add(db_qc_session)
db_session.commit()

Expand All @@ -421,22 +417,23 @@ async def create_session_bulk(
db_session: Session = Depends(get_session),
) -> DynamiqueItemsCreatedResponse:
"""Create a sessions batch."""
for session in sessions:
if not is_pdc_allowed_for_user(session.id_pdc_itinerance, user):
raise PermissionDenied(
"You cannot submit data for an organization you are not assigned to"
)

# Check if all points of charge exist
ids_pdc_itinerance = [session.id_pdc_itinerance for session in sessions]
ids_pdc_itinerance_set = set(ids_pdc_itinerance)
db_pdcs = db_session.exec(
select(PointDeCharge).filter(
cast(SAColumn, PointDeCharge.id_pdc_itinerance).in_(ids_pdc_itinerance_set)
ids_pdc_itinerance = {s.id_pdc_itinerance for s in sessions}
if not are_pdcs_allowed_for_user(ids_pdc_itinerance, user):
raise PermissionDenied(
"You cannot submit data for an organization you are not assigned to"
)
).all()

if len(db_pdcs) != len(ids_pdc_itinerance_set):
# Create a dict with keys as id_pdc_itinerance and values as PDC id
# for existing PDCs
db_pdcs = dict(
db_session.exec(
select(PointDeCharge.id_pdc_itinerance, PointDeCharge.id).filter(
cast(SAColumn, PointDeCharge.id_pdc_itinerance).in_(ids_pdc_itinerance)
)
).all()
)

if len(db_pdcs) != len(ids_pdc_itinerance):
raise HTTPException(
status_code=fa_status.HTTP_404_NOT_FOUND,
detail=(
Expand All @@ -445,15 +442,11 @@ async def create_session_bulk(
),
)

# Prepare statuses PDC index
db_pdc_ids = [pdc.id_pdc_itinerance for pdc in db_pdcs]
pdc_indexes = [db_pdc_ids.index(id_) for id_ in ids_pdc_itinerance]

# Create all statuses
db_qc_sessions = []
for session, pdc_index in zip(sessions, pdc_indexes, strict=True):
for session in sessions:
db_qc_session = QCSession(**session.model_dump(exclude={"id_pdc_itinerance"}))
db_qc_session.point_de_charge_id = db_pdcs[pdc_index].id
db_qc_session.point_de_charge_id = db_pdcs[session.id_pdc_itinerance]
db_qc_sessions.append(db_qc_session)
db_session.add_all(db_qc_sessions)
db_session.commit()
Expand Down
10 changes: 5 additions & 5 deletions src/api/qualicharge/api/v1/routers/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from qualicharge.schemas.core import OperationalUnit, PointDeCharge, Station
from qualicharge.schemas.sql import StatiqueImporter
from qualicharge.schemas.utils import (
are_pdcs_allowed_for_user,
build_statique,
is_pdc_allowed_for_user,
list_statique,
Expand Down Expand Up @@ -277,11 +278,10 @@ async def bulk(
session: Session = Depends(get_session),
) -> StatiqueItemsCreatedResponse:
"""Create a set of statique items."""
for statique in statiques:
if not is_pdc_allowed_for_user(statique.id_pdc_itinerance, user):
raise PermissionDenied(
"You cannot submit data for an organization you are not assigned to"
)
if not are_pdcs_allowed_for_user([s.id_pdc_itinerance for s in statiques], user):
raise PermissionDenied(
"You cannot submit data for an organization you are not assigned to"
)

# Convert statiques to a Pandas DataFrame
df = pd.read_json(
Expand Down
9 changes: 9 additions & 0 deletions src/api/qualicharge/schemas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,12 @@ def is_pdc_allowed_for_user(id_pdc_itinerance: str, user: User) -> bool:
if id_pdc_itinerance[:5] in [ou.code for ou in user.operational_units]:
return True
return False


def are_pdcs_allowed_for_user(ids_pdc_itinerance: set | list, user) -> bool:
"""Check of a user can create/read/update a list of PDCs given their identifiers."""
if user.is_superuser:
return True
operational_unit_codes: set = {id_[:5] for id_ in ids_pdc_itinerance}
user_operational_unit_codes: set = {ou.code for ou in user.operational_units}
return operational_unit_codes.issubset(user_operational_unit_codes)
37 changes: 37 additions & 0 deletions src/api/tests/schemas/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from qualicharge.schemas.utils import (
EntryStatus,
are_pdcs_allowed_for_user,
build_statique,
get_or_create,
is_pdc_allowed_for_user,
Expand Down Expand Up @@ -501,3 +502,39 @@ def test_is_pdc_allowed_for_user(db_session):
user = UserFactory.create_sync(is_superuser=False)
db_session.add(UserGroup(user_id=user.id, group_id=group.id))
assert is_pdc_allowed_for_user(id_pdc_itinerance, user) is False


def test_are_pdcs_allowed_for_user(db_session):
"""Test the are_pdcs_allowed_for_user utility."""
UserFactory.__session__ = db_session
GroupFactory.__session__ = db_session

# Superuser
user = UserFactory.create_sync(is_superuser=True)
assert are_pdcs_allowed_for_user(["FRFASE001", "FRS63E0001"], user) is True
assert are_pdcs_allowed_for_user({"FRFASE001", "FRS63E0001"}, user) is True

# Normal user with no assigned operational units
user = UserFactory.create_sync(is_superuser=False)
assert user.operational_units == []
assert are_pdcs_allowed_for_user(["FRFASE001", "FRS63E0001"], user) is False

# Create groups linked to Operational Units
groups = GroupFactory.create_batch_sync(3)
operational_units = db_session.exec(select(OperationalUnit).limit(3)).all()
for group, operational_unit in zip(groups, operational_units, strict=True):
db_session.add(
GroupOperationalUnit(
group_id=group.id, operational_unit_id=operational_unit.id
)
)

# And a user belonging to groups linked to those operational units
user = UserFactory.create_sync(is_superuser=False, groups=groups)
assert {ou.code for ou in user.operational_units} == {
ou.code for ou in operational_units
}

ids_pdc_itinerance = [f"{ou.code}P001" for ou in operational_units]
assert are_pdcs_allowed_for_user(ids_pdc_itinerance, user) is True
assert are_pdcs_allowed_for_user(ids_pdc_itinerance + ["FRFASP0001"], user) is False

0 comments on commit 1322d06

Please sign in to comment.