Skip to content

Commit

Permalink
🐛(api) allow to update all statique-related model fields
Browse files Browse the repository at this point in the history
And not only create new entries when unique fields change.
  • Loading branch information
jmaupetit committed Jun 10, 2024
1 parent bea2879 commit b6f56eb
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 41 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to
### Fixed

- Improve database transactions in statique endpoints
- Allow to update all statique-related model fields

## [0.8.0] - 2024-05-31

Expand Down
29 changes: 23 additions & 6 deletions src/api/qualicharge/models/static.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""QualiCharge static models."""

import json
import re
from enum import StrEnum
from typing import Optional

Expand Down Expand Up @@ -62,15 +63,31 @@ class FrenchPhoneNumber(PhoneNumber):
default_region_code = "FR"


def to_coordinates_tuple(value):
"""Convert input string to a Coordinate tuple.
Two string formats are supported:
1. "[longitude: float, latitude: float]"
2. "POINT(longitude latitude)"
In both cases, the input string is converted to a reversed tuple
(latitude: float, longitude: float) that will be used as Coordinate input.
"""
if not isinstance(value, str):
return value
if m := re.match(
r"POINT\((?P<longitude>-?\d+\.\d+) (?P<latitude>-?\d+\.\d+)\)", value
):
return (m["latitude"], m["longitude"])
return tuple(reversed(json.loads(value)))


# A pivot type to handle DataGouv coordinates de/serialization.
DataGouvCoordinate = Annotated[
Coordinate,
# Input string format is: "[longitude: float, latitude: float]". It is converted to
# a reversed tuple (latitude: float, longitude: float) that will be used as
# Coordinate input.
BeforeValidator(
lambda x: tuple(reversed(json.loads(x))) if isinstance(x, str) else x
),
# Convert input string to a (latitude, longitude) Coordinate tuple input
BeforeValidator(to_coordinates_tuple),
# When serializing a coordinate we want a string array: "[long,lat]"
PlainSerializer(lambda x: f"[{x.longitude}, {x.latitude}]", return_type=str),
# Document expected longitude/latitude order in the description
Expand Down
66 changes: 50 additions & 16 deletions src/api/qualicharge/schemas/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""QualiCharge schemas utilities."""

import logging
from enum import IntEnum
from typing import Generator, List, NamedTuple, Optional, Set, Tuple, Type, cast

from sqlalchemy import func
Expand Down Expand Up @@ -32,9 +33,21 @@
DB_TO_STATIC_EXCLUDED_FIELDS = {"id", "created_at", "updated_at"}


class EntryStatus(IntEnum):
"""Describe entry status."""

EXISTS = 0
CREATED = 1
UPDATED = 2


def get_or_create(
session: Session, entry: SQLModel, fields: Optional[Set] = None, add: bool = True
) -> Tuple[bool, SQLModel]:
session: Session,
entry: SQLModel,
fields: Optional[Set] = None,
add: bool = True,
update: bool = False,
) -> Tuple[EntryStatus, SQLModel]:
"""Get or create schema instance.
Args:
Expand All @@ -43,9 +56,10 @@ def get_or_create(
fields: entry fields used in database query to select target entry.
Defaults to None (use all fields).
add: should we add the schema instance to the session?
update: should we update existing instance if required?
Returns:
A (bool, entry) tuple. The boolean states on the entry creation.
A (EntryStatus, entry) tuple. The status refers on the entry creation/update.
Raises:
DatabaseQueryException: Found multiple entries given input fields.
Expand All @@ -63,21 +77,32 @@ def get_or_create(

if db_entry is not None:
logger.debug(f"Found database entry with id: {db_entry.id}") # type: ignore[attr-defined]
return False, db_entry
if not update:
return EntryStatus.EXISTS, db_entry

# Update database entry
for key, value in entry.model_dump(
exclude=DB_TO_STATIC_EXCLUDED_FIELDS
).items():
setattr(db_entry, key, value)
session.add(db_entry)

return EntryStatus.UPDATED, db_entry

# Add new entry
if add:
session.add(entry)

return True, entry
return EntryStatus.CREATED, entry


def save_schema_from_statique(
session: Session,
schema_klass: Type[SQLModel],
statique: Statique,
fields: Optional[Set] = None,
) -> Tuple[bool, SQLModel]:
update: bool = False,
) -> Tuple[EntryStatus, SQLModel]:
"""Save schema to database from Statique instance.
Args:
Expand All @@ -86,9 +111,10 @@ def save_schema_from_statique(
statique: input static model definition
fields: entry fields used in database query to select target entry.
Defaults to None (use all fields).
update: should we update existing instance if required?
Returns:
A (bool, entry) tuple. The boolean states on the entry creation.
A (EntryStatus, entry) tuple. The status refers on the entry creation/update.
Raises:
DatabaseQueryException: Found multiple entries given input fields.
Expand All @@ -99,6 +125,7 @@ def save_schema_from_statique(
session,
entry,
fields=fields,
update=update,
)


Expand All @@ -114,25 +141,32 @@ def pdc_to_statique(pdc: PointDeCharge) -> Statique:
)


