diff --git a/src/api/qualicharge/api/v1/routers/static.py b/src/api/qualicharge/api/v1/routers/static.py index b32630e3..921884eb 100644 --- a/src/api/qualicharge/api/v1/routers/static.py +++ b/src/api/qualicharge/api/v1/routers/static.py @@ -210,7 +210,7 @@ async def update( transaction = session.begin_nested() try: - update = update_statique(session, id_pdc_itinerance, statique) + update = update_statique(session, id_pdc_itinerance, statique, author=user) except QCIntegrityError as err: transaction.rollback() raise HTTPException( @@ -243,7 +243,7 @@ async def create( transaction = session.begin_nested() try: - db_statique = save_statique(session, statique) + db_statique = save_statique(session, statique, author=user) except ObjectDoesNotExist as err: transaction.rollback() raise HTTPException( @@ -277,7 +277,7 @@ async def bulk( ) transaction = session.begin_nested() - importer = StatiqueImporter(df, transaction.session.connection()) + importer = StatiqueImporter(df, transaction.session.connection(), author=user) try: importer.save() except ( diff --git a/src/api/qualicharge/schemas/sql.py b/src/api/qualicharge/schemas/sql.py index 7458a037..5b0d4ffe 100644 --- a/src/api/qualicharge/schemas/sql.py +++ b/src/api/qualicharge/schemas/sql.py @@ -17,6 +17,7 @@ from sqlalchemy.schema import MetaData from typing_extensions import Optional +from ..auth.schemas import User from ..exceptions import ObjectDoesNotExist, ProgrammingError from ..models.static import Statique from .audit import BaseAuditableSQLModel @@ -39,7 +40,9 @@ class StatiqueImporter: """Statique importer from a Pandas Dataframe.""" - def __init__(self, df: pd.DataFrame, connection: Connection): + def __init__( + self, df: pd.DataFrame, connection: Connection, author: Optional[User] = None + ): """Add table cache keys.""" logger.info("Loading input dataframe containing %d rows", len(df)) @@ -57,20 +60,20 @@ def __init__(self, df: pd.DataFrame, connection: Connection): self._operational_units: Optional[pd.DataFrame] = None self.connection: Connection = connection + self.author: Optional[User] = author def __len__(self): """Object length corresponds to the static dataframe length.""" return len(self._statique) - @staticmethod - def _add_auditable_model_fields(df: pd.DataFrame): + def _add_auditable_model_fields(self, df: pd.DataFrame): """Add required fields for a BaseAuditableSQLModel.""" df["id"] = df.apply(lambda x: uuid.uuid4(), axis=1) now = pd.Timestamp.now(tz="utc") df["created_at"] = now df["updated_at"] = now - df["created_by_id"] = None - df["updated_by_id"] = None + df["created_by_id"] = self.author.id if self.author else None + df["updated_by_id"] = self.author.id if self.author else None return df @staticmethod @@ -247,7 +250,10 @@ def _save_schema( f: stmt.excluded.get(f) for f in self._get_fields_for_schema(schema, with_fk=True) } - updates_on_conflict.update({"updated_at": stmt.excluded.updated_at}) + updates_on_conflict.update({ + "updated_at": stmt.excluded.updated_at, + "updated_by_id": stmt.excluded.updated_by_id, + }) stmt = stmt.on_conflict_do_update( constraint=constraint, index_elements=index_elements, diff --git a/src/api/qualicharge/schemas/utils.py b/src/api/qualicharge/schemas/utils.py index 15e2a4ee..52259326 100644 --- a/src/api/qualicharge/schemas/utils.py +++ b/src/api/qualicharge/schemas/utils.py @@ -9,6 +9,7 @@ from sqlalchemy import func from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.schema import Column as SAColumn +from qualicharge.schemas.audit import BaseAuditableSQLModel from sqlmodel import Session, SQLModel, select from qualicharge.auth.schemas import User @@ -90,7 +91,7 @@ def get_or_create( # Update database entry for key, value in entry.model_dump( - exclude=DB_TO_STATIC_EXCLUDED_FIELDS + exclude=list(set(DB_TO_STATIC_EXCLUDED_FIELDS) - {"updated_by_id"}) ).items(): setattr(db_entry, key, value) session.add(db_entry) @@ -110,6 +111,7 @@ def save_schema_from_statique( statique: Statique, fields: Optional[Set] = None, update: bool = False, + author: Optional[User] = None, ) -> Tuple[EntryStatus, SQLModel]: """Save schema to database from Statique instance. @@ -120,6 +122,7 @@ def save_schema_from_statique( fields: entry fields used in database query to select target entry. Defaults to None (use all fields). update: should we update existing instance if required? + author: the user that creates/updates the schema entry Returns: A (EntryStatus, entry) tuple. The status refers on the entry creation/update. @@ -129,6 +132,12 @@ def save_schema_from_statique( """ # Is this a new entry? entry = schema_klass(**statique.get_fields_for_schema(schema_klass)) + + # Add author for auditability + if issubclass(schema_klass, BaseAuditableSQLModel): + entry.created_by_id = author.id if author else None + entry.updated_by_id = author.id if author else None + return get_or_create( session, entry, @@ -150,23 +159,38 @@ def pdc_to_statique(pdc: PointDeCharge) -> Statique: def save_statique( - session: Session, statique: Statique, update: bool = False + session: Session, + statique: Statique, + update: bool = False, + author: Optional[User] = None, ) -> Statique: """Save Statique instance to database.""" # Core schemas _, pdc = save_schema_from_statique( - session, PointDeCharge, statique, fields={"id_pdc_itinerance"}, update=update + session, + PointDeCharge, + statique, + fields={"id_pdc_itinerance"}, + update=update, + author=author, ) _, station = save_schema_from_statique( - session, Station, statique, fields={"id_station_itinerance"}, update=update + session, + Station, + statique, + fields={"id_station_itinerance"}, + update=update, + author=author, ) _, amenageur = save_schema_from_statique( - session, Amenageur, statique, update=update + session, Amenageur, statique, update=update, author=author ) _, operateur = save_schema_from_statique( - session, Operateur, statique, update=update + session, Operateur, statique, update=update, author=author + ) + _, enseigne = save_schema_from_statique( + session, Enseigne, statique, update=update, author=author ) - _, enseigne = save_schema_from_statique(session, Enseigne, statique, update=update) _, localisation = save_schema_from_statique( session, Localisation, @@ -175,6 +199,7 @@ def save_statique( "adresse_station", }, update=update, + author=author, ) # Relationships @@ -195,7 +220,10 @@ def save_statique( def update_statique( - session: Session, id_pdc_itinerance: str, to_update: Statique + session: Session, + id_pdc_itinerance: str, + to_update: Statique, + author: Optional[User] = None, ) -> Statique: """Update given statique from its id_pdc_itinerance.""" # Check that submitted id_pdc_itinerance corresponds to the update @@ -215,17 +243,19 @@ def update_statique( ): raise ObjectDoesNotExist("Statique with id_pdc_itinerance does not exist") - return save_statique(session, to_update, update=True) + return save_statique(session, to_update, update=True, author=author) -def save_statiques(db_session: Session, statiques: List[Statique]): +def save_statiques( + db_session: Session, statiques: List[Statique], author: Optional[User] = None +): """Save input statiques to database.""" df = pd.read_json( StringIO(f"{'\n'.join([s.model_dump_json() for s in statiques])}"), lines=True, dtype_backend="pyarrow", ) - importer = StatiqueImporter(df, db_session.connection()) + importer = StatiqueImporter(df, db_session.connection(), author=author) importer.save()