Skip to content

Commit

Permalink
Integrate auditability in the API
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaupetit committed Dec 5, 2024
1 parent f7ec67f commit 36e70d1
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 20 deletions.
6 changes: 3 additions & 3 deletions src/api/qualicharge/api/v1/routers/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ async def update(

transaction = session.begin_nested()
try:
update = update_statique(session, id_pdc_itinerance, statique)
update = update_statique(session, id_pdc_itinerance, statique, author=user)
except QCIntegrityError as err:
transaction.rollback()
raise HTTPException(
Expand Down Expand Up @@ -243,7 +243,7 @@ async def create(

transaction = session.begin_nested()
try:
db_statique = save_statique(session, statique)
db_statique = save_statique(session, statique, author=user)
except ObjectDoesNotExist as err:
transaction.rollback()
raise HTTPException(
Expand Down Expand Up @@ -277,7 +277,7 @@ async def bulk(
)

transaction = session.begin_nested()
importer = StatiqueImporter(df, transaction.session.connection())
importer = StatiqueImporter(df, transaction.session.connection(), author=user)
try:
importer.save()
except (
Expand Down
18 changes: 12 additions & 6 deletions src/api/qualicharge/schemas/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sqlalchemy.schema import MetaData
from typing_extensions import Optional

from ..auth.schemas import User
from ..exceptions import ObjectDoesNotExist, ProgrammingError
from ..models.static import Statique
from .audit import BaseAuditableSQLModel
Expand All @@ -39,7 +40,9 @@
class StatiqueImporter:
"""Statique importer from a Pandas Dataframe."""

def __init__(self, df: pd.DataFrame, connection: Connection):
def __init__(
self, df: pd.DataFrame, connection: Connection, author: Optional[User] = None
):
"""Add table cache keys."""
logger.info("Loading input dataframe containing %d rows", len(df))

Expand All @@ -57,20 +60,20 @@ def __init__(self, df: pd.DataFrame, connection: Connection):
self._operational_units: Optional[pd.DataFrame] = None

self.connection: Connection = connection
self.author: Optional[User] = author

def __len__(self):
"""Object length corresponds to the static dataframe length."""
return len(self._statique)

@staticmethod
def _add_auditable_model_fields(df: pd.DataFrame):
def _add_auditable_model_fields(self, df: pd.DataFrame):
"""Add required fields for a BaseAuditableSQLModel."""
df["id"] = df.apply(lambda x: uuid.uuid4(), axis=1)
now = pd.Timestamp.now(tz="utc")
df["created_at"] = now
df["updated_at"] = now
df["created_by_id"] = None
df["updated_by_id"] = None
df["created_by_id"] = self.author.id if self.author else None
df["updated_by_id"] = self.author.id if self.author else None
return df

@staticmethod
Expand Down Expand Up @@ -247,7 +250,10 @@ def _save_schema(
f: stmt.excluded.get(f)
for f in self._get_fields_for_schema(schema, with_fk=True)
}
updates_on_conflict.update({"updated_at": stmt.excluded.updated_at})
updates_on_conflict.update({
"updated_at": stmt.excluded.updated_at,
"updated_by_id": stmt.excluded.updated_by_id,
})
stmt = stmt.on_conflict_do_update(
constraint=constraint,
index_elements=index_elements,
Expand Down
52 changes: 41 additions & 11 deletions src/api/qualicharge/schemas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy import func
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.schema import Column as SAColumn
from qualicharge.schemas.audit import BaseAuditableSQLModel
from sqlmodel import Session, SQLModel, select

from qualicharge.auth.schemas import User
Expand Down Expand Up @@ -90,7 +91,7 @@ def get_or_create(

# Update database entry
for key, value in entry.model_dump(
exclude=DB_TO_STATIC_EXCLUDED_FIELDS
exclude=list(set(DB_TO_STATIC_EXCLUDED_FIELDS) - {"updated_by_id"})
).items():
setattr(db_entry, key, value)
session.add(db_entry)
Expand All @@ -110,6 +111,7 @@ def save_schema_from_statique(
statique: Statique,
fields: Optional[Set] = None,
update: bool = False,
author: Optional[User] = None,
) -> Tuple[EntryStatus, SQLModel]:
"""Save schema to database from Statique instance.
Expand All @@ -120,6 +122,7 @@ def save_schema_from_statique(
fields: entry fields used in database query to select target entry.
Defaults to None (use all fields).
update: should we update existing instance if required?
author: the user that creates/updates the schema entry
Returns:
A (EntryStatus, entry) tuple. The status refers on the entry creation/update.
Expand All @@ -129,6 +132,12 @@ def save_schema_from_statique(
"""
# Is this a new entry?
entry = schema_klass(**statique.get_fields_for_schema(schema_klass))

# Add author for auditability
if issubclass(schema_klass, BaseAuditableSQLModel):
entry.created_by_id = author.id if author else None
entry.updated_by_id = author.id if author else None

return get_or_create(
session,
entry,
Expand All @@ -150,23 +159,38 @@ def pdc_to_statique(pdc: PointDeCharge) -> Statique:


def save_statique(
session: Session, statique: Statique, update: bool = False
session: Session,
statique: Statique,
update: bool = False,
author: Optional[User] = None,
) -> Statique:
"""Save Statique instance to database."""
# Core schemas
_, pdc = save_schema_from_statique(
session, PointDeCharge, statique, fields={"id_pdc_itinerance"}, update=update
session,
PointDeCharge,
statique,
fields={"id_pdc_itinerance"},
update=update,
author=author,
)
_, station = save_schema_from_statique(
session, Station, statique, fields={"id_station_itinerance"}, update=update
session,
Station,
statique,
fields={"id_station_itinerance"},
update=update,
author=author,
)
_, amenageur = save_schema_from_statique(
session, Amenageur, statique, update=update
session, Amenageur, statique, update=update, author=author
)
_, operateur = save_schema_from_statique(
session, Operateur, statique, update=update
session, Operateur, statique, update=update, author=author
)
_, enseigne = save_schema_from_statique(
session, Enseigne, statique, update=update, author=author
)
_, enseigne = save_schema_from_statique(session, Enseigne, statique, update=update)
_, localisation = save_schema_from_statique(
session,
Localisation,
Expand All @@ -175,6 +199,7 @@ def save_statique(
"adresse_station",
},
update=update,
author=author,
)

# Relationships
Expand All @@ -195,7 +220,10 @@ def save_statique(


def update_statique(
session: Session, id_pdc_itinerance: str, to_update: Statique
session: Session,
id_pdc_itinerance: str,
to_update: Statique,
author: Optional[User] = None,
) -> Statique:
"""Update given statique from its id_pdc_itinerance."""
# Check that submitted id_pdc_itinerance corresponds to the update
Expand All @@ -215,17 +243,19 @@ def update_statique(
):
raise ObjectDoesNotExist("Statique with id_pdc_itinerance does not exist")

return save_statique(session, to_update, update=True)
return save_statique(session, to_update, update=True, author=author)


def save_statiques(db_session: Session, statiques: List[Statique]):
def save_statiques(
db_session: Session, statiques: List[Statique], author: Optional[User] = None
):
"""Save input statiques to database."""
df = pd.read_json(
StringIO(f"{'\n'.join([s.model_dump_json() for s in statiques])}"),
lines=True,
dtype_backend="pyarrow",
)
importer = StatiqueImporter(df, db_session.connection())
importer = StatiqueImporter(df, db_session.connection(), author=author)
importer.save()


Expand Down

0 comments on commit 36e70d1

Please sign in to comment.