def save_statique(session: Session, statique: Statique) -> Statique:
def save_statique(
session: Session, statique: Statique, update: bool = False
) -> Statique:
"""Save Statique instance to database."""
# Core schemas
_, pdc = save_schema_from_statique(
session, PointDeCharge, statique, fields={"id_pdc_itinerance"}
session, PointDeCharge, statique, fields={"id_pdc_itinerance"}, update=update
)
_, station = save_schema_from_statique(
session, Station, statique, fields={"id_station_itinerance"}
session, Station, statique, fields={"id_station_itinerance"}, update=update
)
_, amenageur = save_schema_from_statique(
session, Amenageur, statique, update=update
)
_, operateur = save_schema_from_statique(
session, Operateur, statique, update=update
)
_, amenageur = save_schema_from_statique(session, Amenageur, statique)
_, operateur = save_schema_from_statique(session, Operateur, statique)
_, enseigne = save_schema_from_statique(session, Enseigne, statique)
_, enseigne = save_schema_from_statique(session, Enseigne, statique, update=update)
_, localisation = save_schema_from_statique(
session,
Localisation,
statique,
fields={
"adresse_station",
},
update=update,
)

# Relationships
Expand All @@ -153,11 +187,11 @@ def save_statique(session: Session, statique: Statique) -> Statique:


def update_statique(
session: Session, id_pdc_itinerance: str, update: Statique
session: Session, id_pdc_itinerance: str, to_update: Statique
) -> Statique:
"""Update given statique from its id_pdc_itinerance."""
# Check that submitted id_pdc_itinerance corresponds to the update
if id_pdc_itinerance != update.id_pdc_itinerance:
if id_pdc_itinerance != to_update.id_pdc_itinerance:
raise IntegrityError(
"Cannot update statique with a different id_pdc_itinerance"
)
Expand All @@ -173,7 +207,7 @@ def update_statique(
):
raise ObjectDoesNotExist("Statique with id_pdc_itinerance does not exist")

return save_statique(session, update)
return save_statique(session, to_update, update=True)


def save_statiques(
Expand Down
29 changes: 26 additions & 3 deletions src/api/tests/api/v1/routers/test_statique.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
from fastapi import status
from pydantic_extra_types.coordinate import Coordinate
from sqlalchemy import Column as SAColumn
from sqlalchemy import func
from sqlmodel import select
Expand All @@ -14,7 +15,11 @@
from qualicharge.auth.schemas import GroupOperationalUnit, ScopesEnum, User, UserGroup
from qualicharge.conf import settings
from qualicharge.factories.static import StatiqueFactory
from qualicharge.schemas.core import OperationalUnit, PointDeCharge, Station
from qualicharge.schemas.core import (
OperationalUnit,
PointDeCharge,
Station,
)
from qualicharge.schemas.utils import pdc_to_statique, save_statique, save_statiques


Expand Down Expand Up @@ -473,10 +478,28 @@ def test_update_for_superuser(client_auth, db_session):
"""Test the /statique/{id_pdc_itinerance} update endpoint (superuser case)."""
id_pdc_itinerance = "FR911E1111ER1"
db_statique = save_statique(
db_session, StatiqueFactory.build(id_pdc_itinerance=id_pdc_itinerance)
db_session,
StatiqueFactory.build(
id_pdc_itinerance=id_pdc_itinerance,
nom_amenageur="ACME Inc.",
nom_operateur="ACME Inc.",
nom_enseigne="ACME Inc.",
coordonneesXY=Coordinate(-1.0, 1.0),
station_deux_roues=False,
cable_t2_attache=False,
),
)
new_statique = db_statique.model_copy(
update={"contact_oprateur": "[email protected]"}, deep=True
update={
"contact_operateur": "[email protected]",
"nom_amenageur": "Magma Corp.",
"nom_operateur": "Magma Corp.",
"nom_enseigne": "Magma Corp.",
"coordonneesXY": Coordinate(1.0, 2.0),
"station_deux_roues": True,
"cable_t2_attache": True,
},
deep=True,
)

response = client_auth.put(
Expand Down
9 changes: 8 additions & 1 deletion src/api/tests/models/test_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def test_statique_model_coordonneesXY():
"""Test the Statique model coordonneesXY field."""
longitude = 12.3
longitude = -12.3
latitude = 16.2

# Expected raw input
Expand All @@ -33,6 +33,13 @@ def test_statique_model_coordonneesXY():
assert record.coordonneesXY.longitude == longitude
assert record.coordonneesXY.latitude == latitude

# Geometry input
record = StatiqueFactory.build(
coordonneesXY=f"POINT({longitude} {latitude})",
)
assert record.coordonneesXY.longitude == longitude
assert record.coordonneesXY.latitude == latitude


@pytest.mark.parametrize(
"phone_number",
Expand Down
Loading

0 comments on commit b6f56eb

Please sign in to comment.