diff --git a/src/api/qualicharge/api/v1/routers/static.py b/src/api/qualicharge/api/v1/routers/static.py index b32630e3..921884eb 100644 --- a/src/api/qualicharge/api/v1/routers/static.py +++ b/src/api/qualicharge/api/v1/routers/static.py @@ -210,7 +210,7 @@ async def update( transaction = session.begin_nested() try: - update = update_statique(session, id_pdc_itinerance, statique) + update = update_statique(session, id_pdc_itinerance, statique, author=user) except QCIntegrityError as err: transaction.rollback() raise HTTPException( @@ -243,7 +243,7 @@ async def create( transaction = session.begin_nested() try: - db_statique = save_statique(session, statique) + db_statique = save_statique(session, statique, author=user) except ObjectDoesNotExist as err: transaction.rollback() raise HTTPException( @@ -277,7 +277,7 @@ async def bulk( ) transaction = session.begin_nested() - importer = StatiqueImporter(df, transaction.session.connection()) + importer = StatiqueImporter(df, transaction.session.connection(), author=user) try: importer.save() except ( diff --git a/src/api/qualicharge/auth/schemas.py b/src/api/qualicharge/auth/schemas.py index a72d7a36..98fb051e 100644 --- a/src/api/qualicharge/auth/schemas.py +++ b/src/api/qualicharge/auth/schemas.py @@ -9,7 +9,7 @@ from sqlmodel import Field, Relationship, SQLModel from qualicharge.conf import settings -from qualicharge.schemas import BaseTimestampedSQLModel +from qualicharge.schemas.audit import BaseAuditableSQLModel from qualicharge.schemas.core import OperationalUnit @@ -53,7 +53,7 @@ class ScopesEnum(StrEnum): # -- Core schemas -class User(BaseTimestampedSQLModel, table=True): +class User(BaseAuditableSQLModel, table=True): """QualiCharge User.""" id: UUID = Field(default_factory=uuid4, primary_key=True) @@ -103,7 +103,7 @@ def check_password(self, password: str) -> bool: return settings.PASSWORD_CONTEXT.verify(password, self.password) -class Group(BaseTimestampedSQLModel, table=True): +class Group(BaseAuditableSQLModel, table=True): """QualiCharge Group.""" id: UUID = Field(default_factory=uuid4, primary_key=True) diff --git a/src/api/qualicharge/factories/__init__.py b/src/api/qualicharge/factories/__init__.py index d85f10ce..bb82920d 100644 --- a/src/api/qualicharge/factories/__init__.py +++ b/src/api/qualicharge/factories/__init__.py @@ -32,5 +32,7 @@ class TimestampedSQLModelFactory(Generic[T], SQLAlchemyFactory[T]): __is_base_factory__ = True id = Use(uuid4) + created_by_id = None + updated_by_id = None created_at = Use(lambda: datetime.now(timezone.utc) - timedelta(hours=1)) updated_at = Use(datetime.now, timezone.utc) diff --git a/src/api/qualicharge/fixtures/operational_units.py b/src/api/qualicharge/fixtures/operational_units.py index e72504e5..b81839a7 100644 --- a/src/api/qualicharge/fixtures/operational_units.py +++ b/src/api/qualicharge/fixtures/operational_units.py @@ -15,12 +15,13 @@ """ from collections import namedtuple +from typing import List from qualicharge.schemas.core import OperationalUnit, OperationalUnitTypeEnum # Operational units Item = namedtuple("Item", ["code", "name"]) -data = [ +data: List[Item] = [ Item( "FR073", "ACELEC CHARGE", diff --git a/src/api/qualicharge/migrations/env.py b/src/api/qualicharge/migrations/env.py index f00bba8d..39bd74d7 100644 --- a/src/api/qualicharge/migrations/env.py +++ b/src/api/qualicharge/migrations/env.py @@ -15,6 +15,7 @@ Group, GroupOperationalUnit, ) +from qualicharge.schemas.audit import Audit from qualicharge.schemas.core import ( # noqa: F401 Amenageur, Enseigne, diff --git a/src/api/qualicharge/migrations/versions/86b08ec6e6d1_add_auditability.py b/src/api/qualicharge/migrations/versions/86b08ec6e6d1_add_auditability.py new file mode 100644 index 00000000..a9bae7a1 --- /dev/null +++ b/src/api/qualicharge/migrations/versions/86b08ec6e6d1_add_auditability.py @@ -0,0 +1,268 @@ +"""add auditability + +Revision ID: 86b08ec6e6d1 +Revises: c09664a85912 +Create Date: 2024-10-16 17:03:58.431420 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + +# revision identifiers, used by Alembic. +revision: str = "86b08ec6e6d1" +down_revision: Union[str, None] = "c09664a85912" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "audit", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("table", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("author_id", sa.Uuid(), nullable=False), + sa.Column("target_id", sa.Uuid(), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("changes", sa.JSON(), nullable=False), + sa.ForeignKeyConstraint( + ["author_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.add_column("amenageur", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("amenageur", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "amenageur_created_by_id_fkey", "amenageur", "user", ["created_by_id"], ["id"] + ) + op.create_foreign_key( + "amenageur_updated_by_id_fkey", "amenageur", "user", ["updated_by_id"], ["id"] + ) + op.add_column("city", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("city", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "city_created_by_id_fkey", "city", "user", ["created_by_id"], ["id"] + ) + op.create_foreign_key( + "city_updated_by_id_fkey", "city", "user", ["updated_by_id"], ["id"] + ) + op.add_column("department", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("department", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "department_created_by_id_fkey", "department", "user", ["updated_by_id"], ["id"] + ) + op.create_foreign_key( + "department_updated_by_id_fkey", "department", "user", ["created_by_id"], ["id"] + ) + op.add_column("enseigne", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("enseigne", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "enseigne_created_by_id_fkey", "enseigne", "user", ["created_by_id"], ["id"] + ) + op.create_foreign_key( + "enseigne_updated_by_id_fkey", "enseigne", "user", ["updated_by_id"], ["id"] + ) + op.add_column("epci", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("epci", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "epci_created_by_id_fkey", "epci", "user", ["created_by_id"], ["id"] + ) + op.create_foreign_key( + "epci_updated_by_id_fkey", "epci", "user", ["updated_by_id"], ["id"] + ) + op.add_column("group", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("group", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "group_created_by_id_fkey", "group", "user", ["created_by_id"], ["id"] + ) + op.create_foreign_key( + "group_updated_by_id_fkey", "group", "user", ["updated_by_id"], ["id"] + ) + op.add_column("localisation", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("localisation", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "localisation_created_by_id_fkey", + "localisation", + "user", + ["updated_by_id"], + ["id"], + ) + op.create_foreign_key( + "localisation_updated_by_id_fkey", + "localisation", + "user", + ["created_by_id"], + ["id"], + ) + op.add_column("operateur", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("operateur", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "operateur_created_by_id_fkey", "operateur", "user", ["created_by_id"], ["id"] + ) + op.create_foreign_key( + "operateur_updated_by_id_fkey", "operateur", "user", ["updated_by_id"], ["id"] + ) + op.add_column( + "operationalunit", sa.Column("created_by_id", sa.Uuid(), nullable=True) + ) + op.add_column( + "operationalunit", sa.Column("updated_by_id", sa.Uuid(), nullable=True) + ) + op.create_foreign_key( + "operationalunit_created_by_id_fkey", + "operationalunit", + "user", + ["updated_by_id"], + ["id"], + ) + op.create_foreign_key( + "operationalunit_updated_by_id_fkey", + "operationalunit", + "user", + ["created_by_id"], + ["id"], + ) + op.add_column("pointdecharge", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("pointdecharge", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "pointdecharge_created_by_id_fkey", + "pointdecharge", + "user", + ["created_by_id"], + ["id"], + ) + op.create_foreign_key( + "pointdecharge_updated_by_id_fkey", + "pointdecharge", + "user", + ["updated_by_id"], + ["id"], + ) + op.add_column("region", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("region", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "region_created_by_id_fkey", "region", "user", ["updated_by_id"], ["id"] + ) + op.create_foreign_key( + "region_updated_by_id_fkey", "region", "user", ["created_by_id"], ["id"] + ) + op.add_column("session", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("session", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "session_created_by_id_fkey", "session", "user", ["updated_by_id"], ["id"] + ) + op.create_foreign_key( + "session_updated_by_id_fkey", "session", "user", ["created_by_id"], ["id"] + ) + op.add_column("station", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("station", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "station_created_by_id_fkey", "station", "user", ["updated_by_id"], ["id"] + ) + op.create_foreign_key( + "station_updated_by_id_fkey", "station", "user", ["created_by_id"], ["id"] + ) + op.add_column("status", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("status", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "status_created_by_id_fkey", "status", "user", ["created_by_id"], ["id"] + ) + op.create_foreign_key( + "status_updated_by_id_fkey", "status", "user", ["updated_by_id"], ["id"] + ) + op.add_column("user", sa.Column("created_by_id", sa.Uuid(), nullable=True)) + op.add_column("user", sa.Column("updated_by_id", sa.Uuid(), nullable=True)) + op.create_foreign_key( + "user_created_by_id_fkey", "user", "user", ["created_by_id"], ["id"] + ) + op.create_foreign_key( + "user_updated_by_id_fkey", "user", "user", ["updated_by_id"], ["id"] + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("user_created_by_id_fkey", "user", type_="foreignkey") + op.drop_constraint("user_updated_by_id_fkey", "user", type_="foreignkey") + op.drop_column("user", "updated_by_id") + op.drop_column("user", "created_by_id") + op.drop_constraint("status_created_by_id_fkey", "status", type_="foreignkey") + op.drop_constraint("status_updated_by_id_fkey", "status", type_="foreignkey") + op.drop_column("status", "updated_by_id") + op.drop_column("status", "created_by_id") + op.drop_constraint("station_created_by_id_fkey", "station", type_="foreignkey") + op.drop_constraint("station_updated_by_id_fkey", "station", type_="foreignkey") + op.drop_column("station", "updated_by_id") + op.drop_column("station", "created_by_id") + op.drop_constraint("session_created_by_id_fkey", "session", type_="foreignkey") + op.drop_constraint("session_updated_by_id_fkey", "session", type_="foreignkey") + op.drop_column("session", "updated_by_id") + op.drop_column("session", "created_by_id") + op.drop_constraint("region_created_by_id_fkey", "region", type_="foreignkey") + op.drop_constraint("region_updated_by_id_fkey", "region", type_="foreignkey") + op.drop_column("region", "updated_by_id") + op.drop_column("region", "created_by_id") + op.drop_constraint( + "pointdecharge_created_by_id_fkey", "pointdecharge", type_="foreignkey" + ) + op.drop_constraint( + "pointdecharge_updated_by_id_fkey", "pointdecharge", type_="foreignkey" + ) + op.drop_column("pointdecharge", "updated_by_id") + op.drop_column("pointdecharge", "created_by_id") + op.drop_constraint( + "operationalunit_created_by_id_fkey", "operationalunit", type_="foreignkey" + ) + op.drop_constraint( + "operationalunit_updated_by_id_fkey", "operationalunit", type_="foreignkey" + ) + op.drop_column("operationalunit", "updated_by_id") + op.drop_column("operationalunit", "created_by_id") + op.drop_constraint("operateur_created_by_id_fkey", "operateur", type_="foreignkey") + op.drop_constraint("operateur_updated_by_id_fkey", "operateur", type_="foreignkey") + op.drop_column("operateur", "updated_by_id") + op.drop_column("operateur", "created_by_id") + op.drop_constraint( + "localisation_created_by_id_fkey", "localisation", type_="foreignkey" + ) + op.drop_constraint( + "localisation_updated_by_id_fkey", "localisation", type_="foreignkey" + ) + op.drop_column("localisation", "updated_by_id") + op.drop_column("localisation", "created_by_id") + op.drop_constraint("group_created_by_id_fkey", "group", type_="foreignkey") + op.drop_constraint("group_updated_by_id_fkey", "group", type_="foreignkey") + op.drop_column("group", "updated_by_id") + op.drop_column("group", "created_by_id") + op.drop_constraint("epci_created_by_id_fkey", "epci", type_="foreignkey") + op.drop_constraint("epci_updated_by_id_fkey", "epci", type_="foreignkey") + op.drop_column("epci", "updated_by_id") + op.drop_column("epci", "created_by_id") + op.drop_constraint("enseigne_created_by_id_fkey", "enseigne", type_="foreignkey") + op.drop_constraint("enseigne_updated_by_id_fkey", "enseigne", type_="foreignkey") + op.drop_column("enseigne", "updated_by_id") + op.drop_column("enseigne", "created_by_id") + op.drop_constraint( + "department_created_by_id_fkey", "department", type_="foreignkey" + ) + op.drop_constraint( + "department_updated_by_id_fkey", "department", type_="foreignkey" + ) + op.drop_column("department", "updated_by_id") + op.drop_column("department", "created_by_id") + op.drop_constraint("city_created_by_id_fkey", "city", type_="foreignkey") + op.drop_constraint("city_updated_by_id_fkey", "city", type_="foreignkey") + op.drop_column("city", "updated_by_id") + op.drop_column("city", "created_by_id") + op.drop_constraint("amenageur_created_by_id_fkey", "amenageur", type_="foreignkey") + op.drop_constraint("amenageur_updated_by_id_fkey", "amenageur", type_="foreignkey") + op.drop_column("amenageur", "updated_by_id") + op.drop_column("amenageur", "created_by_id") + op.drop_table("audit") + # ### end Alembic commands ### diff --git a/src/api/qualicharge/migrations/versions/fda96abb970d_add_operational_unit_data.py b/src/api/qualicharge/migrations/versions/fda96abb970d_add_operational_unit_data.py index b8a4157a..a281d9c0 100644 --- a/src/api/qualicharge/migrations/versions/fda96abb970d_add_operational_unit_data.py +++ b/src/api/qualicharge/migrations/versions/fda96abb970d_add_operational_unit_data.py @@ -6,13 +6,16 @@ """ +from datetime import datetime, timezone from typing import Sequence, Union +from uuid import uuid4 + +from sqlalchemy import MetaData from alembic import op -from sqlmodel import Session, select -from qualicharge.fixtures.operational_units import operational_units -from qualicharge.schemas.core import Station +from qualicharge.fixtures.operational_units import data as operational_units +from qualicharge.schemas.core import OperationalUnitTypeEnum # revision identifiers, used by Alembic. @@ -32,32 +35,54 @@ def downgrade(): def data_upgrades(): """Add any optional data upgrade migrations here!""" - - # We are running in a transaction, hence we need to get the current active connection - session = Session(op.get_bind()) - # Reset table before inserting data data_downgrades() - session.add_all(operational_units) - session.commit() + + # Get OperationalUnit table + metadata = MetaData() + metadata.reflect(bind=op.get_bind()) + ou_table = metadata.tables["operationalunit"] + + # Bulk insert + now = datetime.now(timezone.utc) + op.bulk_insert( + ou_table, + [ + { + "id": uuid4().hex, + "created_at": now, + "updated_at": now, + "type": "CHARGING", + } + | ou._asdict() + for ou in operational_units + ], + ) # Create FK - for operational_unit in operational_units: - operational_unit.create_stations_fk(session) + op.execute( + """ + WITH station_ou AS ( + SELECT + Station.id as station_id, + OperationalUnit.id as operational_unit_id + FROM + Station + INNER JOIN OperationalUnit ON + SUBSTRING(Station.id_station_itinerance, 1, 5) = OperationalUnit.code + ) + UPDATE Station + SET operational_unit_id = station_ou.operational_unit_id + FROM station_ou + WHERE Station.id = station_ou.station_id + """ + ) def data_downgrades(): """Add any optional data downgrade migrations here!""" - - # We are running in a transaction, hence we need to get the current active connection - session = Session(op.get_bind()) - # Reset FK - stations = session.exec(select(Station)).all() - for station in stations: - station.operational_unit_id = None - session.add_all(stations) - session.commit() + op.execute("UPDATE Station SET operational_unit_id = NULL") # Delete records - op.execute("delete from operationalunit") + op.execute("DELETE FROM OperationalUnit") diff --git a/src/api/qualicharge/schemas/__init__.py b/src/api/qualicharge/schemas/__init__.py index 4290e078..c6cf2484 100644 --- a/src/api/qualicharge/schemas/__init__.py +++ b/src/api/qualicharge/schemas/__init__.py @@ -1,35 +1 @@ """QualiCharge schemas.""" - -from datetime import datetime, timezone -from typing import Any - -from pydantic import PastDatetime -from sqlalchemy import CheckConstraint -from sqlalchemy.types import DateTime -from sqlmodel import Field, SQLModel - - -class BaseTimestampedSQLModel(SQLModel): - """A base class for SQL models with timestamp fields. - - This class provides two timestamp fields, `created_at` and `updated_at`, which are - automatically managed. The `created_at` field is set to the current UTC time when - a new record is created, and the `updated_at` field is updated to the current UTC - time whenever the record is modified. - """ - - __table_args__: Any = ( - CheckConstraint("created_at <= updated_at", name="pre-creation-update"), - ) - - created_at: PastDatetime = Field( - sa_type=DateTime(timezone=True), - default_factory=lambda: datetime.now(timezone.utc), - description="The timestamp indicating when the record was created.", - ) # type: ignore - updated_at: PastDatetime = Field( - sa_type=DateTime(timezone=True), - sa_column_kwargs={"onupdate": lambda: datetime.now(timezone.utc)}, - default_factory=lambda: datetime.now(timezone.utc), - description="The timestamp indicating when the record was last updated.", - ) # type: ignore diff --git a/src/api/qualicharge/schemas/audit.py b/src/api/qualicharge/schemas/audit.py new file mode 100644 index 00000000..02ae59fc --- /dev/null +++ b/src/api/qualicharge/schemas/audit.py @@ -0,0 +1,142 @@ +"""QualiCharge auditable schemas.""" + +import logging +from datetime import datetime, timezone +from enum import StrEnum +from typing import Any, Optional +from uuid import UUID, uuid4 + +from pydantic import PastDatetime +from sqlalchemy import CheckConstraint, and_, event, insert, inspect +from sqlalchemy.orm import backref, foreign, relationship, remote +from sqlalchemy.types import JSON, DateTime +from sqlmodel import Field, SQLModel + +logger = logging.getLogger(__name__) + + +class AuditableFieldBlackListEnum(StrEnum): + """Fields black listed for auditability.""" + + PASSWORD = "password" # noqa: S105 + + +class BaseAuditableSQLModel(SQLModel): + """A base class for auditable SQL models. + + This class provides two timestamp fields, `created_at` and `updated_at`, which are + automatically managed. The `created_at` field is set to the current UTC time when + a new record is created, and the `updated_at` field is updated to the current UTC + time whenever the record is modified. + + The two `created_by_id`, `updated_by_id` foreign keys points to the auth.User model. + Both keys are optional and need to be explicitly set by your code. + + To fully track changes, you need to connect the `track_model_changes` utility (see + below) to SQLAlchemy events. + """ + + __table_args__: Any = ( + CheckConstraint("created_at <= updated_at", name="pre-creation-update"), + ) + + created_at: PastDatetime = Field( + sa_type=DateTime(timezone=True), + default_factory=lambda: datetime.now(timezone.utc), + description="The timestamp indicating when the record was created.", + ) # type: ignore + updated_at: PastDatetime = Field( + sa_type=DateTime(timezone=True), + sa_column_kwargs={"onupdate": lambda: datetime.now(timezone.utc)}, + default_factory=lambda: datetime.now(timezone.utc), + description="The timestamp indicating when the record was last updated.", + ) # type: ignore + created_by_id: Optional[UUID] = Field(default=None, foreign_key="user.id") + updated_by_id: Optional[UUID] = Field(default=None, foreign_key="user.id") + + +class Audit(SQLModel, table=True): + """Model changes record for auditability.""" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + table: str + author_id: UUID = Field(foreign_key="user.id") + target_id: UUID + updated_at: PastDatetime = Field(sa_type=DateTime(timezone=True)) # type: ignore + changes: dict = Field(sa_type=JSON) + + +@event.listens_for(BaseAuditableSQLModel, "mapper_configured", propagate=True) +def add_audit_generic_fk(mapper, class_): + """Create a generic foreign key to the Audit table for auditable schemas.""" + name = class_.__name__ + discriminator = name.lower() + class_.audits = relationship( + Audit, + primaryjoin=and_( + class_.id == foreign(remote(Audit.target_id)), + Audit.table == discriminator, + ), + backref=backref( + "target_%s" % discriminator, + primaryjoin=remote(class_.id) == foreign(Audit.target_id), + ), + ) + + @event.listens_for(class_.audits, "append") + def append_audit(target, value, initiator): + value.discriminator = discriminator + + +def track_model_changes(mapper, connection, target): + """Track model changes for auditability. + + Models to track are supposed to inherit from the `BaseAuditableSQLModel`. You + should listen to "after_update" events on your model to add full auditability + support, e.g.: + + ```python + from sqlalchemy import event + + from myapp.schemas import MyModel + from ..schemas.audit import track_model_changes + + + event.listen(MyModel, "after_update", track_model_changes) + ``` + + For each changed field, the previous value is stored along with the modification + date and the author. For fields with sensitive information (_e.g._ passwords or + tokens), a null value is stored. + """ + if target.updated_by_id is None: + logger.debug("Target updated_by_id is empty, aborting changes tracking.") + return + + state = inspect(target) + + # Get changes + changes = {} + for attr in state.attrs: + if attr.key in AuditableFieldBlackListEnum: + continue + history = attr.load_history() + if not history.has_changes(): + continue + changes[attr.key] = [ + str(history.deleted[0]) if len(history.deleted) else None, + str(history.added[0]) if len(history.added) else None, + ] + + logger.debug("Detected changes: %s", str(changes)) + + # Log changes + connection.execute( + insert(Audit).values( + table=target.__tablename__, + author_id=target.updated_by_id, + target_id=target.id, + updated_at=target.updated_at, + changes=changes, + ) + ) diff --git a/src/api/qualicharge/schemas/core.py b/src/api/qualicharge/schemas/core.py index 1a4d8f44..76a877a8 100644 --- a/src/api/qualicharge/schemas/core.py +++ b/src/api/qualicharge/schemas/core.py @@ -37,7 +37,7 @@ ImplantationStationEnum, RaccordementEnum, ) -from . import BaseTimestampedSQLModel +from .audit import BaseAuditableSQLModel, track_model_changes if TYPE_CHECKING: from qualicharge.auth.schemas import Group @@ -50,10 +50,10 @@ class OperationalUnitTypeEnum(IntEnum): MOBILITY = 2 -class Amenageur(BaseTimestampedSQLModel, table=True): +class Amenageur(BaseAuditableSQLModel, table=True): """Amenageur table.""" - __table_args__ = BaseTimestampedSQLModel.__table_args__ + ( + __table_args__ = BaseAuditableSQLModel.__table_args__ + ( UniqueConstraint("nom_amenageur", "siren_amenageur", "contact_amenageur"), ) @@ -73,10 +73,10 @@ def __eq__(self, other) -> bool: return all(getattr(self, field) == getattr(other, field) for field in fields) -class Operateur(BaseTimestampedSQLModel, table=True): +class Operateur(BaseAuditableSQLModel, table=True): """Operateur table.""" - __table_args__ = BaseTimestampedSQLModel.__table_args__ + ( + __table_args__ = BaseAuditableSQLModel.__table_args__ + ( UniqueConstraint("nom_operateur", "contact_operateur", "telephone_operateur"), ) @@ -96,7 +96,7 @@ def __eq__(self, other) -> bool: return all(getattr(self, field) == getattr(other, field) for field in fields) -class Enseigne(BaseTimestampedSQLModel, table=True): +class Enseigne(BaseAuditableSQLModel, table=True): """Enseigne table.""" model_config = SQLModelConfig(validate_assignment=True) @@ -113,7 +113,7 @@ def __eq__(self, other) -> bool: return all(getattr(self, field) == getattr(other, field) for field in fields) -class Localisation(BaseTimestampedSQLModel, table=True): +class Localisation(BaseAuditableSQLModel, table=True): """Localisation table.""" model_config = SQLModelConfig( @@ -178,7 +178,7 @@ def serialize_coordonneesXY( return self._wkb_to_coordinates(value) -class OperationalUnit(BaseTimestampedSQLModel, table=True): +class OperationalUnit(BaseAuditableSQLModel, table=True): """OperationalUnit table.""" id: UUID = Field(default_factory=uuid4, primary_key=True) @@ -217,7 +217,7 @@ def create_stations_fk(self, session: SMSession): session.commit() -class Station(BaseTimestampedSQLModel, table=True): +class Station(BaseAuditableSQLModel, table=True): """Station table.""" model_config = SQLModelConfig(validate_assignment=True) @@ -299,7 +299,7 @@ def link_station_to_operational_unit(mapper, connection, target): target.operational_unit_id = operational_unit.id -class PointDeCharge(BaseTimestampedSQLModel, table=True): +class PointDeCharge(BaseAuditableSQLModel, table=True): """PointDeCharge table.""" model_config = SQLModelConfig(validate_assignment=True) @@ -340,10 +340,10 @@ def __eq__(self, other) -> bool: statuses: List["Status"] = Relationship(back_populates="point_de_charge") -class Session(BaseTimestampedSQLModel, SessionBase, table=True): +class Session(BaseAuditableSQLModel, SessionBase, table=True): """IRVE recharge session.""" - __table_args__ = BaseTimestampedSQLModel.__table_args__ + ( + __table_args__ = BaseAuditableSQLModel.__table_args__ + ( {"timescaledb_hypertable": {"time_column_name": "start"}}, ) @@ -364,10 +364,10 @@ class Session(BaseTimestampedSQLModel, SessionBase, table=True): point_de_charge: PointDeCharge = Relationship(back_populates="sessions") -class Status(BaseTimestampedSQLModel, StatusBase, table=True): +class Status(BaseAuditableSQLModel, StatusBase, table=True): """IRVE recharge session.""" - __table_args__ = BaseTimestampedSQLModel.__table_args__ + ( + __table_args__ = BaseAuditableSQLModel.__table_args__ + ( {"timescaledb_hypertable": {"time_column_name": "horodatage"}}, ) @@ -388,3 +388,15 @@ class Status(BaseTimestampedSQLModel, StatusBase, table=True): def id_pdc_itinerance(self) -> str: """Return the PointDeCharge.id_pdc_itinerance (used for serialization only).""" return self.point_de_charge.id_pdc_itinerance + + +# Declare auditable models +for auditable_model in ( + Amenageur, + Operateur, + Enseigne, + Localisation, + Station, + PointDeCharge, +): + event.listen(auditable_model, "after_update", track_model_changes) diff --git a/src/api/qualicharge/schemas/geo.py b/src/api/qualicharge/schemas/geo.py index 50fcc024..407a0f81 100644 --- a/src/api/qualicharge/schemas/geo.py +++ b/src/api/qualicharge/schemas/geo.py @@ -7,10 +7,10 @@ from sqlmodel import Field, Relationship from sqlmodel.main import SQLModelConfig -from . import BaseTimestampedSQLModel +from .audit import BaseAuditableSQLModel -class BaseAdministrativeBoundaries(BaseTimestampedSQLModel): +class BaseAdministrativeBoundaries(BaseAuditableSQLModel): """Base administrative boundaries model.""" model_config = SQLModelConfig( diff --git a/src/api/qualicharge/schemas/sql.py b/src/api/qualicharge/schemas/sql.py index 674669de..405b7389 100644 --- a/src/api/qualicharge/schemas/sql.py +++ b/src/api/qualicharge/schemas/sql.py @@ -17,9 +17,10 @@ from sqlalchemy.schema import MetaData from typing_extensions import Optional +from ..auth.schemas import User from ..exceptions import ObjectDoesNotExist, ProgrammingError from ..models.static import Statique -from . import BaseTimestampedSQLModel +from .audit import BaseAuditableSQLModel from .core import ( AccessibilitePMREnum, Amenageur, @@ -39,13 +40,15 @@ class StatiqueImporter: """Statique importer from a Pandas Dataframe.""" - def __init__(self, df: pd.DataFrame, connection: Connection): + def __init__( + self, df: pd.DataFrame, connection: Connection, author: Optional[User] = None + ): """Add table cache keys.""" logger.info("Loading input dataframe containing %d rows", len(df)) self._statique: pd.DataFrame = self._fix_enums(df) self._statique_with_fk: pd.DataFrame = self._statique.copy() - self._saved_schemas: list[type[BaseTimestampedSQLModel]] = [] + self._saved_schemas: list[type[BaseAuditableSQLModel]] = [] self._amenageur: Optional[pd.DataFrame] = None self._enseigne: Optional[pd.DataFrame] = None @@ -57,27 +60,29 @@ def __init__(self, df: pd.DataFrame, connection: Connection): self._operational_units: Optional[pd.DataFrame] = None self.connection: Connection = connection + self.author: Optional[User] = author def __len__(self): """Object length corresponds to the static dataframe length.""" return len(self._statique) - @staticmethod - def _add_timestamped_model_fields(df: pd.DataFrame): - """Add required fields for a BaseTimestampedSQLModel.""" + def _add_auditable_model_fields(self, df: pd.DataFrame): + """Add required fields for a BaseAuditableSQLModel.""" df["id"] = df.apply(lambda x: uuid.uuid4(), axis=1) now = pd.Timestamp.now(tz="utc") df["created_at"] = now df["updated_at"] = now + df["created_by_id"] = self.author.id if self.author else None + df["updated_by_id"] = self.author.id if self.author else None return df @staticmethod - def _schema_fk(schema: type[BaseTimestampedSQLModel]) -> str: + def _schema_fk(schema: type[BaseAuditableSQLModel]) -> str: """Get expected schema foreign key name.""" return f"{schema.__table__.name}_id" # type: ignore[attr-defined] @staticmethod - def _get_schema_fks(schema: type[BaseTimestampedSQLModel]) -> list[str]: + def _get_schema_fks(schema: type[BaseAuditableSQLModel]) -> list[str]: """Get foreign key field names from a schema.""" return [ fk.parent.name @@ -103,19 +108,21 @@ def _fix_enums(self, df: pd.DataFrame) -> pd.DataFrame: return df.replace(to_replace=src, value=target) def _get_fields_for_schema( - self, schema: type[BaseTimestampedSQLModel], with_fk: bool = False + self, schema: type[BaseAuditableSQLModel], with_fk: bool = False ) -> list[str]: """Get Statique fields from a core schema.""" fields = list( set(Statique.model_fields.keys()) & set(schema.model_fields.keys()) ) + # Auditable model fks should be ignored + ignored_fks = {"created_by_id", "updated_by_id"} if with_fk: - fields += self._get_schema_fks(schema) + fields += list(set(self._get_schema_fks(schema)) - ignored_fks) return fields def _get_dataframe_for_schema( self, - schema: type[BaseTimestampedSQLModel], + schema: type[BaseAuditableSQLModel], subset: Optional[str] = None, with_fk: bool = False, ): @@ -123,11 +130,11 @@ def _get_dataframe_for_schema( src = self._statique_with_fk if with_fk else self._statique df = src[self._get_fields_for_schema(schema, with_fk=with_fk)] df = df.drop_duplicates(subset) - df = self._add_timestamped_model_fields(df) + df = self._add_auditable_model_fields(df) return df def _add_fk_from_saved_schema( - self, saved: pd.DataFrame, schema: type[BaseTimestampedSQLModel] + self, saved: pd.DataFrame, schema: type[BaseAuditableSQLModel] ): """Add foreign keys to the statique DataFrame using saved schema.""" fields = self._get_fields_for_schema(schema) @@ -214,7 +221,7 @@ def station(self) -> pd.DataFrame: def _save_schema( self, df: pd.DataFrame, - schema: type[BaseTimestampedSQLModel], + schema: type[BaseAuditableSQLModel], constraint: Optional[str] = None, index_elements: Optional[list[str]] = None, chunksize: int = 1000, @@ -243,7 +250,12 @@ def _save_schema( f: stmt.excluded.get(f) for f in self._get_fields_for_schema(schema, with_fk=True) } - updates_on_conflict.update({"updated_at": stmt.excluded.updated_at}) + updates_on_conflict.update( + { + "updated_at": stmt.excluded.updated_at, + "updated_by_id": stmt.excluded.updated_by_id, + } + ) stmt = stmt.on_conflict_do_update( constraint=constraint, index_elements=index_elements, diff --git a/src/api/qualicharge/schemas/utils.py b/src/api/qualicharge/schemas/utils.py index 4dbc67bc..b168971b 100644 --- a/src/api/qualicharge/schemas/utils.py +++ b/src/api/qualicharge/schemas/utils.py @@ -12,6 +12,7 @@ from sqlmodel import Session, SQLModel, select from qualicharge.auth.schemas import User +from qualicharge.schemas.audit import BaseAuditableSQLModel from ..exceptions import ( DatabaseQueryException, @@ -32,7 +33,13 @@ logger = logging.getLogger(__name__) -DB_TO_STATIC_EXCLUDED_FIELDS = {"id", "created_at", "updated_at"} +DB_TO_STATIC_EXCLUDED_FIELDS = { + "id", + "created_at", + "updated_at", + "created_by_id", + "updated_by_id", +} class EntryStatus(IntEnum): @@ -84,7 +91,7 @@ def get_or_create( # Update database entry for key, value in entry.model_dump( - exclude=DB_TO_STATIC_EXCLUDED_FIELDS + exclude=set(DB_TO_STATIC_EXCLUDED_FIELDS) - {"updated_by_id"} ).items(): setattr(db_entry, key, value) session.add(db_entry) @@ -98,12 +105,13 @@ def get_or_create( return EntryStatus.CREATED, entry -def save_schema_from_statique( +def save_schema_from_statique( # noqa: PLR0913 session: Session, schema_klass: Type[SQLModel], statique: Statique, fields: Optional[Set] = None, update: bool = False, + author: Optional[User] = None, ) -> Tuple[EntryStatus, SQLModel]: """Save schema to database from Statique instance. @@ -114,6 +122,7 @@ def save_schema_from_statique( fields: entry fields used in database query to select target entry. Defaults to None (use all fields). update: should we update existing instance if required? + author: the user that creates/updates the schema entry Returns: A (EntryStatus, entry) tuple. The status refers on the entry creation/update. @@ -123,6 +132,12 @@ def save_schema_from_statique( """ # Is this a new entry? entry = schema_klass(**statique.get_fields_for_schema(schema_klass)) + + # Add author for auditability + if issubclass(schema_klass, BaseAuditableSQLModel): + entry.created_by_id = author.id if author else None + entry.updated_by_id = author.id if author else None + return get_or_create( session, entry, @@ -144,23 +159,38 @@ def pdc_to_statique(pdc: PointDeCharge) -> Statique: def save_statique( - session: Session, statique: Statique, update: bool = False + session: Session, + statique: Statique, + update: bool = False, + author: Optional[User] = None, ) -> Statique: """Save Statique instance to database.""" # Core schemas _, pdc = save_schema_from_statique( - session, PointDeCharge, statique, fields={"id_pdc_itinerance"}, update=update + session, + PointDeCharge, + statique, + fields={"id_pdc_itinerance"}, + update=update, + author=author, ) _, station = save_schema_from_statique( - session, Station, statique, fields={"id_station_itinerance"}, update=update + session, + Station, + statique, + fields={"id_station_itinerance"}, + update=update, + author=author, ) _, amenageur = save_schema_from_statique( - session, Amenageur, statique, update=update + session, Amenageur, statique, update=update, author=author ) _, operateur = save_schema_from_statique( - session, Operateur, statique, update=update + session, Operateur, statique, update=update, author=author + ) + _, enseigne = save_schema_from_statique( + session, Enseigne, statique, update=update, author=author ) - _, enseigne = save_schema_from_statique(session, Enseigne, statique, update=update) _, localisation = save_schema_from_statique( session, Localisation, @@ -169,6 +199,7 @@ def save_statique( "adresse_station", }, update=update, + author=author, ) # Relationships @@ -189,7 +220,10 @@ def save_statique( def update_statique( - session: Session, id_pdc_itinerance: str, to_update: Statique + session: Session, + id_pdc_itinerance: str, + to_update: Statique, + author: Optional[User] = None, ) -> Statique: """Update given statique from its id_pdc_itinerance.""" # Check that submitted id_pdc_itinerance corresponds to the update @@ -209,17 +243,19 @@ def update_statique( ): raise ObjectDoesNotExist("Statique with id_pdc_itinerance does not exist") - return save_statique(session, to_update, update=True) + return save_statique(session, to_update, update=True, author=author) -def save_statiques(db_session: Session, statiques: List[Statique]): +def save_statiques( + db_session: Session, statiques: List[Statique], author: Optional[User] = None +): """Save input statiques to database.""" df = pd.read_json( StringIO(f"{'\n'.join([s.model_dump_json() for s in statiques])}"), lines=True, dtype_backend="pyarrow", ) - importer = StatiqueImporter(df, db_session.connection()) + importer = StatiqueImporter(df, db_session.connection(), author=author) importer.save() diff --git a/src/api/tests/api/v1/routers/test_statique.py b/src/api/tests/api/v1/routers/test_statique.py index 9169458f..8705f04d 100644 --- a/src/api/tests/api/v1/routers/test_statique.py +++ b/src/api/tests/api/v1/routers/test_statique.py @@ -514,6 +514,63 @@ def test_update_for_superuser(client_auth, db_session): assert json_response == json.loads(new_statique.model_dump_json()) +def test_update_audits(client_auth, db_session): + """Test the /statique/{id_pdc_itinerance} update endpoint audits. + + (superuser case) + """ + id_pdc_itinerance = "FR911E1111ER1" + db_statique = save_statique( + db_session, + StatiqueFactory.build( + id_pdc_itinerance=id_pdc_itinerance, + nom_amenageur="ACME Inc.", + nom_operateur="ACME Inc.", + nom_enseigne="ACME Inc.", + coordonneesXY=Coordinate(-1.0, 1.0), + station_deux_roues=False, + cable_t2_attache=False, + ), + ) + station = db_session.exec( + select(Station).where( + Station.id_station_itinerance == db_statique.id_station_itinerance + ) + ).one() + + assert len(station.audits) == 0 + + new_statique = db_statique.model_copy( + update={ + "contact_operateur": "john@doe.com", + "nom_amenageur": "Magma Corp.", + "nom_operateur": "Magma Corp.", + "nom_enseigne": "Magma Corp.", + "coordonneesXY": Coordinate(1.0, 2.0), + "station_deux_roues": True, + "cable_t2_attache": True, + }, + deep=True, + ) + + response = client_auth.put( + f"/statique/{id_pdc_itinerance}", + json=json.loads(new_statique.model_dump_json()), + ) + assert response.status_code == status.HTTP_200_OK + db_session.refresh(station) + + # Get user requesting the server + user = db_session.exec(select(User).where(User.email == "john@doe.com")).one() + + # We expect two audits as FKs are updated in a second request (once all other models + # have been updated). + expected_audits = 2 + assert len(station.audits) == expected_audits + assert station.audits[0].author_id == user.id + assert station.audits[1].author_id == user.id + + @pytest.mark.parametrize( "client_auth", ( @@ -870,3 +927,56 @@ def test_bulk_update(client_auth, db_session): .one() .paiement_cb ) + + +def test_bulk_update_audits(client_auth, db_session): + """Test that bulk endpoint updates audits.""" + size = 10 + statiques = StatiqueFactory.batch( + size=size, + paiement_cb=False, + ) + save_statiques(db_session, statiques) + + # Update paiement_cb field + statiques[3].paiement_cb = statiques[7].paiement_cb = True + + payload = [json.loads(s.model_dump_json()) for s in statiques] + response = client_auth.post("/statique/bulk", json=payload) + assert response.status_code == status.HTTP_201_CREATED + + # Get user requesting the server + user = db_session.exec(select(User).where(User.email == "john@doe.com")).one() + + # Check changes and audit + pdc3 = db_session.exec( + select(PointDeCharge).where( + PointDeCharge.id_pdc_itinerance == statiques[3].id_pdc_itinerance + ) + ).one() + assert pdc3.paiement_cb + assert len(pdc3.audits) == 1 + assert pdc3.audits[0].author_id == user.id + + pdc7 = db_session.exec( + select(PointDeCharge).where( + PointDeCharge.id_pdc_itinerance == statiques[7].id_pdc_itinerance + ) + ).one() + assert pdc7.paiement_cb + assert len(pdc7.audits) == 1 + assert pdc7.audits[0].author_id == user.id + + # Check no other changes exist + not_updated = statiques.copy() + not_updated.pop(3) + not_updated.pop(7) + pdcs = db_session.exec( + select(PointDeCharge).where( + PointDeCharge.id_pdc_itinerance.in_( + s.id_pdc_itinerance for s in not_updated + ) + ) + ).all() + for pdc in pdcs: + assert len(pdc.audits) == 0 diff --git a/src/api/tests/schemas/test_audit.py b/src/api/tests/schemas/test_audit.py new file mode 100644 index 00000000..f158d086 --- /dev/null +++ b/src/api/tests/schemas/test_audit.py @@ -0,0 +1,187 @@ +"""QualiCharge auditable schemas tests.""" + +import pytest +from sqlalchemy import func +from sqlmodel import select + +from qualicharge.auth.factories import UserFactory +from qualicharge.factories.static import ( + AmenageurFactory, + EnseigneFactory, + LocalisationFactory, + OperateurFactory, + PointDeChargeFactory, + StationFactory, +) +from qualicharge.schemas.audit import Audit +from qualicharge.schemas.core import ( + Operateur, +) + + +def test_auditable_schema_changes(db_session): + """Test an updated schema instance creates a new Audit entry.""" + OperateurFactory.__session__ = db_session + UserFactory.__session__ = db_session + + user = UserFactory.create_sync() + + # Check initial database state + assert db_session.exec(select(func.count(Operateur.id))).one() == 0 + assert db_session.exec(select(func.count(Audit.id))).one() == 0 + + # Persist an operateur without creator or updator + operateur = OperateurFactory.create_sync( + nom_operateur="Doe inc.", + contact_operateur="john@doe.com", + telephone_operateur="+33144276350", + ) + + # Check database state + assert db_session.exec(select(func.count(Operateur.id))).one() == 1 + assert db_session.exec(select(func.count(Audit.id))).one() == 0 + + # Update operateur without updator + operateur.contact_operateur = "jane@doe.com" + operateur.telephone_operateur = "+33144276351" + db_session.add(operateur) + + # Check database state + assert db_session.exec(select(func.count(Operateur.id))).one() == 1 + assert db_session.exec(select(func.count(Audit.id))).one() == 0 + + # Now update operateur with an updator + operateur.updated_by_id = user.id + operateur.contact_operateur = "janine@doe.com" + operateur.telephone_operateur = "+33144276352" + db_session.add(operateur) + + # Check database state + assert db_session.exec(select(func.count(Operateur.id))).one() == 1 + assert db_session.exec(select(func.count(Audit.id))).one() == 1 + audit = db_session.exec(select(Audit)).first() + assert audit.table == "operateur" + assert audit.author_id == user.id + assert audit.target_id == operateur.id + assert audit.updated_at == operateur.updated_at + assert audit.changes == { + "updated_by_id": ["None", str(user.id)], + "contact_operateur": ["jane@doe.com", "janine@doe.com"], + "telephone_operateur": ["tel:+33-1-44-27-63-51", "tel:+33-1-44-27-63-52"], + } + + # Perform new updates + operateur.contact_operateur = "janot@doe.com" + operateur.telephone_operateur = "+33144276353" + db_session.add(operateur) + + # Check database state + expected_audits = 2 + assert db_session.exec(select(func.count(Operateur.id))).one() == 1 + assert db_session.exec(select(func.count(Audit.id))).one() == expected_audits + audit = db_session.exec(select(Audit).order_by(Audit.updated_at.desc())).first() + assert audit.table == "operateur" + assert audit.author_id == user.id + assert audit.target_id == operateur.id + assert audit.updated_at == operateur.updated_at + assert audit.changes == { + "contact_operateur": ["janine@doe.com", "janot@doe.com"], + "telephone_operateur": ["tel:+33-1-44-27-63-52", "tel:+33-1-44-27-63-53"], + } + + +def test_auditable_schema_audits_dynamic_fk(db_session): + """Test auditable schema dynamic audits foreign key.""" + OperateurFactory.__session__ = db_session + UserFactory.__session__ = db_session + + user = UserFactory.create_sync() + operateur = OperateurFactory.create_sync( + nom_operateur="Doe inc.", + contact_operateur="john@doe.com", + telephone_operateur="+33144276350", + updated_by_id=user.id, + ) + + assert len(operateur.audits) == 0 + + # Update operateur + operateur.contact_operateur = "janine@doe.com" + operateur.telephone_operateur = "+33144276352" + db_session.add(operateur) + db_session.commit() + db_session.refresh(operateur) + + # Test audits dymanic generic FK + assert len(operateur.audits) == 1 + assert operateur.audits[0].table == "operateur" + assert operateur.audits[0].author_id == user.id + assert operateur.audits[0].target_id == operateur.id + assert operateur.audits[0].updated_at == operateur.updated_at + assert operateur.audits[0].changes == { + "contact_operateur": ["john@doe.com", "janine@doe.com"], + "telephone_operateur": ["tel:+33-1-44-27-63-50", "tel:+33-1-44-27-63-52"], + } + + # Update operateur once again + operateur.contact_operateur = "janot@doe.com" + operateur.telephone_operateur = "+33144276353" + db_session.add(operateur) + db_session.commit() + db_session.refresh(operateur) + + # Test audits dymanic generic FK + expected_audits = 2 + assert len(operateur.audits) == expected_audits + assert operateur.audits[1].table == "operateur" + assert operateur.audits[1].author_id == user.id + assert operateur.audits[1].target_id == operateur.id + assert operateur.audits[1].updated_at == operateur.updated_at + assert operateur.audits[1].changes == { + "contact_operateur": ["janine@doe.com", "janot@doe.com"], + "telephone_operateur": ["tel:+33-1-44-27-63-52", "tel:+33-1-44-27-63-53"], + } + + +@pytest.mark.parametrize( + "factory,extras", + ( + (AmenageurFactory, {}), + (OperateurFactory, {}), + (EnseigneFactory, {}), + (LocalisationFactory, {}), + (StationFactory, {}), + (PointDeChargeFactory, {"station_id": None}), + ), +) +def test_auditable_schema_audits(db_session, factory, extras): + """Test auditable schema dynamic audits foreign key.""" + factory.__session__ = db_session + UserFactory.__session__ = db_session + + user1 = UserFactory.create_sync() + user2 = UserFactory.create_sync() + instance = factory.create_sync(created_by_id=user1.id, updated_by_id=None, **extras) + + assert len(instance.audits) == 0 + + # Update + instance.updated_by_id = user2.id + db_session.add(instance) + db_session.commit() + db_session.refresh(instance) + + # Test audits dymanic generic FK + assert len(instance.audits) == 1 + assert instance.audits[0].author_id == user2.id + + # Update instance once again + instance.updated_by_id = user1.id + db_session.add(instance) + db_session.commit() + db_session.refresh(instance) + + # Test audits dymanic generic FK + expected_audits = 2 + assert len(instance.audits) == expected_audits + assert instance.audits[1].author_id == user1.id