From 34942daac2cf3a4e43f3e22a7f59cbd5cf373a89 Mon Sep 17 00:00:00 2001 From: Julien Maupetit Date: Mon, 29 Jul 2024 18:37:48 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8(api)=20add=20Pandas-based=20StatiqueI?= =?UTF-8?q?mporter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We designed an efficient ORM-free Pandas-based Statique importer. It is supposed to be 10x faster than using ORM-based routines as implemented in /statique/bulk endpoint. --- docker-compose.yml | 1 + src/api/CHANGELOG.md | 5 + src/api/Pipfile | 1 + src/api/Pipfile.lock | 89 ++++++++- src/api/qualicharge/cli.py | 32 ++++ src/api/qualicharge/exceptions.py | 4 + src/api/qualicharge/schemas/sql.py | 295 +++++++++++++++++++++++++++++ src/api/tests/schemas/test_sql.py | 167 ++++++++++++++++ src/api/tests/test_cli.py | 86 ++++++++- src/notebook/misc/import-static.md | 86 ++++++++- 10 files changed, 757 insertions(+), 9 deletions(-) create mode 100644 src/api/qualicharge/schemas/sql.py create mode 100644 src/api/tests/schemas/test_sql.py diff --git a/docker-compose.yml b/docker-compose.yml index 478f5901..31ad51fd 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -40,6 +40,7 @@ services: restart: always volumes: - ./src/api:/app + - ./data:/data depends_on: - postgresql diff --git a/src/api/CHANGELOG.md b/src/api/CHANGELOG.md index 6d785d90..ca66864f 100644 --- a/src/api/CHANGELOG.md +++ b/src/api/CHANGELOG.md @@ -8,6 +8,11 @@ and this project adheres to ## [Unreleased] +### Added + +- Implement Pandas-based `StatiqueImporter` +- CLI: add `import-statique` command + ### Changed - Upgrade fastapi to `0.112.2` diff --git a/src/api/Pipfile b/src/api/Pipfile index e4fb8b66..6cffea26 100644 --- a/src/api/Pipfile +++ b/src/api/Pipfile @@ -32,6 +32,7 @@ black = "==24.8.0" csvkit = "==2.0.1" honcho = "==1.1.0" mypy = "==1.11.2" +pandas-stubs = "==2.2.2.240807" polyfactory = "==2.16.2" pytest = "==8.3.2" pytest-cov = "==5.0.0" diff --git a/src/api/Pipfile.lock b/src/api/Pipfile.lock index 7439d439..f918239d 100644 --- a/src/api/Pipfile.lock +++ b/src/api/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "dcea5518d34d8dd180af5d658de7b0b1949fd54a7984012dd308767a855f73a5" + "sha256": "a732ef38cfad1fecfa4583f609c3667f18ccf64c7cae4b74785b0a5fa9762143" }, "pipfile-spec": 6, "requires": { @@ -568,10 +568,10 @@ }, "phonenumbers": { "hashes": [ - "sha256:339e521403fe4dd9c664dbbeb2fe434f9ea5c81e54c0fdfadbaeb53b26a76c27", - "sha256:35b904e4a79226eee027fbb467a9aa6f1ab9ffc3c09c91bf14b885c154936726" + "sha256:2175021e84ee4e41b43c890f2d0af51f18c6ca9ad525886d6d6e4ea882e46fac", + "sha256:52cd02865dab1428ca9e89d442629b61d407c7dc687cfb80a3e8d068a584513c" ], - "version": "==8.13.43" + "version": "==8.13.44" }, "prompt-toolkit": { "hashes": [ @@ -1021,11 +1021,11 @@ }, "rich": { "hashes": [ - "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222", - "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432" + "sha256:2e85306a063b9492dffc86278197a60cbece75bcb766022f3436f567cae11bdc", + "sha256:a5ac1f1cd448ade0d59cc3356f7db7a7ccda2c8cbae9c7a90c28ff463d3e91f4" ], "markers": "python_full_version >= '3.7.0'", - "version": "==13.7.1" + "version": "==13.8.0" }, "semver": { "hashes": [ @@ -1923,6 +1923,64 @@ "markers": "python_version >= '3.5'", "version": "==1.0.0" }, + "numpy": { + "hashes": [ + "sha256:08801848a40aea24ce16c2ecde3b756f9ad756586fb2d13210939eb69b023f5b", + "sha256:0937e54c09f7a9a68da6889362ddd2ff584c02d015ec92672c099b61555f8911", + "sha256:0ab32eb9170bf8ffcbb14f11613f4a0b108d3ffee0832457c5d4808233ba8977", + "sha256:0abb3916a35d9090088a748636b2c06dc9a6542f99cd476979fb156a18192b84", + "sha256:0af3a5987f59d9c529c022c8c2a64805b339b7ef506509fba7d0556649b9714b", + "sha256:10e2350aea18d04832319aac0f887d5fcec1b36abd485d14f173e3e900b83e33", + "sha256:15ef8b2177eeb7e37dd5ef4016f30b7659c57c2c0b57a779f1d537ff33a72c7b", + "sha256:1f817c71683fd1bb5cff1529a1d085a57f02ccd2ebc5cd2c566f9a01118e3b7d", + "sha256:24003ba8ff22ea29a8c306e61d316ac74111cebf942afbf692df65509a05f111", + "sha256:30014b234f07b5fec20f4146f69e13cfb1e33ee9a18a1879a0142fbb00d47673", + "sha256:343e3e152bf5a087511cd325e3b7ecfd5b92d369e80e74c12cd87826e263ec06", + "sha256:378cb4f24c7d93066ee4103204f73ed046eb88f9ad5bb2275bb9fa0f6a02bd36", + "sha256:398049e237d1aae53d82a416dade04defed1a47f87d18d5bd615b6e7d7e41d1f", + "sha256:3a3336fbfa0d38d3deacd3fe7f3d07e13597f29c13abf4d15c3b6dc2291cbbdd", + "sha256:442596f01913656d579309edcd179a2a2f9977d9a14ff41d042475280fc7f34e", + "sha256:44e44973262dc3ae79e9063a1284a73e09d01b894b534a769732ccd46c28cc62", + "sha256:54139e0eb219f52f60656d163cbe67c31ede51d13236c950145473504fa208cb", + "sha256:5474dad8c86ee9ba9bb776f4b99ef2d41b3b8f4e0d199d4f7304728ed34d0300", + "sha256:54c6a63e9d81efe64bfb7bcb0ec64332a87d0b87575f6009c8ba67ea6374770b", + "sha256:624884b572dff8ca8f60fab591413f077471de64e376b17d291b19f56504b2bb", + "sha256:6326ab99b52fafdcdeccf602d6286191a79fe2fda0ae90573c5814cd2b0bc1b8", + "sha256:652e92fc409e278abdd61e9505649e3938f6d04ce7ef1953f2ec598a50e7c195", + "sha256:6c1de77ded79fef664d5098a66810d4d27ca0224e9051906e634b3f7ead134c2", + "sha256:76368c788ccb4f4782cf9c842b316140142b4cbf22ff8db82724e82fe1205dce", + "sha256:7a894c51fd8c4e834f00ac742abad73fc485df1062f1b875661a3c1e1fb1c2f6", + "sha256:7dc90da0081f7e1da49ec4e398ede6a8e9cc4f5ebe5f9e06b443ed889ee9aaa2", + "sha256:848c6b5cad9898e4b9ef251b6f934fa34630371f2e916261070a4eb9092ffd33", + "sha256:899da829b362ade41e1e7eccad2cf274035e1cb36ba73034946fccd4afd8606b", + "sha256:8ab81ccd753859ab89e67199b9da62c543850f819993761c1e94a75a814ed667", + "sha256:8fb49a0ba4d8f41198ae2d52118b050fd34dace4b8f3fb0ee34e23eb4ae775b1", + "sha256:9156ca1f79fc4acc226696e95bfcc2b486f165a6a59ebe22b2c1f82ab190384a", + "sha256:9523f8b46485db6939bd069b28b642fec86c30909cea90ef550373787f79530e", + "sha256:a0756a179afa766ad7cb6f036de622e8a8f16ffdd55aa31f296c870b5679d745", + "sha256:a0cdef204199278f5c461a0bed6ed2e052998276e6d8ab2963d5b5c39a0500bc", + "sha256:ab83adc099ec62e044b1fbb3a05499fa1e99f6d53a1dde102b2d85eff66ed324", + "sha256:b34fa5e3b5d6dc7e0a4243fa0f81367027cb6f4a7215a17852979634b5544ee0", + "sha256:b47c551c6724960479cefd7353656498b86e7232429e3a41ab83be4da1b109e8", + "sha256:c4cd94dfefbefec3f8b544f61286584292d740e6e9d4677769bc76b8f41deb02", + "sha256:c4f982715e65036c34897eb598d64aef15150c447be2cfc6643ec7a11af06574", + "sha256:d8f699a709120b220dfe173f79c73cb2a2cab2c0b88dd59d7b49407d032b8ebd", + "sha256:dd94ce596bda40a9618324547cfaaf6650b1a24f5390350142499aa4e34e53d1", + "sha256:de844aaa4815b78f6023832590d77da0e3b6805c644c33ce94a1e449f16d6ab5", + "sha256:e5f0642cdf4636198a4990de7a71b693d824c56a757862230454629cf62e323d", + "sha256:f07fa2f15dabe91259828ce7d71b5ca9e2eb7c8c26baa822c825ce43552f4883", + "sha256:f15976718c004466406342789f31b6673776360f3b1e3c575f25302d7e789575", + "sha256:f358ea9e47eb3c2d6eba121ab512dfff38a88db719c38d1e67349af210bc7529", + "sha256:f505264735ee074250a9c78247ee8618292091d9d1fcc023290e9ac67e8f1afa", + "sha256:f5ebbf9fbdabed208d4ecd2e1dfd2c0741af2f876e7ae522c2537d404ca895c3", + "sha256:f6b26e6c3b98adb648243670fddc8cab6ae17473f9dc58c51574af3e64d61211", + "sha256:f8e93a01a35be08d31ae33021e5268f157a2d60ebd643cfc15de6ab8e4722eb1", + "sha256:fe76d75b345dc045acdbc006adcb197cc680754afd6c259de60d358d60c93736", + "sha256:ffbd6faeb190aaf2b5e9024bac9622d2ee549b7ec89ef3a9373fa35313d44e0e" + ], + "markers": "python_version >= '3.12'", + "version": "==2.1.0" + }, "olefile": { "hashes": [ "sha256:543c7da2a7adadf21214938bb79c83ea12b473a4b6ee4ad4bf854e7715e13d1f", @@ -1947,6 +2005,15 @@ "markers": "python_version >= '3.8'", "version": "==24.1" }, + "pandas-stubs": { + "hashes": [ + "sha256:64a559725a57a449f46225fbafc422520b7410bff9252b661a225b5559192a93", + "sha256:893919ad82be4275f0d07bb47a95d08bae580d3fdea308a7acfcb3f02e76186e" + ], + "index": "pypi", + "markers": "python_version >= '3.9'", + "version": "==2.2.2.240807" + }, "parsedatetime": { "hashes": [ "sha256:4cb368fbb18a0b7231f4d76119165451c8d2e35951455dfee97c62a87b04d455", @@ -2166,6 +2233,14 @@ "markers": "python_version >= '3.8'", "version": "==3.3.4.20240106" }, + "types-pytz": { + "hashes": [ + "sha256:6810c8a1f68f21fdf0f4f374a432487c77645a0ac0b31de4bf4690cf21ad3981", + "sha256:8335d443310e2db7b74e007414e74c4f53b67452c0cb0d228ca359ccfba59659" + ], + "markers": "python_version >= '3.8'", + "version": "==2024.1.0.20240417" + }, "typing-extensions": { "hashes": [ "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", diff --git a/src/api/qualicharge/cli.py b/src/api/qualicharge/cli.py index 528069f5..e29069c0 100644 --- a/src/api/qualicharge/cli.py +++ b/src/api/qualicharge/cli.py @@ -1,13 +1,19 @@ """QualiCharge CLI.""" +import logging +from pathlib import Path from typing import Optional, Sequence, cast +import pandas as pd import questionary import typer +from psycopg import Error as PGError from rich import print from rich.console import Console +from rich.logging import RichHandler from rich.table import Table from sqlalchemy import Column as SAColumn +from sqlalchemy.exc import IntegrityError, OperationalError, ProgrammingError from sqlmodel import Session as SMSession from sqlmodel import select @@ -15,9 +21,14 @@ from .auth.schemas import Group, ScopesEnum, User from .conf import settings from .db import get_session +from .exceptions import IntegrityError as QCIntegrityError from .fixtures.operational_units import prefixes from .schemas.core import OperationalUnit +from .schemas.sql import StatiqueImporter +logging.basicConfig( + level=logging.INFO, format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] +) app = typer.Typer(name="qualicharge", no_args_is_help=True) console = Console() @@ -416,6 +427,27 @@ def delete_user(ctx: typer.Context, username: str, force: bool = False): print(f"[bold yellow]User {username} deleted.[/bold yellow]") +@app.command() +def import_static(ctx: typer.Context, input_file: Path): + """Import Statique file (parquet format).""" + session: SMSession = ctx.obj + + # Load dataset + console.log(f"Reading input file: {input_file}") + static = pd.read_parquet(input_file) + console.log(f"Read {len(static.index)} rows") + importer = StatiqueImporter(static, session.connection()) + + console.log("Save to configured database") + try: + importer.save() + except (ProgrammingError, IntegrityError, OperationalError, PGError) as err: + session.rollback() + raise QCIntegrityError("Input file importation failed. Rolling back.") from err + session.commit() + console.log("Saved (or updated) all entries successfully.") + + @app.callback() def main(ctx: typer.Context): """Attach database session to the context object.""" diff --git a/src/api/qualicharge/exceptions.py b/src/api/qualicharge/exceptions.py index 50318e05..d0fc5494 100644 --- a/src/api/qualicharge/exceptions.py +++ b/src/api/qualicharge/exceptions.py @@ -39,3 +39,7 @@ class IntegrityError(QualiChargeExceptionMixin, Exception): class ObjectDoesNotExist(QualiChargeExceptionMixin, Exception): """Raised when queried object does not exist.""" + + +class ProgrammingError(QualiChargeExceptionMixin, Exception): + """Raised when QC object API is badly used.""" diff --git a/src/api/qualicharge/schemas/sql.py b/src/api/qualicharge/schemas/sql.py new file mode 100644 index 00000000..dc1a12c5 --- /dev/null +++ b/src/api/qualicharge/schemas/sql.py @@ -0,0 +1,295 @@ +"""QualiCharge SQL module. + +This module regroups ORM-free methods used to massively import data. +""" + +import json +import logging +import uuid + +import geopandas as gp # type: ignore +import pandas as pd +from shapely import to_wkt +from shapely.geometry import Point +from sqlalchemy import Table +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.engine import Connection +from sqlalchemy.schema import MetaData +from typing_extensions import Optional + +from ..exceptions import ProgrammingError +from ..models.static import Statique +from . import BaseTimestampedSQLModel +from .core import ( + AccessibilitePMREnum, + Amenageur, + ConditionAccesEnum, + Enseigne, + ImplantationStationEnum, + Localisation, + Operateur, + PointDeCharge, + RaccordementEnum, + Station, +) + +logger = logging.getLogger(__name__) + + +class StatiqueImporter: + """Statique importer from a Pandas Dataframe.""" + + def __init__(self, df: pd.DataFrame, connection: Connection): + """Add table cache keys.""" + logger.info("Loading input dataframe containing %d rows", len(df)) + + self._statique: pd.DataFrame = self._fix_enums(df) + self._statique_with_fk: pd.DataFrame = self._statique.copy() + self._saved_schemas: list[type[BaseTimestampedSQLModel]] = [] + + self._amenageur: Optional[pd.DataFrame] = None + self._enseigne: Optional[pd.DataFrame] = None + self._localisation: gp.GeoDataFrame = None + self._operateur: Optional[pd.DataFrame] = None + self._pdc: Optional[pd.DataFrame] = None + self._station: Optional[pd.DataFrame] = None + + self._operational_units: Optional[pd.DataFrame] = None + + self.connection: Connection = connection + + def __len__(self): + """Object length corresponds to the static dataframe length.""" + return len(self._statique) + + @staticmethod + def _add_timestamped_model_fields(df: pd.DataFrame): + """Add required fields for a BaseTimestampedSQLModel.""" + df["id"] = df.apply(lambda x: uuid.uuid4(), axis=1) + now = pd.Timestamp.now(tz="utc") + df["created_at"] = now + df["updated_at"] = now + return df + + @staticmethod + def _schema_fk(schema: type[BaseTimestampedSQLModel]) -> str: + """Get expected schema foreign key name.""" + return f"{schema.__table__.name}_id" # type: ignore[attr-defined] + + @staticmethod + def _get_schema_fks(schema: type[BaseTimestampedSQLModel]) -> list[str]: + """Get foreign key field names from a schema.""" + return [ + fk.parent.name + for fk in schema.metadata.tables[schema.__tablename__].foreign_keys # type: ignore[index] + ] + + def _fix_enums(self, df: pd.DataFrame) -> pd.DataFrame: + """Fix enums representation in dataframe.""" + logger.debug("Fixing enum columns representation") + target = [] + src = [] + + for enum_ in ( + ImplantationStationEnum, + ConditionAccesEnum, + AccessibilitePMREnum, + RaccordementEnum, + ): + for entry in enum_: + target.append(str(entry.name)) + src.append(entry.value) + + return df.replace(to_replace=src, value=target) + + def _get_fields_for_schema( + self, schema: type[BaseTimestampedSQLModel], with_fk: bool = False + ) -> list[str]: + """Get Statique fields from a core schema.""" + fields = list( + set(Statique.model_fields.keys()) & set(schema.model_fields.keys()) + ) + if with_fk: + fields += self._get_schema_fks(schema) + return fields + + def _get_dataframe_for_schema( + self, + schema: type[BaseTimestampedSQLModel], + subset: Optional[str] = None, + with_fk: bool = False, + ): + """Extract Schema DataFrame from original Statique DataFrame.""" + src = self._statique_with_fk if with_fk else self._statique + df = src[self._get_fields_for_schema(schema, with_fk=with_fk)] + df = df.drop_duplicates(subset) + df = self._add_timestamped_model_fields(df) + return df + + def _add_fk_from_saved_schema( + self, saved: pd.DataFrame, schema: type[BaseTimestampedSQLModel] + ): + """Add foreign keys to the statique DataFrame using saved schema.""" + fields = self._get_fields_for_schema(schema) + # coordonneesXY field cannot be used for merging + fields = list(set(fields) - {"coordonneesXY"}) + left = self._statique_with_fk + right = saved[fields + [self._schema_fk(schema)]] + self._statique_with_fk = left.merge(right, how="left", on=fields) + + def _load_operational_units(self): + """Query database to get Operational Units.""" + logger.info("Loading operational units from database") + self._operational_units = pd.read_sql_table( + "operationalunit", self.connection, columns=["id", "code"] + ) + + def _add_operational_units_fk(self): + """Add operational units fk in statique with fk dataframe.""" + logger.info("Merging operational unit foreign keys") + if self._operational_units is None: + self._load_operational_units() + left = self._statique_with_fk + left["code"] = left["id_station_itinerance"].str.slice(stop=5) + left = left.merge(self._operational_units, how="left", on="code") + left.drop(columns="code", inplace=True) + left.rename(columns={"id": "operational_unit_id"}, inplace=True) + self._statique_with_fk = left + + @property + def amenageur(self) -> pd.DataFrame: + """Get Amenageur Dataframe.""" + if self._amenageur is None: + self._amenageur = self._get_dataframe_for_schema(Amenageur) + return self._amenageur + + @property + def enseigne(self) -> pd.DataFrame: + """Get Enseigne Dataframe.""" + if self._enseigne is None: + self._enseigne = self._get_dataframe_for_schema(Enseigne) + return self._enseigne + + @property + def localisation(self) -> pd.DataFrame: + """Get localisation DataFrame.""" + if self._localisation is None: + df = self._get_dataframe_for_schema(Localisation) + # We need a WKT representation for bulk insertion + df["coordonneesXY"] = ( + df["coordonneesXY"].map(json.loads).map(Point).map(to_wkt) + ) + self._localisation = df + return self._localisation + + @property + def operateur(self) -> pd.DataFrame: + """Get Operateur Dataframe.""" + if self._operateur is None: + self._operateur = self._get_dataframe_for_schema(Operateur) + return self._operateur + + @property + def pdc(self) -> pd.DataFrame: + """Get PointDeCharge Dataframe.""" + if self._pdc is None: + self._pdc = self._get_dataframe_for_schema(PointDeCharge, with_fk=True) + return self._pdc + + @property + def station(self) -> pd.DataFrame: + """Get Station Dataframe.""" + if self._station is None: + self._station = self._get_dataframe_for_schema(Station, with_fk=True) + return self._station + + def _save_schema( + self, + df: pd.DataFrame, + schema: type[BaseTimestampedSQLModel], + constraint: Optional[str] = None, + index_elements: Optional[list[str]] = None, + chunksize: int = 1000, + ) -> pd.DataFrame: + """Save given dataframe records to the corresponding schema.""" + logger.info("Saving schema %s (%d rows)", schema.__qualname__, len(df)) + + if schema in self._saved_schemas: + raise ProgrammingError( + ( + "You cannot save the same schema more than once. " + "You should create a new StatiqueImporter instance instead." + ) + ) + + schema_table = Table( + schema.__table__.name, # type: ignore[attr-defined] + MetaData(), + autoload_with=self.connection, + ) + + fks = pd.Series() + for chunk in [df[i : i + chunksize] for i in range(0, len(df), chunksize)]: + stmt = insert(schema_table).values(chunk.to_dict("records")) + updates_on_conflict = { + f: stmt.excluded.get(f) + for f in self._get_fields_for_schema(schema, with_fk=True) + } + updates_on_conflict.update({"updated_at": stmt.excluded.updated_at}) + stmt = stmt.on_conflict_do_update( + constraint=constraint, + index_elements=index_elements, + set_=updates_on_conflict, + ) + stmt_ret = stmt.returning(schema_table.c.id) + + result = self.connection.execute(stmt_ret) + fks = pd.concat( + [ + fks, + pd.Series(data=[row.id for row in result.all()], index=chunk.index), + ] + ) + + # Leave the original dataframe untouched + cp = df.copy(deep=True) + cp.insert(0, self._schema_fk(schema), fks) + self._add_fk_from_saved_schema(cp, schema) + self._saved_schemas += [schema] + + return cp + + def save(self): + """Save (or update) statique entries.""" + self._add_operational_units_fk() + + self._save_schema( + self.amenageur, + Amenageur, + constraint="amenageur_nom_amenageur_siren_amenageur_contact_amenageur_key", + ) + self._save_schema( + self.operateur, + Operateur, + constraint="operateur_nom_operateur_contact_operateur_telephone_operate_key", + ) + self._save_schema( + self.enseigne, + Enseigne, + constraint="enseigne_nom_enseigne_key", + ) + self._save_schema( + self.localisation, + Localisation, + constraint="localisation_adresse_station_key", + ) + self._save_schema( + self.station, + Station, + index_elements=["id_station_itinerance"], + ) + self._save_schema( + self.pdc, + PointDeCharge, + index_elements=["id_pdc_itinerance"], + ) diff --git a/src/api/tests/schemas/test_sql.py b/src/api/tests/schemas/test_sql.py new file mode 100644 index 00000000..f771ca26 --- /dev/null +++ b/src/api/tests/schemas/test_sql.py @@ -0,0 +1,167 @@ +"""Tests for QualiCharge SQL importer.""" + +from io import StringIO +from math import isclose + +import pandas as pd +import pytest +from sqlalchemy import func +from sqlmodel import select + +from qualicharge.exceptions import ProgrammingError +from qualicharge.factories.static import StatiqueFactory +from qualicharge.schemas.core import ( + Amenageur, + Enseigne, + Localisation, + Operateur, + PointDeCharge, + Station, +) +from qualicharge.schemas.sql import StatiqueImporter + + +def test_statique_importer_properties(db_session): + """Test the StatiqueImporter properties.""" + # Create statique data to import + size = 5 + statiques = StatiqueFactory.batch(size=size) + df = pd.read_json( + StringIO(f"{'\n'.join([s.model_dump_json() for s in statiques])}"), + lines=True, + dtype_backend="pyarrow", + ) + importer = StatiqueImporter(df, db_session.connection()) + + assert len(importer) == size + assert len(importer.amenageur.index) == size + assert len(importer.enseigne.index) == size + assert len(importer.operateur.index) == size + assert len(importer.localisation.index) == size + + with pytest.raises(KeyError, match="not in index"): + assert importer.station + + with pytest.raises(KeyError, match="not in index"): + assert importer.pdc + + +def test_statique_importer_save_or_update(db_session): + """Test the StatiqueImporter save (or update) feature.""" + # Create statique data to import + size = 5 + statiques = StatiqueFactory.batch(size=size) + df = pd.read_json( + StringIO(f"{'\n'.join([s.model_dump_json() for s in statiques])}"), + lines=True, + dtype_backend="pyarrow", + ) + importer = StatiqueImporter(df, db_session.connection()) + + # No database records exist yet + assert db_session.exec(select(func.count(Amenageur.id))).one() == 0 + assert db_session.exec(select(func.count(Enseigne.id))).one() == 0 + assert db_session.exec(select(func.count(Localisation.id))).one() == 0 + assert db_session.exec(select(func.count(Operateur.id))).one() == 0 + assert db_session.exec(select(func.count(PointDeCharge.id))).one() == 0 + assert db_session.exec(select(func.count(Station.id))).one() == 0 + + # Save to database + importer.save() + + # Assert we've created expected records + assert db_session.exec(select(func.count(Amenageur.id))).one() == size + assert db_session.exec(select(func.count(Enseigne.id))).one() == size + assert db_session.exec(select(func.count(Localisation.id))).one() == size + assert db_session.exec(select(func.count(Operateur.id))).one() == size + assert db_session.exec(select(func.count(PointDeCharge.id))).one() == size + assert db_session.exec(select(func.count(Station.id))).one() == size + + # Check save() cannot be used more than once + with pytest.raises( + ProgrammingError, match="You cannot save the same schema more than once." + ): + importer.save() + + # Save it again and make sure we've updated records + importer = StatiqueImporter(df, db_session.connection()) + importer.save() + + # Assert we've created expected records + assert ( + db_session.exec(select(func.count(Amenageur.id))).one() >= size + ) # too permissive :( + assert db_session.exec(select(func.count(Enseigne.id))).one() == size + assert db_session.exec(select(func.count(Localisation.id))).one() == size + assert db_session.exec(select(func.count(Operateur.id))).one() >= size # ditto + assert db_session.exec(select(func.count(PointDeCharge.id))).one() == size + assert db_session.exec(select(func.count(Station.id))).one() == size + + +def test_statique_importer_consistency(db_session): + """Test the StatiqueImporter consistency.""" + # Create statique data to import + size = 20 + statiques = StatiqueFactory.batch(size=size) + df = pd.read_json( + StringIO(f"{'\n'.join([s.model_dump_json() for s in statiques])}"), + lines=True, + dtype_backend="pyarrow", + ) + importer = StatiqueImporter(df, db_session.connection()) + importer.save() + + for statique in statiques: + pdc = db_session.exec( + select(PointDeCharge).where( + PointDeCharge.id_pdc_itinerance == statique.id_pdc_itinerance + ) + ).one() + assert pdc.id_pdc_itinerance == statique.id_pdc_itinerance + assert pdc.id_pdc_local == statique.id_pdc_local + assert isclose(pdc.puissance_nominale, statique.puissance_nominale) + assert pdc.prise_type_ef == statique.prise_type_ef + assert pdc.prise_type_2 == statique.prise_type_2 + assert pdc.prise_type_combo_ccs == statique.prise_type_combo_ccs + assert pdc.prise_type_chademo == statique.prise_type_chademo + assert pdc.prise_type_autre == statique.prise_type_autre + assert pdc.gratuit == statique.gratuit + assert pdc.paiement_acte == statique.paiement_acte + assert pdc.paiement_cb == statique.paiement_cb + assert pdc.paiement_autre == statique.paiement_autre + assert pdc.tarification == statique.tarification + assert pdc.reservation == statique.reservation + assert pdc.accessibilite_pmr == statique.accessibilite_pmr + assert pdc.restriction_gabarit == statique.restriction_gabarit + assert pdc.observations == statique.observations + assert pdc.cable_t2_attache == statique.cable_t2_attache + assert pdc.station.id_station_itinerance == statique.id_station_itinerance + assert pdc.station.id_station_local == statique.id_station_local + assert pdc.station.nom_station == statique.nom_station + assert pdc.station.implantation_station == statique.implantation_station + assert pdc.station.nbre_pdc == statique.nbre_pdc + assert pdc.station.condition_acces == statique.condition_acces + assert pdc.station.horaires == statique.horaires + assert pdc.station.station_deux_roues == statique.station_deux_roues + assert pdc.station.raccordement == statique.raccordement + assert pdc.station.num_pdl == statique.num_pdl + assert pdc.station.date_maj == statique.date_maj + assert pdc.station.date_mise_en_service == statique.date_mise_en_service + assert pdc.station.amenageur.nom_amenageur == statique.nom_amenageur + assert pdc.station.amenageur.siren_amenageur == statique.siren_amenageur + assert pdc.station.amenageur.contact_amenageur == statique.contact_amenageur + assert pdc.station.operateur.nom_operateur == statique.nom_operateur + assert pdc.station.operateur.contact_operateur == statique.contact_operateur + assert pdc.station.operateur.telephone_operateur == statique.telephone_operateur + assert pdc.station.enseigne.nom_enseigne == statique.nom_enseigne + assert pdc.station.localisation.adresse_station == statique.adresse_station + assert ( + pdc.station.localisation.code_insee_commune == statique.code_insee_commune + ) + assert ( + pdc.station.localisation._wkb_to_coordinates( + pdc.station.localisation.coordonneesXY + ) + == statique.coordonneesXY + ) + assert pdc.station.operational_unit.code == statique.id_station_itinerance[:5] diff --git a/src/api/tests/test_cli.py b/src/api/tests/test_cli.py index 0589031c..379ffb1c 100644 --- a/src/api/tests/test_cli.py +++ b/src/api/tests/test_cli.py @@ -1,12 +1,24 @@ """Tests for QualiCharge CLI.""" +from io import StringIO + +import pandas as pd from sqlalchemy import func from sqlmodel import select from qualicharge.auth.factories import GroupFactory, UserFactory from qualicharge.auth.schemas import Group, GroupOperationalUnit, User, UserGroup from qualicharge.cli import app -from qualicharge.schemas.core import OperationalUnit +from qualicharge.factories.static import StatiqueFactory +from qualicharge.schemas.core import ( + Amenageur, + Enseigne, + Localisation, + Operateur, + OperationalUnit, + PointDeCharge, + Station, +) def test_list_groups(runner, db_session): @@ -392,3 +404,75 @@ def test_delete_user(runner, db_session): # Check that user no longer exists assert db_session.exec(select(func.count(User.id))).one() == 0 assert db_session.exec(select(func.count(UserGroup.group_id))).one() == 0 + + +def test_import_static(runner, db_session): + """Test the `import-static` command.""" + # Create statique data to import + size = 5 + statiques = StatiqueFactory.batch(size=size) + df = pd.read_json( + StringIO(f"{'\n'.join([s.model_dump_json() for s in statiques])}"), + lines=True, + dtype_backend="pyarrow", + ) + + # No database records exist yet + assert db_session.exec(select(func.count(Amenageur.id))).one() == 0 + assert db_session.exec(select(func.count(Enseigne.id))).one() == 0 + assert db_session.exec(select(func.count(Localisation.id))).one() == 0 + assert db_session.exec(select(func.count(Operateur.id))).one() == 0 + assert db_session.exec(select(func.count(PointDeCharge.id))).one() == 0 + assert db_session.exec(select(func.count(Station.id))).one() == 0 + + # Write parquet file to import + file_path = "test.parquet" + with runner.isolated_filesystem(): + df.to_parquet(file_path) + result = runner.invoke(app, ["import-static", file_path], obj=db_session) + assert result.exit_code == 0 + + # Assert we've created expected records + assert db_session.exec(select(func.count(Amenageur.id))).one() == size + assert db_session.exec(select(func.count(Enseigne.id))).one() == size + assert db_session.exec(select(func.count(Localisation.id))).one() == size + assert db_session.exec(select(func.count(Operateur.id))).one() == size + assert db_session.exec(select(func.count(PointDeCharge.id))).one() == size + assert db_session.exec(select(func.count(Station.id))).one() == size + + +def test_import_static_with_integrity_exception(runner, db_session): + """Test the `import-static` command.""" + # Create statique data to import + statiques = StatiqueFactory.batch(size=5) + statiques[1].id_pdc_itinerance = "FRS63E0001" + statiques[3].id_pdc_itinerance = "FRS63E0001" + df = pd.read_json( + StringIO(f"{'\n'.join([s.model_dump_json() for s in statiques])}"), + lines=True, + dtype_backend="pyarrow", + ) + + # No database records exist yet + assert db_session.exec(select(func.count(Amenageur.id))).one() == 0 + assert db_session.exec(select(func.count(Enseigne.id))).one() == 0 + assert db_session.exec(select(func.count(Localisation.id))).one() == 0 + assert db_session.exec(select(func.count(Operateur.id))).one() == 0 + assert db_session.exec(select(func.count(PointDeCharge.id))).one() == 0 + assert db_session.exec(select(func.count(Station.id))).one() == 0 + + # Write parquet file to import + file_path = "test.parquet" + with runner.isolated_filesystem(): + df.to_parquet(file_path) + result = runner.invoke(app, ["import-static", file_path], obj=db_session) + assert result.exit_code == 1 + assert "Input file importation failed. Rolling back." in str(result.exception) + + # Assert we've not created any record + assert db_session.exec(select(func.count(Amenageur.id))).one() == 0 + assert db_session.exec(select(func.count(Enseigne.id))).one() == 0 + assert db_session.exec(select(func.count(Localisation.id))).one() == 0 + assert db_session.exec(select(func.count(Operateur.id))).one() == 0 + assert db_session.exec(select(func.count(PointDeCharge.id))).one() == 0 + assert db_session.exec(select(func.count(Station.id))).one() == 0 diff --git a/src/notebook/misc/import-static.md b/src/notebook/misc/import-static.md index 7dd956f4..aa5467e2 100644 --- a/src/notebook/misc/import-static.md +++ b/src/notebook/misc/import-static.md @@ -151,7 +151,7 @@ localisation = add_timestamped_table_fields(localisation) # Convert to a GeoDataFrame localisation = gp.GeoDataFrame(localisation, crs="EPSG:4326", geometry="coordonneesXY") -localisation +localisation[localisation["code_insee_commune"] == "77018"] ``` ```python @@ -377,3 +377,87 @@ save(pdc, engine, "pointdecharge", truncate=True, dtype=dtype) saved = pd.read_sql("SELECT * FROM PointDeCharge", engine) saved ``` + +## Alternate version using raw SQLAlchemy + +```python +amenageur_fields = ["nom_amenageur", "siren_amenageur", "contact_amenageur"] +amenageur = static[amenageur_fields] + +# Remove duplicates +amenageur = amenageur.drop_duplicates() + +# Add missing columns (to fit with the ORM) +amenageur = add_timestamped_table_fields(amenageur) +amenageur +``` + +```python +%%time +from sqlalchemy import Table +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.schema import MetaData + +def save_amenageur(df): + metadata_obj = MetaData() + t_amenageur = Table("amenageur", metadata_obj, autoload_with=engine) + + amenageur.drop("amenageur_id", axis=1, inplace=True, errors="ignore") + + stmt = insert(t_amenageur).values(amenageur.to_dict("records")) + stmt = stmt.on_conflict_do_update( + constraint="amenageur_nom_amenageur_siren_amenageur_contact_amenageur_key", + set_=dict( + nom_amenageur=stmt.excluded.nom_amenageur, + siren_amenageur=stmt.excluded.siren_amenageur, + contact_amenageur=stmt.excluded.contact_amenageur, + updated_at=stmt.excluded.updated_at + ) + ) + stmt = stmt.returning(t_amenageur.c.id) + + with engine.connect() as conn: + result = conn.execute(stmt) + + ids = pd.Series(data=(row.id for row in result.all()), index=amenageur.index) + + amenageur.insert(0, "amenageur_id", ids) + return amenageur + +amenageur = save_amenageur(amenageur) +amenageur +``` + +```python +%%time + +def save_amenageur_by_chunks(df, n=10000): + metadata_obj = MetaData() + t_amenageur = Table("amenageur", metadata_obj, autoload_with=engine) + + df.drop("amenageur_id", axis=1, inplace=True, errors="ignore") + + chunks = [df[i:i+n] for i in range(0,len(df),n)] + for chunk in chunks: + stmt = insert(t_amenageur).values(chunk.to_dict("records")) + stmt = stmt.on_conflict_do_update( + constraint="amenageur_nom_amenageur_siren_amenageur_contact_amenageur_key", + set_=dict( + nom_amenageur=stmt.excluded.nom_amenageur, + siren_amenageur=stmt.excluded.siren_amenageur, + contact_amenageur=stmt.excluded.contact_amenageur, + updated_at=stmt.excluded.updated_at + ) + ) + stmt = stmt.returning(t_amenageur.c.id) + + with engine.connect() as conn: + result = conn.execute(stmt) + + ids = pd.Series(data=(row.id for row in result.all()), index=chunk.index) + + chunk.insert(0, "amenageur_id", ids) + return amenageur + +amenageur = save_amenageur_by_chunks(amenageur, n=2000) +```