Skip to content

Commit

Permalink
Dumb work
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaupetit committed Aug 1, 2024
1 parent eb863c7 commit 4cb2970
Showing 1 changed file with 80 additions and 16 deletions.
96 changes: 80 additions & 16 deletions src/api/qualicharge/schemas/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import json
import logging
from typing_extensions import Optional
import uuid
from threading import local

import geopandas as gp
import pandas as pd
Expand All @@ -21,7 +24,7 @@
class StatiqueImporter:
"""Statique model data frame."""

def __init__(self, df: pd.DataFrame, connection: Connection):
def __init__(self, df: pd.DataFrame, connection: Connection, flag_new: bool = True):
"""Add table cache keys."""
self._statique: pd.DataFrame = df

Expand All @@ -33,9 +36,14 @@ def __init__(self, df: pd.DataFrame, connection: Connection):
self._station: pd.DataFrame = None

self.connection: Connection = connection
self.flag_new()
if flag_new:
self._flag_new()

def flag_new(self):
def __len__(self):
"""Object length corresponds to the static dataframe length."""
return len(self._statique.index)

def _flag_new(self):
"""Add existing PDC or Station (and related)."""
station_ids = self._statique["id_station_itinerance"].unique()
stations_exists_query = (
Expand Down Expand Up @@ -66,27 +74,31 @@ def flag_new(self):
self._statique = self._statique.merge(
existing_pdcs, how="left", on="id_pdc_itinerance"
)
self._statique["new"] = self._statique.apply(
lambda x: False if all((x["station_id"], x["pointdecharge_id"])) else True,
axis=1,
)
self._statique["is_new"] = False
self._statique.loc[self._statique["station_id"].isnull(), "is_new"] = True
self._statique.loc[self._statique["pointdecharge_id"].isnull(), "is_new"] = True

def _get_fields_for_schema(self, schema: BaseTimestampedSQLModel):
@staticmethod
def _get_fields_for_schema(schema: BaseTimestampedSQLModel):
"""Get Statique fields from a core schema."""
return list(set(Statique.model_fields.keys()) & set(schema.model_fields.keys()))

def _add_timestamped_model_fields(self, df: pd.DataFrame):
@staticmethod
def _add_timestamped_model_fields(df: pd.DataFrame):
"""Add required fields for a BaseTimestampedSQLModel."""
df["id"] = None
now = pd.Timestamp.now(tz="utc")
df["created_at"] = now
df["updated_at"] = now
return df

def _get_dataframe_for_schema(self, schema: BaseTimestampedSQLModel):
@staticmethod
def _get_dataframe_for_schema(
df: pd.DataFrame, schema: BaseTimestampedSQLModel, subset: Optional[str] = None
):
"""Extract Schema DataFrame from original Statique DataFrame."""
df = self._statique[self._get_fields_for_schema(schema)]
df = df.drop_duplicates()
df = df[_get_fields_for_schema(schema)]
df = df.drop_duplicates(subset)
df = self._add_timestamped_model_fields(df)
return df

Expand Down Expand Up @@ -142,10 +154,62 @@ def station(self):

def save(self) -> int:
"""Save new entries."""
to_save = self._statique.loc[self._statique["new"]]
return len(to_save.index)
to_save = StatiqueImporter(
self._statique.loc[self._statique["is_new"]],
self.connection,
flag_new=False,
)
if not len(to_save):
return 0

localisation = to_save.localisation.loc[
to_save._statique["station_id"].isnull()
]
localisation["id"] = localisation.apply(lambda _: uuid.uuid4(), axis=1)
# localisation.to_postgis(
# "localisation", to_save.connection, if_exists="append", index=False
# )

return len(to_save)

def update(self) -> int:
"""Update existing entries."""
to_update = self._statique.loc[~self._statique["new"]]
return len(to_update.index)
to_update = StatiqueImporter(
self._statique.loc[~self._statique["is_new"]],
self.connection,
flag_new=False,
)
return len(to_update)


def importer(statique: pd.DataFrame, connection: Connection):
"""FIXME."""
station_ids = statique["id_station_itinerance"].unique()
stations_exists_query = (
"SELECT "
"id as station_id, "
"amenageur_id, "
"operateur_id, "
"enseigne_id, "
"localisation_id, "
"operational_unit_id, "
"id_station_itinerance "
"FROM station "
"WHERE id_station_itinerance IN "
f"('{"','".join(station_ids)}')"
)
existing_stations = pd.read_sql(stations_exists_query, connection)
statique = statique.merge(existing_stations, how="left", on="id_station_itinerance")

pdc_ids = statique["id_pdc_itinerance"].unique()
pdc_exists_query = (
"SELECT id as pointdecharge_id, id_pdc_itinerance "
"FROM pointdecharge "
f"WHERE id_pdc_itinerance IN ('{"','".join(pdc_ids)}')"
)
existing_pdcs = pd.read_sql(pdc_exists_query, connection)
statique = statique.merge(existing_pdcs, how="left", on="id_pdc_itinerance")

to_create = statique.loc[
statique["station_id"].isnull() | statique["pointdecharge_id"].isnull()
]

0 comments on commit 4cb2970

Please sign in to comment.