From cca2993cfdd24635d7048bb79e3921f2ed152e43 Mon Sep 17 00:00:00 2001 From: Julien Maupetit Date: Mon, 26 Aug 2024 19:03:42 +0200 Subject: [PATCH] Switch importer to raw SQL --- src/api/qualicharge/cli.py | 4 +- src/api/qualicharge/schemas/sql.py | 161 +++++++++++------------------ 2 files changed, 60 insertions(+), 105 deletions(-) diff --git a/src/api/qualicharge/cli.py b/src/api/qualicharge/cli.py index 0f8e0177..58245ce2 100644 --- a/src/api/qualicharge/cli.py +++ b/src/api/qualicharge/cli.py @@ -432,8 +432,8 @@ def import_static(ctx: typer.Context, input_file: Path): with console.status("Saving Statiques to database…"): saved = static.save() - updated = static.update() - console.log(f"Saved/updated {saved}/{updated} entries.") + console.log(f"Saved {saved} entries.") + session.commit() @app.callback() diff --git a/src/api/qualicharge/schemas/sql.py b/src/api/qualicharge/schemas/sql.py index 5c49c04a..9f4d8a2b 100644 --- a/src/api/qualicharge/schemas/sql.py +++ b/src/api/qualicharge/schemas/sql.py @@ -5,14 +5,17 @@ import json import logging -from typing_extensions import Optional import uuid -from threading import local +from typing_extensions import Optional import geopandas as gp import pandas as pd from shapely.geometry import Point +from sqlalchemy import Table +from sqlalchemy.dialects.postgresql import insert from sqlalchemy.engine import Connection +from sqlalchemy.orm import Session +from sqlalchemy.schema import MetaData from ..models.static import Statique from . import BaseTimestampedSQLModel @@ -22,9 +25,9 @@ class StatiqueImporter: - """Statique model data frame.""" + """Statique importer from a Pandas Dataframe.""" - def __init__(self, df: pd.DataFrame, connection: Connection, flag_new: bool = True): + def __init__(self, df: pd.DataFrame, connection: Connection): """Add table cache keys.""" self._statique: pd.DataFrame = df @@ -36,48 +39,11 @@ def __init__(self, df: pd.DataFrame, connection: Connection, flag_new: bool = Tr self._station: pd.DataFrame = None self.connection: Connection = connection - if flag_new: - self._flag_new() def __len__(self): """Object length corresponds to the static dataframe length.""" return len(self._statique.index) - def _flag_new(self): - """Add existing PDC or Station (and related).""" - station_ids = self._statique["id_station_itinerance"].unique() - stations_exists_query = ( - "SELECT " - "id as station_id, " - "amenageur_id, " - "operateur_id, " - "enseigne_id, " - "localisation_id, " - "operational_unit_id, " - "id_station_itinerance " - "FROM station " - "WHERE id_station_itinerance IN " - f"('{"','".join(station_ids)}')" - ) - existing_stations = pd.read_sql(stations_exists_query, self.connection) - self._statique = self._statique.merge( - existing_stations, how="left", on="id_station_itinerance" - ) - - pdc_ids = self._statique["id_pdc_itinerance"].unique() - pdc_exists_query = ( - "SELECT id as pointdecharge_id, id_pdc_itinerance " - "FROM pointdecharge " - f"WHERE id_pdc_itinerance IN ('{"','".join(pdc_ids)}')" - ) - existing_pdcs = pd.read_sql(pdc_exists_query, self.connection) - self._statique = self._statique.merge( - existing_pdcs, how="left", on="id_pdc_itinerance" - ) - self._statique["is_new"] = False - self._statique.loc[self._statique["station_id"].isnull(), "is_new"] = True - self._statique.loc[self._statique["pointdecharge_id"].isnull(), "is_new"] = True - @staticmethod def _get_fields_for_schema(schema: BaseTimestampedSQLModel): """Get Statique fields from a core schema.""" @@ -86,18 +52,19 @@ def _get_fields_for_schema(schema: BaseTimestampedSQLModel): @staticmethod def _add_timestamped_model_fields(df: pd.DataFrame): """Add required fields for a BaseTimestampedSQLModel.""" - df["id"] = None + df["id"] = df.apply(lambda x: uuid.uuid4(), axis=1) now = pd.Timestamp.now(tz="utc") df["created_at"] = now df["updated_at"] = now return df - @staticmethod def _get_dataframe_for_schema( - df: pd.DataFrame, schema: BaseTimestampedSQLModel, subset: Optional[str] = None + self, + schema: BaseTimestampedSQLModel, + subset: Optional[str] = None, ): """Extract Schema DataFrame from original Statique DataFrame.""" - df = df[_get_fields_for_schema(schema)] + df = self._statique[self._get_fields_for_schema(schema)] df = df.drop_duplicates(subset) df = self._add_timestamped_model_fields(df) return df @@ -152,64 +119,52 @@ def station(self): self._station = self._get_dataframe_for_schema(Station) return self._station + def _save_schema( + self, + df: pd.DataFrame, + schema: BaseTimestampedSQLModel, + constraint: str, + ): + """Save given dataframe records to the corresponding schema.""" + schema_table = Table( + schema.__table__.name, MetaData(), autoload_with=self.connection + ) + + stmt = insert(schema_table).values(df.to_dict("records")) + updates_on_conflict = { + f: stmt.excluded.get(f) for f in self._get_fields_for_schema(schema) + } + updates_on_conflict.update({"updated_at": stmt.excluded.updated_at}) + stmt = stmt.on_conflict_do_update( + constraint=constraint, + set_=updates_on_conflict, + ) + stmt = stmt.returning(schema_table.c.id) + + # FIXME + # with Session(self.connection) as session: + # session.begin() + # try: + # result = session.execute(stmt) + # except: + # session.rollback() + # raise + # else: + # session.commit() + result = self.connection.execute(stmt) + + df.insert( + 0, + f"{schema.__table__.name}_id", + pd.Series(data=(row.id for row in result.all()), index=df.index), + ) + return df + def save(self) -> int: """Save new entries.""" - to_save = StatiqueImporter( - self._statique.loc[self._statique["is_new"]], - self.connection, - flag_new=False, - ) - if not len(to_save): - return 0 - - localisation = to_save.localisation.loc[ - to_save._statique["station_id"].isnull() - ] - localisation["id"] = localisation.apply(lambda _: uuid.uuid4(), axis=1) - # localisation.to_postgis( - # "localisation", to_save.connection, if_exists="append", index=False - # ) - - return len(to_save) - - def update(self) -> int: - """Update existing entries.""" - to_update = StatiqueImporter( - self._statique.loc[~self._statique["is_new"]], - self.connection, - flag_new=False, + amenageur = self._save_schema( + self.amenageur, + Amenageur, + constraint="amenageur_nom_amenageur_siren_amenageur_contact_amenageur_key", ) - return len(to_update) - - -def importer(statique: pd.DataFrame, connection: Connection): - """FIXME.""" - station_ids = statique["id_station_itinerance"].unique() - stations_exists_query = ( - "SELECT " - "id as station_id, " - "amenageur_id, " - "operateur_id, " - "enseigne_id, " - "localisation_id, " - "operational_unit_id, " - "id_station_itinerance " - "FROM station " - "WHERE id_station_itinerance IN " - f"('{"','".join(station_ids)}')" - ) - existing_stations = pd.read_sql(stations_exists_query, connection) - statique = statique.merge(existing_stations, how="left", on="id_station_itinerance") - - pdc_ids = statique["id_pdc_itinerance"].unique() - pdc_exists_query = ( - "SELECT id as pointdecharge_id, id_pdc_itinerance " - "FROM pointdecharge " - f"WHERE id_pdc_itinerance IN ('{"','".join(pdc_ids)}')" - ) - existing_pdcs = pd.read_sql(pdc_exists_query, connection) - statique = statique.merge(existing_pdcs, how="left", on="id_pdc_itinerance") - - to_create = statique.loc[ - statique["station_id"].isnull() | statique["pointdecharge_id"].isnull() - ] + return 0