diff --git a/arpav_ppcv/database.py b/arpav_ppcv/database.py index 21ec5dca..7e92f7c4 100644 --- a/arpav_ppcv/database.py +++ b/arpav_ppcv/database.py @@ -208,13 +208,9 @@ def update_station( station_update: observations.StationUpdate, ) -> observations.Station: """Update a station.""" - geom = from_shape( - shapely.io.from_geojson(station_update.geom.model_dump_json())) + geom = from_shape(shapely.io.from_geojson(station_update.geom.model_dump_json())) other_data = station_update.model_dump(exclude={"geom"}, exclude_unset=True) - data = { - **other_data, - "geom": geom - } + data = {**other_data, "geom": geom} for key, value in data.items(): setattr(db_station, key, value) session.add(db_station) diff --git a/arpav_ppcv/webapp/admin/app.py b/arpav_ppcv/webapp/admin/app.py index 58121ddc..f627a982 100644 --- a/arpav_ppcv/webapp/admin/app.py +++ b/arpav_ppcv/webapp/admin/app.py @@ -58,8 +58,12 @@ def create_admin(settings: config.ArpavPpcvSettings) -> ArpavPpcvAdmin: Middleware(SqlModelDbSessionMiddleware, engine=engine), ], ) - admin.add_view(coverage_views.ConfigurationParameterView(coverages.ConfigurationParameter)) - admin.add_view(coverage_views.CoverageConfigurationView(coverages.CoverageConfiguration)) + admin.add_view( + coverage_views.ConfigurationParameterView(coverages.ConfigurationParameter) + ) + admin.add_view( + coverage_views.CoverageConfigurationView(coverages.CoverageConfiguration) + ) admin.add_view(observations_views.VariableView(observations.Variable)) admin.add_view(observations_views.StationView(observations.Station)) admin.add_view( @@ -67,10 +71,16 @@ def create_admin(settings: config.ArpavPpcvSettings) -> ArpavPpcvAdmin: "Measurements", icon="fa-solid fa-vials", views=[ - observations_views.MonthlyMeasurementView(observations.MonthlyMeasurement), - observations_views.SeasonalMeasurementView(observations.SeasonalMeasurement), - observations_views.YearlyMeasurementView(observations.YearlyMeasurement), - ] + observations_views.MonthlyMeasurementView( + observations.MonthlyMeasurement + ), + observations_views.SeasonalMeasurementView( + observations.SeasonalMeasurement + ), + observations_views.YearlyMeasurementView( + observations.YearlyMeasurement + ), + ], ) ) admin.add_view( diff --git a/arpav_ppcv/webapp/admin/views/observations.py b/arpav_ppcv/webapp/admin/views/observations.py index 7f9190c6..572765b7 100644 --- a/arpav_ppcv/webapp/admin/views/observations.py +++ b/arpav_ppcv/webapp/admin/views/observations.py @@ -55,7 +55,8 @@ def can_view_details(self, request: Request) -> bool: @staticmethod def _serialize_instance( - instance: observations.MonthlyMeasurement) -> read_schemas.MonthlyMeasurementRead: + instance: observations.MonthlyMeasurement, + ) -> read_schemas.MonthlyMeasurementRead: return read_schemas.MonthlyMeasurementRead( **instance.model_dump(), station=instance.station.code, @@ -63,12 +64,12 @@ def _serialize_instance( ) async def find_all( - self, - request: Request, - skip: int = 0, - limit: int = 100, - where: Union[dict[str, Any], str, None] = None, - order_by: Optional[list[str]] = None, + self, + request: Request, + skip: int = 0, + limit: int = 100, + where: Union[dict[str, Any], str, None] = None, + order_by: Optional[list[str]] = None, ) -> Sequence[read_schemas.MonthlyMeasurementRead]: list_measurements = functools.partial( db.list_monthly_measurements, @@ -77,7 +78,8 @@ async def find_all( include_total=False, ) db_measurements, _ = await anyio.to_thread.run_sync( - list_measurements, request.state.session) + list_measurements, request.state.session + ) return [self._serialize_instance(item) for item in db_measurements] @@ -111,7 +113,8 @@ def can_view_details(self, request: Request) -> bool: @staticmethod def _serialize_instance( - instance: observations.SeasonalMeasurement) -> read_schemas.SeasonalMeasurementRead: + instance: observations.SeasonalMeasurement, + ) -> read_schemas.SeasonalMeasurementRead: return read_schemas.SeasonalMeasurementRead( **instance.model_dump(), station=instance.station.code, @@ -119,12 +122,12 @@ def _serialize_instance( ) async def find_all( - self, - request: Request, - skip: int = 0, - limit: int = 100, - where: Union[dict[str, Any], str, None] = None, - order_by: Optional[list[str]] = None, + self, + request: Request, + skip: int = 0, + limit: int = 100, + where: Union[dict[str, Any], str, None] = None, + order_by: Optional[list[str]] = None, ) -> Sequence[read_schemas.SeasonalMeasurementRead]: list_measurements = functools.partial( db.list_seasonal_measurements, @@ -133,7 +136,8 @@ async def find_all( include_total=False, ) db_measurements, _ = await anyio.to_thread.run_sync( - list_measurements, request.state.session) + list_measurements, request.state.session + ) return [self._serialize_instance(item) for item in db_measurements] @@ -166,7 +170,8 @@ def can_view_details(self, request: Request) -> bool: @staticmethod def _serialize_instance( - instance: observations.YearlyMeasurement) -> read_schemas.YearlyMeasurementRead: + instance: observations.YearlyMeasurement, + ) -> read_schemas.YearlyMeasurementRead: return read_schemas.YearlyMeasurementRead( **instance.model_dump(), station=instance.station.code, @@ -174,12 +179,12 @@ def _serialize_instance( ) async def find_all( - self, - request: Request, - skip: int = 0, - limit: int = 100, - where: Union[dict[str, Any], str, None] = None, - order_by: Optional[list[str]] = None, + self, + request: Request, + skip: int = 0, + limit: int = 100, + where: Union[dict[str, Any], str, None] = None, + order_by: Optional[list[str]] = None, ) -> Sequence[read_schemas.YearlyMeasurementRead]: list_measurements = functools.partial( db.list_yearly_measurements, @@ -188,7 +193,8 @@ async def find_all( include_total=False, ) db_measurements, _ = await anyio.to_thread.run_sync( - list_measurements, request.state.session) + list_measurements, request.state.session + ) return [self._serialize_instance(item) for item in db_measurements] @@ -215,7 +221,8 @@ def __init__(self, *args, **kwargs) -> None: @staticmethod def _serialize_instance( - instance: observations.Variable) -> read_schemas.VariableRead: + instance: observations.Variable, + ) -> read_schemas.VariableRead: return read_schemas.VariableRead(**instance.model_dump()) async def get_pk_value(self, request: Request, obj: Any) -> str: @@ -227,7 +234,7 @@ async def get_pk_value(self, request: Request, obj: Any) -> str: return str(result) async def create( - self, request: Request, data: dict[str, Any] + self, request: Request, data: dict[str, Any] ) -> Optional[read_schemas.VariableRead]: try: data = await self._arrange_data(request, data) @@ -243,7 +250,7 @@ async def create( return self.handle_exception(e) async def edit( - self, request: Request, pk: Any, data: dict[str, Any] + self, request: Request, pk: Any, data: dict[str, Any] ) -> Optional[read_schemas.VariableRead]: try: data = await self._arrange_data(request, data, True) @@ -253,26 +260,26 @@ async def edit( db.get_variable, request.state.session, pk ) db_var = await anyio.to_thread.run_sync( - db.update_variable, request.state.session, db_var, var_update) + db.update_variable, request.state.session, db_var, var_update + ) return self._serialize_instance(db_var) except Exception as e: logger.exception("something went wrong") self.handle_exception(e) - async def find_by_pk( - self, request: Request, pk: Any - ) -> read_schemas.VariableRead: + async def find_by_pk(self, request: Request, pk: Any) -> read_schemas.VariableRead: db_var = await anyio.to_thread.run_sync( - db.get_variable, request.state.session, pk) + db.get_variable, request.state.session, pk + ) return self._serialize_instance(db_var) async def find_all( - self, - request: Request, - skip: int = 0, - limit: int = 100, - where: Union[dict[str, Any], str, None] = None, - order_by: Optional[list[str]] = None, + self, + request: Request, + skip: int = 0, + limit: int = 100, + where: Union[dict[str, Any], str, None] = None, + order_by: Optional[list[str]] = None, ) -> Sequence[read_schemas.VariableRead]: list_variables = functools.partial( db.list_variables, @@ -281,7 +288,8 @@ async def find_all( include_total=False, ) db_vars, _ = await anyio.to_thread.run_sync( - list_variables, request.state.session) + list_variables, request.state.session + ) return [self._serialize_instance(db_var) for db_var in db_vars] @@ -312,16 +320,10 @@ def __init__(self, *args, **kwargs) -> None: self.icon = "fa-solid fa-tower-observation" @staticmethod - def _serialize_instance( - instance: observations.Station) -> read_schemas.StationRead: + def _serialize_instance(instance: observations.Station) -> read_schemas.StationRead: geom = shapely.io.from_wkb(bytes(instance.geom.data)) return read_schemas.StationRead( - **instance.model_dump( - exclude={ - "geom", - "type_" - } - ), + **instance.model_dump(exclude={"geom", "type_"}), type=instance.type_, longitude=geom.x, latitude=geom.y, @@ -352,21 +354,26 @@ async def validate(self, request: Request, data: dict[str, Any]) -> None: fields_to_exclude = [ f.name for f in self.get_fields_list(request, request.state.action) - if isinstance(f, (starlette_admin.FileField, starlette_admin.RelationField)) + if isinstance( + f, (starlette_admin.FileField, starlette_admin.RelationField) + ) ] + ["latitude", "longitude"] self.model.validate( - {k: v for k, v in data_to_validate.items() if k not in fields_to_exclude} + { + k: v + for k, v in data_to_validate.items() + if k not in fields_to_exclude + } ) async def create( - self, request: Request, data: dict[str, Any] + self, request: Request, data: dict[str, Any] ) -> Optional[read_schemas.StationRead]: try: data = await self._arrange_data(request, data) await self.validate(request, data) geojson_geom = geojson_pydantic.Point( - type="Point", - coordinates=(data.pop("longitude"), data.pop("latitude")) + type="Point", coordinates=(data.pop("longitude"), data.pop("latitude")) ) station_create = observations.StationCreate( type_=data.pop("type"), @@ -384,7 +391,7 @@ async def create( return self.handle_exception(e) async def edit( - self, request: Request, pk: Any, data: dict[str, Any] + self, request: Request, pk: Any, data: dict[str, Any] ) -> Optional[read_schemas.StationRead]: try: data = await self._arrange_data(request, data, True) @@ -394,39 +401,35 @@ async def edit( kwargs = {} if all((lon, lat)): kwargs["geom"] = geojson_pydantic.Point( - type="Point", - coordinates=(lon, lat) + type="Point", coordinates=(lon, lat) ) if (type_ := data.pop("type", None)) is not None: kwargs["type_"] = type_ - station_update = observations.StationUpdate( - **data, - **kwargs - ) + station_update = observations.StationUpdate(**data, **kwargs) db_station = await anyio.to_thread.run_sync( db.get_station, request.state.session, pk ) db_station = await anyio.to_thread.run_sync( - db.update_station, request.state.session, db_station, station_update) + db.update_station, request.state.session, db_station, station_update + ) return self._serialize_instance(db_station) except Exception as e: logger.exception("something went wrong") self.handle_exception(e) - async def find_by_pk( - self, request: Request, pk: Any - ) -> read_schemas.StationRead: + async def find_by_pk(self, request: Request, pk: Any) -> read_schemas.StationRead: db_station = await anyio.to_thread.run_sync( - db.get_station, request.state.session, pk) + db.get_station, request.state.session, pk + ) return self._serialize_instance(db_station) async def find_all( - self, - request: Request, - skip: int = 0, - limit: int = 100, - where: Union[dict[str, Any], str, None] = None, - order_by: Optional[list[str]] = None, + self, + request: Request, + skip: int = 0, + limit: int = 100, + where: Union[dict[str, Any], str, None] = None, + order_by: Optional[list[str]] = None, ) -> Sequence[read_schemas.StationRead]: list_stations = functools.partial( db.list_stations, @@ -435,5 +438,6 @@ async def find_all( include_total=False, ) db_stations, _ = await anyio.to_thread.run_sync( - list_stations, request.state.session) + list_stations, request.state.session + ) return [self._serialize_instance(db_station) for db_station in db_stations]